diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..0519ecba6ea913e21689ec692e81e9e4973fbf73 --- /dev/null +++ b/Dockerfile @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5c7bf410c5b374b13b57d1e891159905db9b8ad8 --- /dev/null +++ b/README.md @@ -0,0 +1,496 @@ +--- +title: MuseV +emoji: 🎨 +colorFrom: blue +colorTo: purple +sdk: docker +app_port: 7860 +pinned: false +--- + +# MuseV [English](README.md) [中文](README-zh.md) + +MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising +
+Zhiqiang Xia \*, +Zhaokang Chen\*, +Bin Wu, +Chao Li, +Kwok-Wai Hung, +Chao Zhan, +Yingjie He, +Wenjiang Zhou +(*co-first author, Corresponding Author, benbinwu@tencent.com) +
+ +**[github](https://github.com/TMElyralab/MuseV)** **[huggingface](https://huggingface.co/TMElyralab/MuseV)** **[HuggingfaceSpace](https://huggingface.co/spaces/AnchorFake/MuseVDemo)** **[project](https://tmelyralab.github.io/MuseV_Page/)** **Technical report (comming soon)** + + +We have setup **the world simulator vision since March 2023, believing diffusion models can simulate the world**. `MuseV` was a milestone achieved around **July 2023**. Amazed by the progress of Sora, we decided to opensource `MuseV`, hopefully it will benefit the community. Next we will move on to the promising diffusion+transformer scheme. + + +Update: We have released MuseTalk, a real-time high quality lip sync model, which can be applied with MuseV as a complete virtual human generation solution. + +# Overview +`MuseV` is a diffusion-based virtual human video generation framework, which +1. supports **infinite length** generation using a novel **Visual Conditioned Parallel Denoising scheme**. +2. checkpoint available for virtual human video generation trained on human dataset. +3. supports Image2Video, Text2Image2Video, Video2Video. +4. compatible with the **Stable Diffusion ecosystem**, including `base_model`, `lora`, `controlnet`, etc. +5. supports multi reference image technology, including `IPAdapter`, `ReferenceOnly`, `ReferenceNet`, `IPAdapterFaceID`. +6. training codes (comming very soon). + +# Important bug fixes +1. `musev_referencenet_pose`: model_name of `unet`, `ip_adapter` of Command is not correct, please use `musev_referencenet_pose` instead of `musev_referencenet`. + +# News +- [03/27/2024] release `MuseV` project and trained model `musev`, `muse_referencenet`. +- [03/30/2024] add huggingface space gradio to generate video in gui + +## Model +### Overview of model structure +![model_structure](./data/models/musev_structure.png) +### Parallel denoising +![parallel_denoise](./data//models/parallel_denoise.png) + +## Cases +All frames were generated directly from text2video model, without any post process. +MoreCase is in **[project](https://tmelyralab.github.io/MuseV_Page/)**, including **1-2 minute video**. + + +Examples bellow can be accessed at `configs/tasks/example.yaml` + + +### Text/Image2Video + +#### Human + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
imagevideo prompt
+ + + + (masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3) +
+ + + + + (masterpiece, best quality, highres:1), peaceful beautiful sea scene +
+ + + + + (masterpiece, best quality, highres:1), peaceful beautiful sea scene +
+ + + + + (masterpiece, best quality, highres:1), playing guitar +
+ + + + + (masterpiece, best quality, highres:1), playing guitar +
+ + + + + (masterpiece, best quality, highres:1),(1man, solo:1),(eye blinks:1.8),(head wave:1.3),Chinese ink painting style +
+ + + + + (masterpiece, best quality, highres:1),(1girl, solo:1),(beautiful face, + soft skin, costume:1),(eye blinks:{eye_blinks_factor}),(head wave:1.3) +
+ +#### Scene + + + + + + + + + + + + + + + + + + + +
imagevideoprompt
+ + + + + (masterpiece, best quality, highres:1), peaceful beautiful waterfall, an + endless waterfall +
+ + + + (masterpiece, best quality, highres:1), peaceful beautiful sea scene +
+ +### VideoMiddle2Video + +**pose2video** +In `duffy` mode, pose of the vision condition frame is not aligned with the first frame of control video. `posealign` will solve the problem. + + + + + + + + + + + + + + + + + +
imagevideoprompt
+ + + + + + (masterpiece, best quality, highres:1) , a girl is dancing, animation +
+ + + + + (masterpiece, best quality, highres:1), is dancing, animation +
+ +### MuseTalk +The character of talk, `Sun Xinying` is a supermodel KOL. You can follow her on [douyin](https://www.douyin.com/user/MS4wLjABAAAAWDThbMPN_6Xmm_JgXexbOii1K-httbu2APdG8DvDyM8). + + + + + + + + + + + + + + +
namevideo
+ talk + + +
+ sing + + +
+ + +# TODO: +- [ ] technical report (comming soon). +- [ ] training codes. +- [ ] release pretrained unet model, which is trained with controlnet、referencenet、IPAdapter, which is better on pose2video. +- [ ] support diffusion transformer generation framework. +- [ ] release `posealign` module + +# Quickstart +Prepare python environment and install extra package like `diffusers`, `controlnet_aux`, `mmcm`. + +## Third party integration +Thanks for the third-party integration, which makes installation and use more convenient for everyone. +We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results. + +### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseV) +### [One click integration package in windows](https://www.bilibili.com/video/BV1ux4y1v7pF/?vd_source=fe03b064abab17b79e22a692551405c3) +netdisk:https://www.123pan.com/s/Pf5Yjv-Bb9W3.html + +code: glut + +## Prepare environment +You are recommended to use `docker` primarily to prepare python environment. +### prepare python env +**Attention**: we only test with docker, there are maybe trouble with conda, or requirement. We will try to fix it. Use `docker` Please. + +#### Method 1: docker +1. pull docker image +```bash +docker pull anchorxia/musev:latest +``` +2. run docker +```bash +docker run --gpus all -it --entrypoint /bin/bash anchorxia/musev:latest +``` +The default conda env is `musev`. + +#### Method 2: conda +create conda environment from environment.yaml +``` +conda env create --name musev --file ./environment.yml +``` +#### Method 3: pip requirements +``` +pip install -r requirements.txt +``` +#### Prepare mmlab package +if not use docker, should install mmlab package additionally. +```bash +pip install --no-cache-dir -U openmim +mim install mmengine +mim install "mmcv>=2.0.1" +mim install "mmdet>=3.1.0" +mim install "mmpose>=1.1.0" +``` + +### Prepare custom package / modified package +#### clone +```bash +git clone --recursive https://github.com/TMElyralab/MuseV.git +``` +#### prepare PYTHONPATH +```bash +current_dir=$(pwd) +export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV +export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/MMCM +export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/diffusers/src +export PYTHONPATH=${PYTHONPATH}:${current_dir}/MuseV/controlnet_aux/src +cd MuseV +``` + +1. `MMCM`: multi media, cross modal process package。 +1. `diffusers`: modified diffusers package based on [diffusers](https://github.com/huggingface/diffusers) +1. `controlnet_aux`: modified based on [controlnet_aux](https://github.com/TMElyralab/controlnet_aux) + + +## Download models +```bash +git clone https://huggingface.co/TMElyralab/MuseV ./checkpoints +``` +- `motion`: text2video model, trained on tiny `ucf101` and tiny `webvid` dataset, approximately 60K videos text pairs. GPU memory consumption testing on `resolution`$=512*512$, `time_size=12`. + - `musev/unet`: only has and train `unet` motion module. `GPU memory consumption` $\approx 8G$. + - `musev_referencenet`: train `unet` module, `referencenet`, `IPAdapter`. `GPU memory consumption` $\approx 12G$. + - `unet`: `motion` module, which has `to_k`, `to_v` in `Attention` layer refer to `IPAdapter` + - `referencenet`: similar to `AnimateAnyone` + - `ip_adapter_image_proj.bin`: images clip emb project layer, refer to `IPAdapter` + - `musev_referencenet_pose`: based on `musev_referencenet`, fix `referencenet`and `controlnet_pose`, train `unet motion` and `IPAdapter`. `GPU memory consumption` $\approx 12G$ +- `t2i/sd1.5`: text2image model, parameter are frozen when training motion module. Different `t2i` base_model has a significant impact.could be replaced with other t2i base. + - `majicmixRealv6Fp16`: example, download from [majicmixRealv6Fp16](https://civitai.com/models/43331?modelVersionId=94640) + - `fantasticmix_v10`: example, download from [fantasticmix_v10](https://civitai.com/models/22402?modelVersionId=26744) +- `IP-Adapter/models`: download from [IPAdapter](https://huggingface.co/h94/IP-Adapter/tree/main) + - `image_encoder`: vision clip model. + - `ip-adapter_sd15.bin`: original IPAdapter model checkpoint. + - `ip-adapter-faceid_sd15.bin`: original IPAdapter model checkpoint. + +## Inference + +### Prepare model_path +Skip this step when run example task with example inference command. +Set model path and abbreviation in config, to use abbreviation in inference script. +- T2I SD:ref to `musev/configs/model/T2I_all_model.py` +- Motion Unet: refer to `musev/configs/model/motion_model.py` +- Task: refer to `musev/configs/tasks/example.yaml` + +### musev_referencenet +#### text2video +```bash +python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --time_size 12 --fps 12 +``` +**common parameters**: +- `test_data_path`: task_path in yaml extention +- `target_datas`: sep is `,`, sample subtasks if `name` in `test_data_path` is in `target_datas`. +- `sd_model_cfg_path`: T2I sd models path, model config path or model path. +- `sd_model_name`: sd model name, which use to choose full model path in sd_model_cfg_path. multi model names with sep =`,`, or `all` +- `unet_model_cfg_path`: motion unet model config path or model path。 +- `unet_model_name`: unet model name, use to get model path in `unet_model_cfg_path`, and init unet class instance in `musev/models/unet_loader.py`. multi model names with sep=`,`, or `all`. If `unet_model_cfg_path` is model path, `unet_name` must be supported in `musev/models/unet_loader.py` +- `time_size`: num_frames per diffusion denoise generation。default=`12`. +- `n_batch`: generation numbers of shot, $total\_frames=n\_batch * time\_size + n\_viscond$, default=`1`。 +- `context_frames`: context_frames num. If `time_size` > `context_frame`,`time_size` window is split into many sub-windows for parallel denoising"。 default=`12`。 + +**To generate long videos**, there two ways: +1. `visual conditioned parallel denoise`: set `n_batch=1`, `time_size` = all frames you want. +1. `traditional end-to-end`: set `time_size` = `context_frames` = frames of a shot (`12`), `context_overlap` = 0; + + +**model parameters**: +supports `referencenet`, `IPAdapter`, `IPAdapterFaceID`, `Facein`. +- referencenet_model_name: `referencenet` model name. +- ImageClipVisionFeatureExtractor: `ImageEmbExtractor` name, extractor vision clip emb used in `IPAdapter`. +- vision_clip_model_path: `ImageClipVisionFeatureExtractor` model path. +- ip_adapter_model_name: from `IPAdapter`, it's `ImagePromptEmbProj`, used with `ImageEmbExtractor`。 +- ip_adapter_face_model_name: `IPAdapterFaceID`, from `IPAdapter` to keep faceid,should set `face_image_path`。 + +**Some parameters that affect the motion range and generation results**: +- `video_guidance_scale`: Similar to text2image, control influence between cond and uncond,default=`3.5` +- `use_condition_image`: Whether to use the given first frame for video generation, if not generate vision condition frames first. Default=`True`. +- `redraw_condition_image`: Whether to redraw the given first frame image. +- `video_negative_prompt`: Abbreviation of full `negative_prompt` in config path. default=`V2`. + + +#### video2video +`t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`. +```bash +python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12 +``` +**import parameters** + +Most of the parameters are same as `musev_text2video`. Special parameters of `video2video` are: +1. need to set `video_path` as reference video in `test_data`. Now reference video supports `rgb video` and `controlnet_middle_video`。 +- `which2video`: whether `rgb` video influences initial noise, influence of `rgb` is stronger than of controlnet condition. +- `controlnet_name`:whether to use `controlnet condition`, such as `dwpose,depth`. +- `video_is_middle`: `video_path` is `rgb video` or `controlnet_middle_video`. Can be set for every `test_data` in test_data_path. +- `video_has_condition`: whether condtion_images is aligned with the first frame of video_path. If Not, exrtact condition of `condition_images` firstly generate, and then align with concatation. set in `test_data`。 + +all controlnet_names refer to [mmcm](https://github.com/TMElyralab/MMCM/blob/main/mmcm/vision/feature_extractor/controlnet.py#L513) +```python +['pose', 'pose_body', 'pose_hand', 'pose_face', 'pose_hand_body', 'pose_hand_face', 'dwpose', 'dwpose_face', 'dwpose_hand', 'dwpose_body', 'dwpose_body_hand', 'canny', 'tile', 'hed', 'hed_scribble', 'depth', 'pidi', 'normal_bae', 'lineart', 'lineart_anime', 'zoe', 'sam', 'mobile_sam', 'leres', 'content', 'face_detector'] +``` + +### musev_referencenet_pose +Only used for `pose2video` +train based on `musev_referencenet`, fix `referencenet`, `pose-controlnet`, and `T2I`, train `motion` module and `IPAdapter`. + +`t2i` base_model has a significant impact. In this case, `fantasticmix_v10` performs better than `majicmixRealv6Fp16`. + +```bash +python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev_referencenet_pose --referencenet_model_name musev_referencenet --ip_adapter_model_name musev_referencenet_pose -test_data_path ./configs/tasks/example.yaml --vision_clip_extractor_class_name ImageClipVisionFeatureExtractor --vision_clip_model_path ./checkpoints/IP-Adapter/models/image_encoder --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12 +``` + +### musev +Only has motion module, no referencenet, requiring less gpu memory. +#### text2video +```bash +python scripts/inference/text2video.py --sd_model_name majicmixRealv6Fp16 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --target_datas yongen --time_size 12 --fps 12 +``` +#### video2video +##### pose align +```bash +python ./pose_align/pose_align.py --max_frame 200 --vidfn ./data/source_video/dance.mp4 --imgfn_refer ./data/images/man.jpg --outfn_ref_img_pose ./data/pose_align_results/ref_img_pose.jpg --outfn_align_pose_video ./data/pose_align_results/align_pose_video.mp4 --outfn ./data/pose_align_results/align_demo.mp4 +``` +- `max_frame`: how many frames to align (count from the first frame) +- `vidfn`:real dance video in rgb +- `imgfn_refer`: refer image path +- `outfn_ref_img_pose`: output path of the pose of the refer img +- `outfn_align_pose_video`: output path of the aligned video of the refer img +- `outfn`: output path of the alignment visualization + + + +https://github.com/TMElyralab/MuseV/assets/47803475/787d7193-ec69-43f4-a0e5-73986a808f51 + + + + +then you can use the aligned pose `outfn_align_pose_video` for pose guided generation. You may need to modify the example in the config file `./configs/tasks/example.yaml` +##### generation +```bash +python scripts/inference/video2video.py --sd_model_name fantasticmix_v10 --unet_model_name musev -test_data_path ./configs/tasks/example.yaml --output_dir ./output --n_batch 1 --controlnet_name dwpose_body_hand --which2video "video_middle" --target_datas dance1 --fps 12 --time_size 12 +``` + +### Gradio demo +MuseV provides gradio script to generate a GUI in a local machine to generate video conveniently. + +```bash +cd scripts/gradio +python app.py +``` + + +# Acknowledgements + +1. MuseV has referred much to [TuneAVideo](https://github.com/showlab/Tune-A-Video), [diffusers](https://github.com/huggingface/diffusers), [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone/tree/master/src/pipelines), [animatediff](https://github.com/guoyww/AnimateDiff), [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter), [AnimateAnyone](https://arxiv.org/abs/2311.17117), [VideoFusion](https://arxiv.org/abs/2303.08320), [insightface](https://github.com/deepinsight/insightface). +2. MuseV has been built on `ucf101` and `webvid` datasets. + +Thanks for open-sourcing! + +# Limitation +There are still many limitations, including + +1. Lack of generalization ability. Some visual condition image perform well, some perform bad. Some t2i pretraied model perform well, some perform bad. +1. Limited types of video generation and limited motion range, partly because of limited types of training data. The released `MuseV` has been trained on approximately 60K human text-video pairs with resolution `512*320`. `MuseV` has greater motion range while lower video quality at lower resolution. `MuseV` tends to generate less motion range with high video quality. Trained on larger, higher resolution, higher quality text-video dataset may make `MuseV` better. +1. Watermarks may appear because of `webvid`. A cleaner dataset without watermarks may solve this issue. +1. Limited types of long video generation. Visual Conditioned Parallel Denoise can solve accumulated error of video generation, but the current method is only suitable for relatively fixed camera scenes. +1. Undertrained referencenet and IP-Adapter, beacause of limited time and limited resources. +1. Understructured code. `MuseV` supports rich and dynamic features, but with complex and unrefacted codes. It takes time to familiarize. + + + +# Citation +```bib +@article{musev, + title={MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising}, + author={Xia, Zhiqiang and Chen, Zhaokang and Wu, Bin and Li, Chao and Hung, Kwok-Wai and Zhan, Chao and He, Yingjie and Zhou, Wenjiang}, + journal={arxiv}, + year={2024} +} +``` +# Disclaimer/License +1. `code`: The code of MuseV is released under the MIT License. There is no limitation for both academic and commercial usage. +1. `model`: The trained model are available for non-commercial research purposes only. +1. `other opensource model`: Other open-source models used must comply with their license, such as `insightface`, `IP-Adapter`, `ft-mse-vae`, etc. +1. The testdata are collected from internet, which are available for non-commercial research purposes only. +1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users. diff --git a/musev/__init__.py b/musev/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..1f718b3053cc235d303f1cbeb9170aa35688f436 --- /dev/null +++ b/musev/__init__.py @@ -0,0 +1,9 @@ +import os +import logging +import logging.config + +# 读取日志配置文件内容 +logging.config.fileConfig(os.path.join(os.path.dirname(__file__), "logging.conf")) + +# 创建一个日志器logger +logger = logging.getLogger("musev") diff --git a/musev/auto_prompt/__init__.py b/musev/auto_prompt/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/musev/auto_prompt/attributes/__init__.py b/musev/auto_prompt/attributes/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..effa177152548c15756a6bc8a67d273a1b709438 --- /dev/null +++ b/musev/auto_prompt/attributes/__init__.py @@ -0,0 +1,8 @@ +from ...utils.register import Register + +AttrRegister = Register(registry_name="attributes") + +# must import like bellow to ensure that each class is registered with AttrRegister: +from .human import * +from .render import * +from .style import * diff --git a/musev/auto_prompt/attributes/attr2template.py b/musev/auto_prompt/attributes/attr2template.py new file mode 100755 index 0000000000000000000000000000000000000000..707ecd23d38b63ca33a1ab19505e6467623e8978 --- /dev/null +++ b/musev/auto_prompt/attributes/attr2template.py @@ -0,0 +1,127 @@ +r""" +中文 +该模块将关键词字典转化为描述文本,生成完整的提词,从而降低对比实验成本、提升控制能力和效率。 +提词(prompy)对比实验会需要控制关键属性发生变化、其他属性不变的文本对。当需要控制的属性变量发生较大变化时,靠人为复制粘贴进行完成文本撰写工作量会非常大。 +该模块主要有三种类,分别是: +1. `BaseAttribute2Text`: 单属性文本转换类 +2. `MultiAttr2Text` 多属性文本转化类,输出`List[Tuple[str, str]`。具体如何转换为文本在 `MultiAttr2PromptTemplate`中实现。 +3. `MultiAttr2PromptTemplate`:先将2生成的多属性文本字典列表转化为完整的文本,然后再使用内置的模板`template`拼接。拼接后的文本作为实际模型输入的提词。 + 1. `template`字段若没有{},且有字符,则认为输入就是完整输入网络的`prompt`; + 2. `template`字段若含有{key},则认为是带关键词的字符串目标,多个属性由`template`字符串中顺序完全决定。关键词内容由表格中相关列通过`attr2text`转化而来; + 3. `template`字段有且只含有一个{},如`a portrait of {}`,则相关内容由 `PresetMultiAttr2PromptTemplate`中预定义好的`attrs`列表指定先后顺序; + +English +This module converts a keyword dictionary into descriptive text, generating complete prompts to reduce the cost of comparison experiments, and improve control and efficiency. + +Prompt-based comparison experiments require text pairs where the key attributes are controlled while other attributes remain constant. When the variable attributes to be controlled undergo significant changes, manually copying and pasting to write text can be very time-consuming. + +This module mainly consists of three classes: + +BaseAttribute2Text: A class for converting single attribute text. +MultiAttr2Text: A class for converting multi-attribute text, outputting List[Tuple[str, str]]. The specific implementation of how to convert to text is implemented in MultiAttr2PromptTemplate. +MultiAttr2PromptTemplate: First, the list of multi-attribute text dictionaries generated by 2 is converted into complete text, and then the built-in template template is used for concatenation. The concatenated text serves as the prompt for the actual model input. +If the template field does not contain {}, and there are characters, the input is considered the complete prompt for the network. +If the template field contains {key}, it is considered a string target with keywords, and the order of multiple attributes is completely determined by the template string. The keyword content is generated by attr2text from the relevant columns in the table. +If the template field contains only one {}, such as a portrait of {}, the relevant content is specified in the order defined by the attrs list predefined in PresetMultiAttr2PromptTemplate. +""" + +from typing import List, Tuple, Union + +from mmcm.utils.str_util import ( + has_key_brace, + merge_near_same_char, + get_word_from_key_brace_string, +) + +from .attributes import MultiAttr2Text, merge_multi_attrtext, AttriributeIsText +from . import AttrRegister + + +class MultiAttr2PromptTemplate(object): + """ + 将多属性转化为模型输入文本的实际类 + The actual class that converts multiple attributes into model input text is + """ + + def __init__( + self, + template: str, + attr2text: MultiAttr2Text, + name: str, + ) -> None: + """ + Args: + template (str): 提词模板, prompt template. + 如果`template`含有{key},则根据key来取值。 if the template field contains {key}, it means that the actual value for that part of the prompt will be determined by the corresponding key + 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。if the template field in MultiAttr2PromptTemplate contains only one {} placeholder, such as "a portrait of {}", the order of the attributes is determined by the attrs list predefined in PresetMultiAttr2PromptTemplate. The values of the attributes in the texts list are concatenated in the order specified by the attrs list. + attr2text (MultiAttr2Text): 多属性转换类。Class for converting multiple attributes into text prompt. + name (str): 该多属性文本模板类的名字,便于记忆. Class Instance name + """ + self.attr2text = attr2text + self.name = name + if template == "": + template = "{}" + self.template = template + self.template_has_key_brace = has_key_brace(template) + + def __call__(self, attributes: dict) -> Union[str, List[str]]: + texts = self.attr2text(attributes) + if not isinstance(texts, list): + texts = [texts] + prompts = [merge_multi_attrtext(text, self.template) for text in texts] + prompts = [merge_near_same_char(prompt) for prompt in prompts] + if len(prompts) == 1: + prompts = prompts[0] + return prompts + + +class KeywordMultiAttr2PromptTemplate(MultiAttr2PromptTemplate): + def __init__(self, template: str, name: str = "keywords") -> None: + """关键词模板属性2文本转化类 + 1. 获取关键词模板字符串中的关键词属性; + 2. 从import * 存储在locals()中变量中获取对应的类; + 3. 将集成了多属性转换类的`MultiAttr2Text` + Args: + template (str): 含有{key}的模板字符串 + name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "keywords". + + class for converting keyword template attributes to text + 1. Get the keyword attributes in the keyword template string; + 2. Get the corresponding class from the variables stored in locals() by import *; + 3. The `MultiAttr2Text` integrated with multiple attribute conversion classes + Args: + template (str): template string containing {key} + name (str, optional): the name of the template string, no actual use. Defaults to "keywords". + """ + assert has_key_brace( + template + ), "template should have key brace, but given {}".format(template) + keywords = get_word_from_key_brace_string(template) + funcs = [] + for word in keywords: + if word in AttrRegister: + func = AttrRegister[word](name=word) + else: + func = AttriributeIsText(name=word) + funcs.append(func) + attr2text = MultiAttr2Text(funcs, name=name) + super().__init__(template, attr2text, name) + + +class OnlySpacePromptTemplate(MultiAttr2PromptTemplate): + def __init__(self, template: str, name: str = "space_prompt") -> None: + """纯空模板,无论输入啥,都只返回空格字符串作为prompt。 + Args: + template (str): 符合只输出空格字符串的模板, + name (str, optional): 该模板字符串名字,暂无实际用处. Defaults to "space_prompt". + + Pure empty template, no matter what the input is, it will only return a space string as the prompt. + Args: + template (str): template that only outputs a space string, + name (str, optional): the name of the template string, no actual use. Defaults to "space_prompt". + """ + attr2text = None + super().__init__(template, attr2text, name) + + def __call__(self, attributes: dict) -> Union[str, List[str]]: + return "" diff --git a/musev/auto_prompt/attributes/attributes.py b/musev/auto_prompt/attributes/attributes.py new file mode 100755 index 0000000000000000000000000000000000000000..1df78a2fa7be1929c871fd752ea68377d65bddd3 --- /dev/null +++ b/musev/auto_prompt/attributes/attributes.py @@ -0,0 +1,227 @@ +from copy import deepcopy +from typing import List, Tuple, Dict + +from mmcm.utils.str_util import has_key_brace + + +class BaseAttribute2Text(object): + """ + 属性转化为文本的基类,该类作用就是输入属性,转化为描述文本。 + Base class for converting attributes to text which converts attributes to prompt text. + """ + + name = "base_attribute" + + def __init__(self, name: str = None) -> None: + """这里类实例初始化设置`name`参数,主要是为了便于一些没有提前实现、通过字符串参数实现的新属性。 + Theses class instances are initialized with the `name` parameter to facilitate the implementation of new attributes that are not implemented in advance and are implemented through string parameters. + + Args: + name (str, optional): _description_. Defaults to None. + """ + if name is not None: + self.name = name + + def __call__(self, attributes) -> str: + raise NotImplementedError + + +class AttributeIsTextAndName(BaseAttribute2Text): + """ + 属性文本转换功能类,将key和value拼接在一起作为文本. + class for converting attributes to text which concatenates the key and value together as text. + """ + + name = "attribute_is_text_name" + + def __call__(self, attributes) -> str: + if attributes == "" or attributes is None: + return "" + attributes = attributes.split(",") + text = ", ".join( + [ + "{} {}".format(attr, self.name) if attr != "" else "" + for attr in attributes + ] + ) + return text + + +class AttriributeIsText(BaseAttribute2Text): + """ + 属性文本转换功能类,将value作为文本. + class for converting attributes to text which only uses the value as text. + """ + + name = "attribute_is_text" + + def __call__(self, attributes: str) -> str: + if attributes == "" or attributes is None: + return "" + attributes = str(attributes) + attributes = attributes.split(",") + text = ", ".join(["{}".format(attr) for attr in attributes]) + return text + + +class MultiAttr2Text(object): + """将多属性组成的字典转换成完整的文本描述,目前采用简单的前后拼接方式,以`, `作为拼接符号 + class for converting a dictionary of multiple attributes into a complete text description. Currently, a simple front and back splicing method is used, with `, ` as the splicing symbol. + + Args: + object (_type_): _description_ + """ + + def __init__(self, funcs: list, name) -> None: + """ + Args: + funcs (list): 继承`BaseAttribute2Text`并实现了`__call__`函数的类. Inherited `BaseAttribute2Text` and implemented the `__call__` function of the class. + name (_type_): 该多属性的一个名字,可通过该类方便了解对应相关属性都是关于啥的。 name of the multi-attribute, which can be used to easily understand what the corresponding related attributes are about. + """ + if not isinstance(funcs, list): + funcs = [funcs] + self.funcs = funcs + self.name = name + + def __call__( + self, dct: dict, ignored_blank_str: bool = False + ) -> List[Tuple[str, str]]: + """ + 有时候一个属性可能会返回多个文本,如 style cartoon会返回宫崎骏和皮克斯两种风格,采用外积增殖成多个字典。 + sometimes an attribute may return multiple texts, such as style cartoon will return two styles, Miyazaki and Pixar, which are multiplied into multiple dictionaries by the outer product. + Args: + dct (dict): 多属性组成的字典,可能有self.funcs关注的属性也可能没有,self.funcs按照各自的名字按需提取关注的属性和值,并转化成文本. + Dict of multiple attributes, may or may not have the attributes that self.funcs is concerned with. self.funcs extracts the attributes and values of interest according to their respective names and converts them into text. + ignored_blank_str (bool): 如果某个attr2text返回的是空字符串,是否要过滤掉该属性。默认`False`. + If the text returned by an attr2text is an empty string, whether to filter out the attribute. Defaults to `False`. + Returns: + Union[List[List[Tuple[str, str]]], List[Tuple[str, str]]: 多组多属性文本字典列表. Multiple sets of multi-attribute text dictionaries. + """ + attrs_lst = [[]] + for func in self.funcs: + if func.name in dct: + attrs = func(dct[func.name]) + if isinstance(attrs, str): + for i in range(len(attrs_lst)): + attrs_lst[i].append((func.name, attrs)) + else: + # 一个属性可能会返回多个文本 + n_attrs = len(attrs) + new_attrs_lst = [] + for n in range(n_attrs): + attrs_lst_cp = deepcopy(attrs_lst) + for i in range(len(attrs_lst_cp)): + attrs_lst_cp[i].append((func.name, attrs[n])) + new_attrs_lst.extend(attrs_lst_cp) + attrs_lst = new_attrs_lst + + texts = [ + [ + (attr, text) + for (attr, text) in attrs + if not (text == "" and ignored_blank_str) + ] + for attrs in attrs_lst + ] + return texts + + +def format_tuple_texts(template: str, texts: Tuple[str, str]) -> str: + """使用含有"{}" 的模板对多属性文本元组进行拼接,形成新文本 + concatenate multiple attribute text tuples using a template containing "{}" to form a new text + Args: + template (str): + texts (Tuple[str, str]): 多属性文本元组. multiple attribute text tuples + + Returns: + str: 拼接后的新文本, merged new text + """ + merged_text = ", ".join([text[1] for text in texts if text[1] != ""]) + merged_text = template.format(merged_text) + return merged_text + + +def format_dct_texts(template: str, texts: Dict[str, str]) -> str: + """使用含有"{key}" 的模板对多属性文本字典进行拼接,形成新文本 + concatenate multiple attribute text dictionaries using a template containing "{key}" to form a new text + Args: + template (str): + texts (Tuple[str, str]): 多属性文本字典. multiple attribute text dictionaries + + Returns: + str: 拼接后的新文本, merged new text + """ + merged_text = template.format(**texts) + return merged_text + + +def merge_multi_attrtext(texts: List[Tuple[str, str]], template: str = None) -> str: + """对多属性文本元组进行拼接,形成新文本。 + 如果`template`含有{key},则根据key来取值; + 如果`template`有且只有1个{},则根据先后顺序对texts中的值进行拼接。 + + concatenate multiple attribute text tuples to form a new text. + if `template` contains {key}, the value is taken according to the key; + if `template` contains only one {}, the values in texts are concatenated in order. + Args: + texts (List[Tuple[str, str]]): Tuple[str, str]第一个str是属性名,第二个str是属性转化的文本. + Tuple[str, str] The first str is the attribute name, and the second str is the text of the attribute conversion. + template (str, optional): template . Defaults to None. + + Returns: + str: 拼接后的新文本, merged new text + """ + if not isinstance(texts, List): + texts = [texts] + if template is None or template == "": + template = "{}" + if has_key_brace(template): + texts = {k: v for k, v in texts} + merged_text = format_dct_texts(template, texts) + else: + merged_text = format_tuple_texts(template, texts) + return merged_text + + +class PresetMultiAttr2Text(MultiAttr2Text): + """预置了多种关注属性转换的类,方便维护 + class for multiple attribute conversion with multiple attention attributes preset for easy maintenance + + """ + + preset_attributes = [] + + def __init__( + self, funcs: List = None, use_preset: bool = True, name: str = "preset" + ) -> None: + """虽然预置了关注的属性列表和转换类,但也允许定义示例时,进行更新。 + 注意`self.preset_attributes`的元素只是类名字,以便减少实例化的资源消耗。而funcs是实例化后的属性转换列表。 + + Although the list of attention attributes and conversion classes is preset, it is also allowed to be updated when defining an instance. + Note that the elements of `self.preset_attributes` are only class names, in order to reduce the resource consumption of instantiation. And funcs is a list of instantiated attribute conversions. + + Args: + funcs (List, optional): list of funcs . Defaults to None. + use_preset (bool, optional): _description_. Defaults to True. + name (str, optional): _description_. Defaults to "preset". + """ + if use_preset: + preset_funcs = self.preset() + else: + preset_funcs = [] + if funcs is None: + funcs = [] + if not isinstance(funcs, list): + funcs = [funcs] + funcs_names = [func.name for func in funcs] + preset_funcs = [ + preset_func + for preset_func in preset_funcs + if preset_func.name not in funcs_names + ] + funcs = funcs + preset_funcs + super().__init__(funcs, name) + + def preset(self): + funcs = [cls() for cls in self.preset_attributes] + return funcs diff --git a/musev/auto_prompt/attributes/human.py b/musev/auto_prompt/attributes/human.py new file mode 100755 index 0000000000000000000000000000000000000000..974ac421b5e20418fa2fe7dc28125373b3ed28ff --- /dev/null +++ b/musev/auto_prompt/attributes/human.py @@ -0,0 +1,424 @@ +from copy import deepcopy +import numpy as np +import random +import json + +from .attributes import ( + MultiAttr2Text, + AttriributeIsText, + AttributeIsTextAndName, + PresetMultiAttr2Text, +) +from .style import Style +from .render import Render +from . import AttrRegister + + +__all__ = [ + "Age", + "Sex", + "Singing", + "Country", + "Lighting", + "Headwear", + "Eyes", + "Irises", + "Hair", + "Skin", + "Face", + "Smile", + "Expression", + "Clothes", + "Nose", + "Mouth", + "Beard", + "Necklace", + "KeyWords", + "InsightFace", + "Caption", + "Env", + "Decoration", + "Festival", + "SpringHeadwear", + "SpringClothes", + "Animal", +] + + +@AttrRegister.register +class Sex(AttriributeIsText): + name = "sex" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Headwear(AttriributeIsText): + name = "headwear" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Expression(AttriributeIsText): + name = "expression" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class KeyWords(AttriributeIsText): + name = "keywords" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Singing(AttriributeIsText): + def __init__(self, name: str = "singing") -> None: + super().__init__(name) + + +@AttrRegister.register +class Country(AttriributeIsText): + name = "country" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Clothes(AttriributeIsText): + name = "clothes" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Age(AttributeIsTextAndName): + name = "age" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + def __call__(self, attributes: str) -> str: + if not isinstance(attributes, str): + attributes = str(attributes) + attributes = attributes.split(",") + text = ", ".join( + ["{}-year-old".format(attr) if attr != "" else "" for attr in attributes] + ) + return text + + +@AttrRegister.register +class Eyes(AttributeIsTextAndName): + name = "eyes" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Hair(AttributeIsTextAndName): + name = "hair" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Background(AttributeIsTextAndName): + name = "background" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Skin(AttributeIsTextAndName): + name = "skin" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Face(AttributeIsTextAndName): + name = "face" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Smile(AttributeIsTextAndName): + name = "smile" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Nose(AttributeIsTextAndName): + name = "nose" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Mouth(AttributeIsTextAndName): + name = "mouth" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Beard(AttriributeIsText): + name = "beard" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Necklace(AttributeIsTextAndName): + name = "necklace" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Irises(AttributeIsTextAndName): + name = "irises" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +@AttrRegister.register +class Lighting(AttributeIsTextAndName): + name = "lighting" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + +PresetPortraitAttributes = [ + Age, + Sex, + Singing, + Country, + Lighting, + Headwear, + Eyes, + Irises, + Hair, + Skin, + Face, + Smile, + Expression, + Clothes, + Nose, + Mouth, + Beard, + Necklace, + Style, + KeyWords, + Render, +] + + +class PortraitMultiAttr2Text(PresetMultiAttr2Text): + preset_attributes = PresetPortraitAttributes + + def __init__(self, funcs: list = None, use_preset=True, name="portrait") -> None: + super().__init__(funcs, use_preset, name) + + +@AttrRegister.register +class InsightFace(AttriributeIsText): + name = "insight_face" + face_render_dict = { + "boy": "handsome,elegant", + "girl": "gorgeous,kawaii,colorful", + } + key_words = "delicate face,beautiful eyes" + + def __call__(self, attributes: str) -> str: + """将insight faces 检测的结果转化成prompt + convert the results of insight faces detection to prompt + Args: + face_list (_type_): _description_ + + Returns: + _type_: _description_ + """ + attributes = json.loads(attributes) + face_list = attributes["info"] + if len(face_list) == 0: + return "" + + if attributes["image_type"] == "body": + for face in face_list: + if "black" in face and face["black"]: + return "african,dark skin" + return "" + + gender_dict = {"girl": 0, "boy": 0} + face_render_list = [] + black = False + + for face in face_list: + if face["ratio"] < 0.02: + continue + + if face["gender"] == 0: + gender_dict["girl"] += 1 + face_render_list.append(self.face_render_dict["girl"]) + else: + gender_dict["boy"] += 1 + face_render_list.append(self.face_render_dict["boy"]) + + if "black" in face and face["black"]: + black = True + + if len(face_render_list) == 0: + return "" + elif len(face_render_list) == 1: + solo = True + else: + solo = False + + gender = "" + for g, num in gender_dict.items(): + if num > 0: + if gender: + gender += ", " + gender += "{}{}".format(num, g) + if num > 1: + gender += "s" + + face_render_list = ",".join(face_render_list) + face_render_list = face_render_list.split(",") + face_render = list(set(face_render_list)) + face_render.sort(key=face_render_list.index) + face_render = ",".join(face_render) + if gender_dict["girl"] == 0: + face_render = "male focus," + face_render + + insightface_prompt = "{},{},{}".format(gender, face_render, self.key_words) + + if solo: + insightface_prompt += ",solo" + if black: + insightface_prompt = "african,dark skin," + insightface_prompt + + return insightface_prompt + + +@AttrRegister.register +class Caption(AttriributeIsText): + name = "caption" + + +@AttrRegister.register +class Env(AttriributeIsText): + name = "env" + envs_list = [ + "east asian architecture", + "fireworks", + "snow, snowflakes", + "snowing, snowflakes", + ] + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.envs_list) + + +@AttrRegister.register +class Decoration(AttriributeIsText): + name = "decoration" + + def __init__(self, name: str = None) -> None: + self.decoration_list = [ + "chinese knot", + "flowers", + "food", + "lanterns", + "red envelop", + ] + super().__init__(name) + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.decoration_list) + + +@AttrRegister.register +class Festival(AttriributeIsText): + name = "festival" + festival_list = ["new year"] + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.festival_list) + + +@AttrRegister.register +class SpringHeadwear(AttriributeIsText): + name = "spring_headwear" + headwear_list = ["rabbit ears", "rabbit ears, fur hat"] + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.headwear_list) + + +@AttrRegister.register +class SpringClothes(AttriributeIsText): + name = "spring_clothes" + clothes_list = [ + "mittens,chinese clothes", + "mittens,fur trim", + "mittens,red scarf", + "mittens,winter clothes", + ] + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.clothes_list) + + +@AttrRegister.register +class Animal(AttriributeIsText): + name = "animal" + animal_list = ["rabbit", "holding rabbits"] + + def __call__(self, attributes: str = None) -> str: + if attributes != "" and attributes != " " and attributes is not None: + return attributes + else: + return random.choice(self.animal_list) diff --git a/musev/auto_prompt/attributes/render.py b/musev/auto_prompt/attributes/render.py new file mode 100755 index 0000000000000000000000000000000000000000..8dda519f985595e1b23f390121d10e7a11652ee4 --- /dev/null +++ b/musev/auto_prompt/attributes/render.py @@ -0,0 +1,33 @@ +from mmcm.utils.util import flatten + +from .attributes import BaseAttribute2Text +from . import AttrRegister + +__all__ = ["Render"] + +RenderMap = { + "Epic": "artstation, epic environment, highly detailed, 8k, HD", + "HD": "8k, highly detailed", + "EpicHD": "hyper detailed, beautiful lighting, epic environment, octane render, cinematic, 8k", + "Digital": "detailed illustration, crisp lines, digital art, 8k, trending on artstation", + "Unreal1": "artstation, concept art, smooth, sharp focus, illustration, unreal engine 5, 8k", + "Unreal2": "concept art, octane render, artstation, epic environment, highly detailed, 8k", +} + + +@AttrRegister.register +class Render(BaseAttribute2Text): + name = "render" + + def __init__(self, name: str = None) -> None: + super().__init__(name) + + def __call__(self, attributes: str) -> str: + if attributes == "" or attributes is None: + return "" + attributes = attributes.split(",") + render = [RenderMap[attr] for attr in attributes if attr in RenderMap] + render = flatten(render, ignored_iterable_types=[str]) + if len(render) == 1: + render = render[0] + return render diff --git a/musev/auto_prompt/attributes/style.py b/musev/auto_prompt/attributes/style.py new file mode 100755 index 0000000000000000000000000000000000000000..da81b6cab0d5213882cc87d855e36c3475ceb7ec --- /dev/null +++ b/musev/auto_prompt/attributes/style.py @@ -0,0 +1,12 @@ +from .attributes import AttriributeIsText +from . import AttrRegister + +__all__ = ["Style"] + + +@AttrRegister.register +class Style(AttriributeIsText): + name = "style" + + def __init__(self, name: str = None) -> None: + super().__init__(name) diff --git a/musev/auto_prompt/human.py b/musev/auto_prompt/human.py new file mode 100755 index 0000000000000000000000000000000000000000..77a38b65f04e8d7fd34da1ae2efec6a8947f666e --- /dev/null +++ b/musev/auto_prompt/human.py @@ -0,0 +1,40 @@ +"""负责按照人相关的属性转化成提词 +""" +from typing import List + +from .attributes.human import PortraitMultiAttr2Text +from .attributes.attributes import BaseAttribute2Text +from .attributes.attr2template import MultiAttr2PromptTemplate + + +class PortraitAttr2PromptTemplate(MultiAttr2PromptTemplate): + """可以将任务字典转化为形象提词模板类 + template class for converting task dictionaries into image prompt templates + Args: + MultiAttr2PromptTemplate (_type_): _description_ + """ + + templates = "a portrait of {}" + + def __init__( + self, templates: str = None, attr2text: List = None, name: str = "portrait" + ) -> None: + """ + + Args: + templates (str, optional): 形象提词模板,若为None,则使用默认的类属性. Defaults to None. + portrait prompt template, if None, the default class attribute is used. + attr2text (List, optional): 形象类需要新增、更新的属性列表,默认使用PortraitMultiAttr2Text中定义的形象属性. Defaults to None. + the list of attributes that need to be added or updated in the image class, by default, the image attributes defined in PortraitMultiAttr2Text are used. + name (str, optional): 该形象类的名字. Defaults to "portrait". + class name of this class instance + """ + if ( + attr2text is None + or isinstance(attr2text, list) + or isinstance(attr2text, BaseAttribute2Text) + ): + attr2text = PortraitMultiAttr2Text(funcs=attr2text) + if templates is None: + templates = self.templates + super().__init__(templates, attr2text, name=name) diff --git a/musev/auto_prompt/load_template.py b/musev/auto_prompt/load_template.py new file mode 100755 index 0000000000000000000000000000000000000000..3676b5c629e4fc19dddfd762d63137cc1cf1e23b --- /dev/null +++ b/musev/auto_prompt/load_template.py @@ -0,0 +1,37 @@ +from mmcm.utils.str_util import has_key_brace + +from .human import PortraitAttr2PromptTemplate +from .attributes.attr2template import ( + KeywordMultiAttr2PromptTemplate, + OnlySpacePromptTemplate, +) + + +def get_template_by_name(template: str, name: str = None): + """根据 template_name 确定 prompt 生成器类 + choose prompt generator class according to template_name + Args: + name (str): template 的名字简称,便于指定. template name abbreviation, for easy reference + + Raises: + ValueError: ValueError: 如果name不在支持的列表中,则报错. if name is not in the supported list, an error is reported. + + Returns: + MultiAttr2PromptTemplate: 能够将任务字典转化为提词的 实现了__call__功能的类. class that can convert task dictionaries into prompts and implements the __call__ function + + """ + if template == "" or template is None: + template = OnlySpacePromptTemplate(template=template) + elif has_key_brace(template): + # if has_key_brace(template): + template = KeywordMultiAttr2PromptTemplate(template=template) + else: + if name == "portrait": + template = PortraitAttr2PromptTemplate(templates=template) + else: + raise ValueError( + "PresetAttr2PromptTemplate only support one of [portrait], but given {}".format( + name + ) + ) + return template diff --git a/musev/auto_prompt/util.py b/musev/auto_prompt/util.py new file mode 100755 index 0000000000000000000000000000000000000000..ec7c67e483e2815cd59947c886b2157f10c7c3df --- /dev/null +++ b/musev/auto_prompt/util.py @@ -0,0 +1,25 @@ +from copy import deepcopy +from typing import Dict, List + +from .load_template import get_template_by_name + + +def generate_prompts(tasks: List[Dict]) -> List[Dict]: + new_tasks = [] + for task in tasks: + task["origin_prompt"] = deepcopy(task["prompt"]) + # 如果prompt单元值含有模板 {},或者 没有填写任何值(默认为空模板),则使用原prompt值 + if "{" not in task["prompt"] and len(task["prompt"]) != 0: + new_tasks.append(task) + else: + template = get_template_by_name( + template=task["prompt"], name=task.get("template_name", None) + ) + prompts = template(task) + if not isinstance(prompts, list) and isinstance(prompts, str): + prompts = [prompts] + for prompt in prompts: + task_cp = deepcopy(task) + task_cp["prompt"] = prompt + new_tasks.append(task_cp) + return new_tasks diff --git a/musev/data/__init__.py b/musev/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/musev/data/data_util.py b/musev/data/data_util.py new file mode 100755 index 0000000000000000000000000000000000000000..d413175d4234a355998641e9cd64c7076b154897 --- /dev/null +++ b/musev/data/data_util.py @@ -0,0 +1,681 @@ +from typing import List, Dict, Literal, Union, Tuple +import os +import string +import logging + +import torch +import numpy as np +from einops import rearrange, repeat + +logger = logging.getLogger(__name__) + + +def generate_tasks_of_dir( + path: str, + output_dir: str, + exts: Tuple[str], + same_dir_name: bool = False, + **kwargs, +) -> List[Dict]: + """covert video directory into tasks + + Args: + path (str): _description_ + output_dir (str): _description_ + exts (Tuple[str]): _description_ + same_dir_name (bool, optional): 存储路径是否保留和源视频相同的父文件名. Defaults to False. + whether keep the same parent dir name as the source video + Returns: + List[Dict]: _description_ + """ + tasks = [] + for rootdir, dirs, files in os.walk(path): + for basename in files: + if basename.lower().endswith(exts): + video_path = os.path.join(rootdir, basename) + filename, ext = basename.split(".") + rootdir_name = os.path.basename(rootdir) + if same_dir_name: + save_path = os.path.join( + output_dir, rootdir_name, f"{filename}.h5py" + ) + save_dir = os.path.join(output_dir, rootdir_name) + else: + save_path = os.path.join(output_dir, f"{filename}.h5py") + save_dir = output_dir + task = { + "video_path": video_path, + "output_path": save_path, + "output_dir": save_dir, + "filename": filename, + "ext": ext, + } + task.update(kwargs) + tasks.append(task) + return tasks + + +def sample_by_idx( + T: int, + n_sample: int, + sample_rate: int, + sample_start_idx: int = None, + change_sample_rate: bool = False, + seed: int = None, + whether_random: bool = True, + n_independent: int = 0, +) -> List[int]: + """given a int to represent candidate list, sample n_sample with sample_rate from the candidate list + + Args: + T (int): _description_ + n_sample (int): 目标采样数目. sample number + sample_rate (int): 采样率, 每隔sample_rate个采样一个. sample interval, pick one per sample_rate number + sample_start_idx (int, optional): 采样开始位置的选择. start position to sample . Defaults to 0. + change_sample_rate (bool, optional): 是否可以通过降低sample_rate的方式来完成采样. whether allow changing sample_rate to finish sample process. Defaults to False. + whether_random (bool, optional): 是否最后随机选择开始点. whether randomly choose sample start position. Defaults to False. + + Raises: + ValueError: T / sample_rate should be larger than n_sample + Returns: + List[int]: 采样的索引位置. sampled index position + """ + if T < n_sample: + raise ValueError(f"T({T}) < n_sample({n_sample})") + else: + if T / sample_rate < n_sample: + if not change_sample_rate: + raise ValueError( + f"T({T}) / sample_rate({sample_rate}) < n_sample({n_sample})" + ) + else: + while T / sample_rate < n_sample: + sample_rate -= 1 + logger.error( + f"sample_rate{sample_rate+1} is too large, decrease to {sample_rate}" + ) + if sample_rate == 0: + raise ValueError("T / sample_rate < n_sample") + + if sample_start_idx is None: + if whether_random: + sample_start_idx_candidates = np.arange(T - n_sample * sample_rate) + if seed is not None: + np.random.seed(seed) + sample_start_idx = np.random.choice(sample_start_idx_candidates, 1)[0] + + else: + sample_start_idx = 0 + sample_end_idx = sample_start_idx + sample_rate * n_sample + sample = list(range(sample_start_idx, sample_end_idx, sample_rate)) + if n_independent == 0: + n_independent_sample = None + else: + left_candidate = np.array( + list(range(0, sample_start_idx)) + list(range(sample_end_idx, T)) + ) + if len(left_candidate) >= n_independent: + # 使用两端的剩余空间采样, use the left space to sample + n_independent_sample = np.random.choice(left_candidate, n_independent) + else: + # 当两端没有剩余采样空间时,使用任意不是sample中的帧 + # if no enough space to sample, use any frame not in sample + left_candidate = np.array(list(set(range(T) - set(sample)))) + n_independent_sample = np.random.choice(left_candidate, n_independent) + + return sample, sample_rate, n_independent_sample + + +def sample_tensor_by_idx( + tensor: Union[torch.Tensor, np.ndarray], + n_sample: int, + sample_rate: int, + sample_start_idx: int = 0, + change_sample_rate: bool = False, + seed: int = None, + dim: int = 0, + return_type: Literal["numpy", "torch"] = "torch", + whether_random: bool = True, + n_independent: int = 0, +) -> Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: + """sample sub_tensor + + Args: + tensor (Union[torch.Tensor, np.ndarray]): _description_ + n_sample (int): _description_ + sample_rate (int): _description_ + sample_start_idx (int, optional): _description_. Defaults to 0. + change_sample_rate (bool, optional): _description_. Defaults to False. + seed (int, optional): _description_. Defaults to None. + dim (int, optional): _description_. Defaults to 0. + return_type (Literal["numpy", "torch"], optional): _description_. Defaults to "torch". + whether_random (bool, optional): _description_. Defaults to True. + n_independent (int, optional): 独立于n_sample的采样数量. Defaults to 0. + n_independent sample number that is independent of n_sample + + Returns: + Tuple[torch.Tensor, torch.Tensor, int, torch.Tensor, torch.Tensor]: sampled tensor + """ + if isinstance(tensor, np.ndarray): + tensor = torch.from_numpy(tensor) + T = tensor.shape[dim] + sample_idx, sample_rate, independent_sample_idx = sample_by_idx( + T, + n_sample, + sample_rate, + sample_start_idx, + change_sample_rate, + seed, + whether_random=whether_random, + n_independent=n_independent, + ) + sample_idx = torch.LongTensor(sample_idx) + sample = torch.index_select(tensor, dim, sample_idx) + if independent_sample_idx is not None: + independent_sample_idx = torch.LongTensor(independent_sample_idx) + independent_sample = torch.index_select(tensor, dim, independent_sample_idx) + else: + independent_sample = None + independent_sample_idx = None + if return_type == "numpy": + sample = sample.cpu().numpy() + return sample, sample_idx, sample_rate, independent_sample, independent_sample_idx + + +def concat_two_tensor( + data1: torch.Tensor, + data2: torch.Tensor, + dim: int, + method: Literal[ + "first_in_first_out", "first_in_last_out", "intertwine", "index" + ] = "first_in_first_out", + data1_index: torch.long = None, + data2_index: torch.long = None, + return_index: bool = False, +): + """concat two tensor along dim with given method + + Args: + data1 (torch.Tensor): first in data + data2 (torch.Tensor): last in data + dim (int): _description_ + method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine" ], optional): _description_. Defaults to "first_in_first_out". + + Raises: + NotImplementedError: unsupported method + ValueError: unsupported method + + Returns: + _type_: _description_ + """ + len_data1 = data1.shape[dim] + len_data2 = data2.shape[dim] + + if method == "first_in_first_out": + res = torch.concat([data1, data2], dim=dim) + data1_index = range(len_data1) + data2_index = [len_data1 + x for x in range(len_data2)] + elif method == "first_in_last_out": + res = torch.concat([data2, data1], dim=dim) + data2_index = range(len_data2) + data1_index = [len_data2 + x for x in range(len_data1)] + elif method == "intertwine": + raise NotImplementedError("intertwine") + elif method == "index": + res = concat_two_tensor_with_index( + data1=data1, + data1_index=data1_index, + data2=data2, + data2_index=data2_index, + dim=dim, + ) + else: + raise ValueError( + "only support first_in_first_out, first_in_last_out, intertwine, index" + ) + if return_index: + return res, data1_index, data2_index + else: + return res + + +def concat_two_tensor_with_index( + data1: torch.Tensor, + data1_index: torch.LongTensor, + data2: torch.Tensor, + data2_index: torch.LongTensor, + dim: int, +) -> torch.Tensor: + """_summary_ + + Args: + data1 (torch.Tensor): b1*c1*h1*w1*... + data1_index (torch.LongTensor): N, if dim=1, N=c1 + data2 (torch.Tensor): b2*c2*h2*w2*... + data2_index (torch.LongTensor): M, if dim=1, M=c2 + dim (int): int + + Returns: + torch.Tensor: b*c*h*w*..., if dim=1, b=b1=b2, c=c1+c2, h=h1=h2, w=w1=w2,... + """ + shape1 = list(data1.shape) + shape2 = list(data2.shape) + target_shape = list(shape1) + target_shape[dim] = shape1[dim] + shape2[dim] + target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) + target = batch_index_copy(target, dim=dim, index=data1_index, source=data1) + target = batch_index_copy(target, dim=dim, index=data2_index, source=data2) + return target + + +def repeat_index_to_target_size( + index: torch.LongTensor, target_size: int +) -> torch.LongTensor: + if len(index.shape) == 1: + index = repeat(index, "n -> b n", b=target_size) + if len(index.shape) == 2: + remainder = target_size % index.shape[0] + assert ( + remainder == 0 + ), f"target_size % index.shape[0] must be zero, but give {target_size % index.shape[0]}" + index = repeat(index, "b n -> (b c) n", c=int(target_size / index.shape[0])) + return index + + +def batch_concat_two_tensor_with_index( + data1: torch.Tensor, + data1_index: torch.LongTensor, + data2: torch.Tensor, + data2_index: torch.LongTensor, + dim: int, +) -> torch.Tensor: + return concat_two_tensor_with_index(data1, data1_index, data2, data2_index, dim) + + +def interwine_two_tensor( + data1: torch.Tensor, + data2: torch.Tensor, + dim: int, + return_index: bool = False, +) -> torch.Tensor: + shape1 = list(data1.shape) + shape2 = list(data2.shape) + target_shape = list(shape1) + target_shape[dim] = shape1[dim] + shape2[dim] + target = torch.zeros(target_shape, device=data1.device, dtype=data1.dtype) + data1_reshape = torch.swapaxes(data1, 0, dim) + data2_reshape = torch.swapaxes(data2, 0, dim) + target = torch.swapaxes(target, 0, dim) + total_index = set(range(target_shape[dim])) + data1_index = range(0, 2 * shape1[dim], 2) + data2_index = sorted(list(set(total_index) - set(data1_index))) + data1_index = torch.LongTensor(data1_index) + data2_index = torch.LongTensor(data2_index) + target[data1_index, ...] = data1_reshape + target[data2_index, ...] = data2_reshape + target = torch.swapaxes(target, 0, dim) + if return_index: + return target, data1_index, data2_index + else: + return target + + +def split_index( + indexs: torch.Tensor, + n_first: int = None, + n_last: int = None, + method: Literal[ + "first_in_first_out", "first_in_last_out", "intertwine", "index", "random" + ] = "first_in_first_out", +): + """_summary_ + + Args: + indexs (List): _description_ + n_first (int): _description_ + n_last (int): _description_ + method (Literal[ "first_in_first_out", "first_in_last_out", "intertwine", "index" ], optional): _description_. Defaults to "first_in_first_out". + + Raises: + NotImplementedError: _description_ + + Returns: + first_index: _description_ + last_index: + """ + # assert ( + # n_first is None and n_last is None + # ), "must assign one value for n_first or n_last" + n_total = len(indexs) + if n_first is None: + n_first = n_total - n_last + if n_last is None: + n_last = n_total - n_first + assert len(indexs) == n_first + n_last + if method == "first_in_first_out": + first_index = indexs[:n_first] + last_index = indexs[n_first:] + elif method == "first_in_last_out": + first_index = indexs[n_last:] + last_index = indexs[:n_last] + elif method == "intertwine": + raise NotImplementedError + elif method == "random": + idx_ = torch.randperm(len(indexs)) + first_index = indexs[idx_[:n_first]] + last_index = indexs[idx_[n_first:]] + return first_index, last_index + + +def split_tensor( + tensor: torch.Tensor, + dim: int, + n_first=None, + n_last=None, + method: Literal[ + "first_in_first_out", "first_in_last_out", "intertwine", "index", "random" + ] = "first_in_first_out", + need_return_index: bool = False, +): + device = tensor.device + total = tensor.shape[dim] + if n_first is None: + n_first = total - n_last + if n_last is None: + n_last = total - n_first + indexs = torch.arange( + total, + dtype=torch.long, + device=device, + ) + ( + first_index, + last_index, + ) = split_index( + indexs=indexs, + n_first=n_first, + method=method, + ) + first_tensor = torch.index_select(tensor, dim=dim, index=first_index) + last_tensor = torch.index_select(tensor, dim=dim, index=last_index) + if need_return_index: + return ( + first_tensor, + last_tensor, + first_index, + last_index, + ) + else: + return (first_tensor, last_tensor) + + +# TODO: 待确定batch_index_select的优化 +def batch_index_select( + tensor: torch.Tensor, index: torch.LongTensor, dim: int +) -> torch.Tensor: + """_summary_ + + Args: + tensor (torch.Tensor): D1*D2*D3*D4... + index (torch.LongTensor): D1*N or N, N<= tensor.shape[dim] + dim (int): dim to select + + Returns: + torch.Tensor: D1*...*N*... + """ + # TODO: now only support N same for every d1 + if len(index.shape) == 1: + return torch.index_select(tensor, dim=dim, index=index) + else: + index = repeat_index_to_target_size(index, tensor.shape[0]) + out = [] + for i in torch.arange(tensor.shape[0]): + sub_tensor = tensor[i] + sub_index = index[i] + d = torch.index_select(sub_tensor, dim=dim - 1, index=sub_index) + out.append(d) + return torch.stack(out).to(dtype=tensor.dtype) + + +def batch_index_copy( + tensor: torch.Tensor, dim: int, index: torch.LongTensor, source: torch.Tensor +) -> torch.Tensor: + """_summary_ + + Args: + tensor (torch.Tensor): b*c*h + dim (int): + index (torch.LongTensor): b*d, + source (torch.Tensor): + b*d*h*..., if dim=1 + b*c*d*..., if dim=2 + + Returns: + torch.Tensor: b*c*d*... + """ + if len(index.shape) == 1: + tensor.index_copy_(dim=dim, index=index, source=source) + else: + index = repeat_index_to_target_size(index, tensor.shape[0]) + + batch_size = tensor.shape[0] + for b in torch.arange(batch_size): + sub_index = index[b] + sub_source = source[b] + sub_tensor = tensor[b] + sub_tensor.index_copy_(dim=dim - 1, index=sub_index, source=sub_source) + tensor[b] = sub_tensor + return tensor + + +def batch_index_fill( + tensor: torch.Tensor, + dim: int, + index: torch.LongTensor, + value: Literal[torch.Tensor, torch.float], +) -> torch.Tensor: + """_summary_ + + Args: + tensor (torch.Tensor): b*c*h + dim (int): + index (torch.LongTensor): b*d, + value (torch.Tensor): b + + Returns: + torch.Tensor: b*c*d*... + """ + index = repeat_index_to_target_size(index, tensor.shape[0]) + batch_size = tensor.shape[0] + for b in torch.arange(batch_size): + sub_index = index[b] + sub_value = value[b] if isinstance(value, torch.Tensor) else value + sub_tensor = tensor[b] + sub_tensor.index_fill_(dim - 1, sub_index, sub_value) + tensor[b] = sub_tensor + return tensor + + +def adaptive_instance_normalization( + src: torch.Tensor, + dst: torch.Tensor, + eps: float = 1e-6, +): + """ + Args: + src (torch.Tensor): b c t h w + dst (torch.Tensor): b c t h w + """ + ndim = src.ndim + if ndim == 5: + dim = (2, 3, 4) + elif ndim == 4: + dim = (2, 3) + elif ndim == 3: + dim = 2 + else: + raise ValueError("only support ndim in [3,4,5], but given {ndim}") + var, mean = torch.var_mean(src, dim=dim, keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + dst = align_repeat_tensor_single_dim(dst, src.shape[0], dim=0) + mean_acc, var_acc = torch.var_mean(dst, dim=dim, keepdim=True, correction=0) + # mean_acc = sum(mean_acc) / float(len(mean_acc)) + # var_acc = sum(var_acc) / float(len(var_acc)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + src = (((src - mean) / std) * std_acc) + mean_acc + return src + + +def adaptive_instance_normalization_with_ref( + src: torch.LongTensor, + dst: torch.LongTensor, + style_fidelity: float = 0.5, + do_classifier_free_guidance: bool = True, +): + # logger.debug( + # f"src={src.shape}, min={src.min()}, max={src.max()}, mean={src.mean()}, \n" + # f"dst={src.shape}, min={dst.min()}, max={dst.max()}, mean={dst.mean()}" + # ) + batch_size = src.shape[0] // 2 + uc_mask = torch.Tensor([1] * batch_size + [0] * batch_size).type_as(src).bool() + src_uc = adaptive_instance_normalization(src, dst) + src_c = src_uc.clone() + # TODO: 该部分默认 do_classifier_free_guidance and style_fidelity > 0 = True + if do_classifier_free_guidance and style_fidelity > 0: + src_c[uc_mask] = src[uc_mask] + src = style_fidelity * src_c + (1.0 - style_fidelity) * src_uc + return src + + +def batch_adain_conditioned_tensor( + tensor: torch.Tensor, + src_index: torch.LongTensor, + dst_index: torch.LongTensor, + keep_dim: bool = True, + num_frames: int = None, + dim: int = 2, + style_fidelity: float = 0.5, + do_classifier_free_guidance: bool = True, + need_style_fidelity: bool = False, +): + """_summary_ + + Args: + tensor (torch.Tensor): b c t h w + src_index (torch.LongTensor): _description_ + dst_index (torch.LongTensor): _description_ + keep_dim (bool, optional): _description_. Defaults to True. + + Returns: + _type_: _description_ + """ + ndim = tensor.ndim + dtype = tensor.dtype + if ndim == 4 and num_frames is not None: + tensor = rearrange(tensor, "(b t) c h w-> b c t h w ", t=num_frames) + src = batch_index_select(tensor, dim=dim, index=src_index).contiguous() + dst = batch_index_select(tensor, dim=dim, index=dst_index).contiguous() + if need_style_fidelity: + src = adaptive_instance_normalization_with_ref( + src=src, + dst=dst, + style_fidelity=style_fidelity, + do_classifier_free_guidance=do_classifier_free_guidance, + need_style_fidelity=need_style_fidelity, + ) + else: + src = adaptive_instance_normalization( + src=src, + dst=dst, + ) + if keep_dim: + src = batch_concat_two_tensor_with_index( + src.to(dtype=dtype), + src_index, + dst.to(dtype=dtype), + dst_index, + dim=dim, + ) + + if ndim == 4 and num_frames is not None: + src = rearrange(tensor, "b c t h w ->(b t) c h w") + return src + + +def align_repeat_tensor_single_dim( + src: torch.Tensor, + target_length: int, + dim: int = 0, + n_src_base_length: int = 1, + src_base_index: List[int] = None, +) -> torch.Tensor: + """沿着 dim 纬度, 补齐 src 的长度到目标 target_length。 + 当 src 长度不如 target_length 时, 取其中 前 n_src_base_length 然后 repeat 到 target_length + + align length of src to target_length along dim + when src length is less than target_length, take the first n_src_base_length and repeat to target_length + + Args: + src (torch.Tensor): 输入 tensor, input tensor + target_length (int): 目标长度, target_length + dim (int, optional): 处理纬度, target dim . Defaults to 0. + n_src_base_length (int, optional): src 的基本单元长度, basic length of src. Defaults to 1. + + Returns: + torch.Tensor: _description_ + """ + src_dim_length = src.shape[dim] + if target_length > src_dim_length: + if target_length % src_dim_length == 0: + new = src.repeat_interleave( + repeats=target_length // src_dim_length, dim=dim + ) + else: + if src_base_index is None and n_src_base_length is not None: + src_base_index = torch.arange(n_src_base_length) + + new = src.index_select( + dim=dim, + index=torch.LongTensor(src_base_index).to(device=src.device), + ) + new = new.repeat_interleave( + repeats=target_length // len(src_base_index), + dim=dim, + ) + elif target_length < src_dim_length: + new = src.index_select( + dim=dim, + index=torch.LongTensor(torch.arange(target_length)).to(device=src.device), + ) + else: + new = src + return new + + +def fuse_part_tensor( + src: torch.Tensor, + dst: torch.Tensor, + overlap: int, + weight: float = 0.5, + skip_step: int = 0, +) -> torch.Tensor: + """fuse overstep tensor with weight of src into dst + out = src_fused_part * weight + dst * (1-weight) for overlap + + Args: + src (torch.Tensor): b c t h w + dst (torch.Tensor): b c t h w + overlap (int): 1 + weight (float, optional): weight of src tensor part. Defaults to 0.5. + + Returns: + torch.Tensor: fused tensor + """ + if overlap == 0: + return dst + else: + dst[:, :, skip_step : skip_step + overlap] = ( + weight * src[:, :, -overlap:] + + (1 - weight) * dst[:, :, skip_step : skip_step + overlap] + ) + return dst diff --git a/musev/logging.conf b/musev/logging.conf new file mode 100755 index 0000000000000000000000000000000000000000..409adb4f6af24c4db11b762e221cbb56682307d7 --- /dev/null +++ b/musev/logging.conf @@ -0,0 +1,32 @@ +[loggers] +keys=root,musev + +[handlers] +keys=consoleHandler + +[formatters] +keys=musevFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +# logger level 尽量设置低一点 +[logger_musev] +level=DEBUG +handlers=consoleHandler +qualname=musev +propagate=0 + +# handler level 设置比 logger level高 +[handler_consoleHandler] +class=StreamHandler +level=DEBUG +# level=INFO + +formatter=musevFormatter +args=(sys.stdout,) + +[formatter_musevFormatter] +format=%(asctime)s- %(name)s:%(lineno)d- %(levelname)s- %(message)s +datefmt= \ No newline at end of file diff --git a/musev/models/__init__.py b/musev/models/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..3b7d572a82332b6d3bf4bb39ff600a13165d919e --- /dev/null +++ b/musev/models/__init__.py @@ -0,0 +1,3 @@ +from ..utils.register import Register + +Model_Register = Register(registry_name="torch_model") diff --git a/musev/models/attention.py b/musev/models/attention.py new file mode 100755 index 0000000000000000000000000000000000000000..ce981b61b26379b1e445534843986016aa540631 --- /dev/null +++ b/musev/models/attention.py @@ -0,0 +1,431 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/huggingface/diffusers/blob/64bf5d33b7ef1b1deac256bed7bd99b55020c4e0/src/diffusers/models/attention.py +from __future__ import annotations +from copy import deepcopy + +from typing import Any, Dict, List, Literal, Optional, Callable, Tuple +import logging +from einops import rearrange + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention_processor import Attention as DiffusersAttention +from diffusers.models.attention import ( + BasicTransformerBlock as DiffusersBasicTransformerBlock, + AdaLayerNormZero, + AdaLayerNorm, + FeedForward, +) +from diffusers.models.attention_processor import AttnProcessor + +from .attention_processor import IPAttention, BaseIPAttnProcessor + + +logger = logging.getLogger(__name__) + + +def not_use_xformers_anyway( + use_memory_efficient_attention_xformers: bool, + attention_op: Optional[Callable] = None, +): + return None + + +@maybe_allow_in_graph +class BasicTransformerBlock(DiffusersBasicTransformerBlock): + print_idx = 0 + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0, + cross_attention_dim: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, + attention_type: str = "default", + allow_xformers: bool = True, + cross_attn_temporal_cond: bool = False, + image_scale: float = 1.0, + processor: AttnProcessor | None = None, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + ): + if not only_cross_attention and double_self_attention: + cross_attention_dim = None + super().__init__( + dim, + num_attention_heads, + attention_head_dim, + dropout, + cross_attention_dim, + activation_fn, + num_embeds_ada_norm, + attention_bias, + only_cross_attention, + double_self_attention, + upcast_attention, + norm_elementwise_affine, + norm_type, + final_dropout, + attention_type, + ) + + self.attn1 = IPAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + cross_attn_temporal_cond=cross_attn_temporal_cond, + image_scale=image_scale, + ip_adapter_dim=cross_attention_dim + if only_cross_attention + else attention_head_dim, + facein_dim=cross_attention_dim + if only_cross_attention + else attention_head_dim, + processor=processor, + ) + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + + self.attn2 = IPAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim + if not double_self_attention + else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attn_temporal_cond=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + image_scale=image_scale, + ip_adapter_dim=cross_attention_dim + if not double_self_attention + else attention_head_dim, + facein_dim=cross_attention_dim + if not double_self_attention + else attention_head_dim, + ip_adapter_face_dim=cross_attention_dim + if not double_self_attention + else attention_head_dim, + processor=processor, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + if self.attn1 is not None: + if not allow_xformers: + self.attn1.set_use_memory_efficient_attention_xformers = ( + not_use_xformers_anyway + ) + if self.attn2 is not None: + if not allow_xformers: + self.attn2.set_use_memory_efficient_attention_xformers = ( + not_use_xformers_anyway + ) + self.double_self_attention = double_self_attention + self.only_cross_attention = only_cross_attention + self.cross_attn_temporal_cond = cross_attn_temporal_cond + self.image_scale = image_scale + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + self_attn_block_embs: Optional[Tuple[List[torch.Tensor], List[None]]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备 + # special AttnProcessor needs input parameters in cross_attention_kwargs + original_cross_attention_kwargs = { + k: v + for k, v in cross_attention_kwargs.items() + if k + not in [ + "num_frames", + "sample_index", + "vision_conditon_frames_sample_index", + "vision_cond", + "vision_clip_emb", + "ip_adapter_scale", + "face_emb", + "facein_scale", + "ip_adapter_face_emb", + "ip_adapter_face_scale", + "do_classifier_free_guidance", + ] + } + + if "do_classifier_free_guidance" in cross_attention_kwargs: + do_classifier_free_guidance = cross_attention_kwargs[ + "do_classifier_free_guidance" + ] + else: + do_classifier_free_guidance = False + + # 2. Prepare GLIGEN inputs + original_cross_attention_kwargs = ( + original_cross_attention_kwargs.copy() + if original_cross_attention_kwargs is not None + else {} + ) + gligen_kwargs = original_cross_attention_kwargs.pop("gligen", None) + + # 返回self_attn的结果,适用于referencenet的输出给其他Unet来使用 + # return the result of self_attn, which is suitable for the output of referencenet to be used by other Unet + if ( + self_attn_block_embs is not None + and self_attn_block_embs_mode.lower() == "write" + ): + # self_attn_block_emb = self.attn1.head_to_batch_dim(attn_output, out_dim=4) + self_attn_block_emb = norm_hidden_states + if not hasattr(self, "spatial_self_attn_idx"): + raise ValueError( + "must call unet.insert_spatial_self_attn_idx to generate spatial attn index" + ) + basick_transformer_idx = self.spatial_self_attn_idx + if self.print_idx == 0: + logger.debug( + f"self_attn_block_embs, self_attn_block_embs_mode={self_attn_block_embs_mode}, " + f"basick_transformer_idx={basick_transformer_idx}, length={len(self_attn_block_embs)}, shape={self_attn_block_emb.shape}, " + # f"attn1 processor, {type(self.attn1.processor)}" + ) + self_attn_block_embs[basick_transformer_idx] = self_attn_block_emb + + # read and put referencenet emb into cross_attention_kwargs, which would be fused into attn_processor + if ( + self_attn_block_embs is not None + and self_attn_block_embs_mode.lower() == "read" + ): + basick_transformer_idx = self.spatial_self_attn_idx + if not hasattr(self, "spatial_self_attn_idx"): + raise ValueError( + "must call unet.insert_spatial_self_attn_idx to generate spatial attn index" + ) + if self.print_idx == 0: + logger.debug( + f"refer_self_attn_emb: , self_attn_block_embs_mode={self_attn_block_embs_mode}, " + f"length={len(self_attn_block_embs)}, idx={basick_transformer_idx}, " + # f"attn1 processor, {type(self.attn1.processor)}, " + ) + ref_emb = self_attn_block_embs[basick_transformer_idx] + cross_attention_kwargs["refer_emb"] = ref_emb + if self.print_idx == 0: + logger.debug( + f"unet attention read, {self.spatial_self_attn_idx}", + ) + # ------------------------------warning----------------------- + # 这两行由于使用了ref_emb会导致和checkpoint_train相关的训练错误,具体未知,留在这里作为警示 + # bellow annoated code will cause training error, keep it here as a warning + # logger.debug(f"ref_emb shape,{ref_emb.shape}, {ref_emb.mean()}") + # logger.debug( + # f"norm_hidden_states shape, {norm_hidden_states.shape}, {norm_hidden_states.mean()}", + # ) + if self.attn1 is None: + self.print_idx += 1 + return norm_hidden_states + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if self.only_cross_attention + else None, + attention_mask=attention_mask, + **( + cross_attention_kwargs + if isinstance(self.attn1.processor, BaseIPAttnProcessor) + else original_cross_attention_kwargs + ), + ) + + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + # 推断的时候,对于uncondition_部分独立生成,排除掉 refer_emb, + # 首帧等的影响,避免生成参考了refer_emb、首帧等,又在uncond上去除了 + # in inference stage, eliminate influence of refer_emb, vis_cond on unconditionpart + # to avoid use that, and then eliminate in pipeline + # refer to moore-animate anyone + + # do_classifier_free_guidance = False + if self.print_idx == 0: + logger.debug(f"do_classifier_free_guidance={do_classifier_free_guidance},") + if do_classifier_free_guidance: + hidden_states_c = attn_output.clone() + _uc_mask = ( + torch.Tensor( + [1] * (norm_hidden_states.shape[0] // 2) + + [0] * (norm_hidden_states.shape[0] // 2) + ) + .to(norm_hidden_states.device) + .bool() + ) + hidden_states_c[_uc_mask] = self.attn1( + norm_hidden_states[_uc_mask], + encoder_hidden_states=norm_hidden_states[_uc_mask], + attention_mask=attention_mask, + ) + attn_output = hidden_states_c.clone() + + if "refer_emb" in cross_attention_kwargs: + del cross_attention_kwargs["refer_emb"] + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + # 2.5 ends + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) + if self.use_ada_layer_norm + else self.norm2(hidden_states) + ) + + # 特殊AttnProcessor需要的入参 在 cross_attention_kwargs 准备 + # special AttnProcessor needs input parameters in cross_attention_kwargs + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states + if not self.double_self_attention + else None, + attention_mask=encoder_attention_mask, + **( + original_cross_attention_kwargs + if not isinstance(self.attn2.processor, BaseIPAttnProcessor) + else cross_attention_kwargs + ), + ) + if self.print_idx == 0: + logger.debug( + f"encoder_hidden_states, type={type(encoder_hidden_states)}" + ) + if encoder_hidden_states is not None: + logger.debug( + f"encoder_hidden_states, ={encoder_hidden_states.shape}" + ) + + # encoder_hidden_states_tmp = ( + # encoder_hidden_states + # if not self.double_self_attention + # else norm_hidden_states + # ) + # if do_classifier_free_guidance: + # hidden_states_c = attn_output.clone() + # _uc_mask = ( + # torch.Tensor( + # [1] * (norm_hidden_states.shape[0] // 2) + # + [0] * (norm_hidden_states.shape[0] // 2) + # ) + # .to(norm_hidden_states.device) + # .bool() + # ) + # hidden_states_c[_uc_mask] = self.attn2( + # norm_hidden_states[_uc_mask], + # encoder_hidden_states=encoder_hidden_states_tmp[_uc_mask], + # attention_mask=attention_mask, + # ) + # attn_output = hidden_states_c.clone() + hidden_states = attn_output + hidden_states + # 4. Feed-forward + if self.norm3 is not None and self.ff is not None: + norm_hidden_states = self.norm3(hidden_states) + if self.use_ada_layer_norm_zero: + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = ( + norm_hidden_states.shape[self._chunk_dim] // self._chunk_size + ) + ff_output = torch.cat( + [ + self.ff(hid_slice, scale=lora_scale) + for hid_slice in norm_hidden_states.chunk( + num_chunks, dim=self._chunk_dim + ) + ], + dim=self._chunk_dim, + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + self.print_idx += 1 + return hidden_states diff --git a/musev/models/attention_processor.py b/musev/models/attention_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..6bd27e7c49254ef5bb3e9ddf6bfed824ee53c47e --- /dev/null +++ b/musev/models/attention_processor.py @@ -0,0 +1,750 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。 + 相对而言,开源代码经常会重新定义Attention 类, + + This module implements special AttnProcessor function with custom attn_processor class. + While other open source code always modify Attention class. +""" +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +from __future__ import annotations + +import time +from typing import Any, Callable, Optional +import logging + +from einops import rearrange, repeat +import torch +import torch.nn as nn +import torch.nn.functional as F +import xformers +from diffusers.models.lora import LoRACompatibleLinear + +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.attention_processor import ( + Attention as DiffusersAttention, + AttnProcessor, + AttnProcessor2_0, +) +from ..data.data_util import ( + batch_concat_two_tensor_with_index, + batch_index_select, + align_repeat_tensor_single_dim, + batch_adain_conditioned_tensor, +) + +from . import Model_Register + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +@maybe_allow_in_graph +class IPAttention(DiffusersAttention): + r""" + Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v, + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: str | None = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: int | None = None, + norm_num_groups: int | None = None, + spatial_norm_dim: int | None = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 0.00001, + rescale_output_factor: float = 1, + residual_connection: bool = False, + _from_deprecated_attn_block=False, + processor: AttnProcessor | None = None, + cross_attn_temporal_cond: bool = False, + image_scale: float = 1.0, + ip_adapter_dim: int = None, + need_t2i_facein: bool = False, + facein_dim: int = None, + need_t2i_ip_adapter_face: bool = False, + ip_adapter_face_dim: int = None, + ): + super().__init__( + query_dim, + cross_attention_dim, + heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + added_kv_proj_dim, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + ) + self.cross_attn_temporal_cond = cross_attn_temporal_cond + self.image_scale = image_scale + # 面向首帧的 ip_adapter + # ip_apdater + if cross_attn_temporal_cond: + self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) + self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) + # facein + self.need_t2i_facein = need_t2i_facein + self.facein_dim = facein_dim + if need_t2i_facein: + raise NotImplementedError("facein") + + # ip_adapter_face + self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face + self.ip_adapter_face_dim = ip_adapter_face_dim + if need_t2i_ip_adapter_face: + self.ip_adapter_face_to_k_ip = LoRACompatibleLinear( + ip_adapter_face_dim, query_dim, bias=False + ) + self.ip_adapter_face_to_v_ip = LoRACompatibleLinear( + ip_adapter_face_dim, query_dim, bias=False + ) + + def set_use_memory_efficient_attention_xformers( + self, + use_memory_efficient_attention_xformers: bool, + attention_op: Callable[..., Any] | None = None, + ): + if ( + "XFormers" in self.processor.__class__.__name__ + or "IP" in self.processor.__class__.__name__ + ): + pass + else: + return super().set_use_memory_efficient_attention_xformers( + use_memory_efficient_attention_xformers, attention_op + ) + + +@Model_Register.register +class BaseIPAttnProcessor(nn.Module): + print_idx = 0 + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + +@Model_Register.register +class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor): + r""" + 面向 ref_image的 self_attn的 IPAdapter + """ + print_idx = 0 + + def __init__( + self, + attention_op: Optional[Callable] = None, + ): + super().__init__() + + self.attention_op = attention_op + + def __call__( + self, + attn: IPAttention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + refer_emb: torch.Tensor = None, + vision_clip_emb: torch.Tensor = None, + ip_adapter_scale: float = 1.0, + face_emb: torch.Tensor = None, + facein_scale: float = 1.0, + ip_adapter_face_emb: torch.Tensor = None, + ip_adapter_face_scale: float = 1.0, + do_classifier_free_guidance: bool = False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, key_tokens, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask( + attention_mask, key_tokens, batch_size + ) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + encoder_hidden_states = align_repeat_tensor_single_dim( + encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 + ) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + # for facein + if self.print_idx == 0: + logger.debug( + f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}" + ) + if facein_scale > 0 and face_emb is not None: + raise NotImplementedError("facein") + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + + # ip-adapter start + if self.print_idx == 0: + logger.debug( + f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}" + ) + if ip_adapter_scale > 0 and vision_clip_emb is not None: + if self.print_idx == 0: + logger.debug( + f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" + ) + ip_key = attn.to_k_ip(vision_clip_emb) + ip_value = attn.to_v_ip(vision_clip_emb) + ip_key = align_repeat_tensor_single_dim( + ip_key, target_length=batch_size, dim=0 + ) + ip_value = align_repeat_tensor_single_dim( + ip_value, target_length=batch_size, dim=0 + ) + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + if self.print_idx == 0: + logger.debug( + f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" + ) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states_from_ip = xformers.ops.memory_efficient_attention( + query, + ip_key, + ip_value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip + # ip-adapter end + + # ip-adapter face start + if self.print_idx == 0: + logger.debug( + f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}" + ) + if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None: + if self.print_idx == 0: + logger.debug( + f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" + ) + ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb) + ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb) + ip_key = align_repeat_tensor_single_dim( + ip_key, target_length=batch_size, dim=0 + ) + ip_value = align_repeat_tensor_single_dim( + ip_value, target_length=batch_size, dim=0 + ) + ip_key = attn.head_to_batch_dim(ip_key).contiguous() + ip_value = attn.head_to_batch_dim(ip_value).contiguous() + if self.print_idx == 0: + logger.debug( + f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" + ) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states_from_ip = xformers.ops.memory_efficient_attention( + query, + ip_key, + ip_value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = ( + hidden_states + ip_adapter_face_scale * hidden_states_from_ip + ) + # ip-adapter face end + + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + self.print_idx += 1 + return hidden_states + + +@Model_Register.register +class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor): + r""" + 面向首帧的 referenceonly attn,适用于 T2I的 self_attn + referenceonly with vis_cond as key, value, in t2i self_attn. + """ + print_idx = 0 + + def __init__( + self, + attention_op: Optional[Callable] = None, + ): + super().__init__() + + self.attention_op = attention_op + + def __call__( + self, + attn: IPAttention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + refer_emb: torch.Tensor = None, + face_emb: torch.Tensor = None, + vision_clip_emb: torch.Tensor = None, + ip_adapter_scale: float = 1.0, + facein_scale: float = 1.0, + ip_adapter_face_emb: torch.Tensor = None, + ip_adapter_face_scale: float = 1.0, + do_classifier_free_guidance: bool = False, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, key_tokens, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = attn.prepare_attention_mask( + attention_mask, key_tokens, batch_size + ) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + # vision_cond in same unet attn start + if ( + vision_conditon_frames_sample_index is not None and num_frames > 1 + ) or refer_emb is not None: + batchsize_timesize = hidden_states.shape[0] + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}" + ) + encoder_hidden_states = rearrange( + hidden_states, "(b t) hw c -> b t hw c", t=num_frames + ) + # if False: + if vision_conditon_frames_sample_index is not None and num_frames > 1: + ip_hidden_states = batch_index_select( + encoder_hidden_states, + dim=1, + index=vision_conditon_frames_sample_index, + ).contiguous() + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" + ) + # + ip_hidden_states = rearrange( + ip_hidden_states, "b t hw c -> b 1 (t hw) c" + ) + ip_hidden_states = align_repeat_tensor_single_dim( + ip_hidden_states, + dim=1, + target_length=num_frames, + ) + # b t hw c -> b t hw + hw c + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" + ) + encoder_hidden_states = torch.concat( + [encoder_hidden_states, ip_hidden_states], dim=2 + ) + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" + ) + # if False: + if refer_emb is not None: # and num_frames > 1: + refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c") + refer_emb = align_repeat_tensor_single_dim( + refer_emb, target_length=num_frames, dim=1 + ) + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" + ) + encoder_hidden_states = torch.concat( + [encoder_hidden_states, refer_emb], dim=2 + ) + if self.print_idx == 0: + logger.debug( + f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" + ) + encoder_hidden_states = rearrange( + encoder_hidden_states, "b t hw c -> (b t) hw c" + ) + # vision_cond in same unet attn end + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + encoder_hidden_states = align_repeat_tensor_single_dim( + encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 + ) + key = attn.to_k(encoder_hidden_states, scale=scale) + value = attn.to_v(encoder_hidden_states, scale=scale) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + op=self.attention_op, + scale=attn.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + self.print_idx += 1 + + return hidden_states + + +@Model_Register.register +class NonParamReferenceIPXFormersAttnProcessor( + NonParamT2ISelfReferenceXFormersAttnProcessor +): + def __init__(self, attention_op: Callable[..., Any] | None = None): + super().__init__(attention_op) + + +@maybe_allow_in_graph +class ReferEmbFuseAttention(IPAttention): + """使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中 + # TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合 + residual_connection: bool = True, 默认, 从不产生影响开始学习 + + use attention to fuse referencenet emb into unet latents + # TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c + residual_connection: bool = True, default, start from no effect + + Args: + IPAttention (_type_): _description_ + """ + + print_idx = 0 + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: str | None = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: int | None = None, + norm_num_groups: int | None = None, + spatial_norm_dim: int | None = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 0.00001, + rescale_output_factor: float = 1, + residual_connection: bool = True, + _from_deprecated_attn_block=False, + processor: AttnProcessor | None = None, + cross_attn_temporal_cond: bool = False, + image_scale: float = 1, + ): + super().__init__( + query_dim, + cross_attention_dim, + heads, + dim_head, + dropout, + bias, + upcast_attention, + upcast_softmax, + cross_attention_norm, + cross_attention_norm_num_groups, + added_kv_proj_dim, + norm_num_groups, + spatial_norm_dim, + out_bias, + scale_qk, + only_cross_attention, + eps, + rescale_output_factor, + residual_connection, + _from_deprecated_attn_block, + processor, + cross_attn_temporal_cond, + image_scale, + ) + self.processor = None + # 配合residual,使一开始不影响之前结果 + nn.init.zeros_(self.to_out[0].weight) + nn.init.zeros_(self.to_out[0].bias) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = None, + ) -> torch.Tensor: + """fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn + refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor + + Args: + hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1 + encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None. + attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None. + temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None. + scale (float, optional): _description_. Defaults to 1.0. + num_frames (int, optional): _description_. Defaults to None. + + Returns: + torch.Tensor: _description_ + """ + residual = hidden_states + # start + hidden_states = rearrange( + hidden_states, "(b t) c h w -> b c t h w", t=num_frames + ) + batch_size, channel, t1, height, width = hidden_states.shape + if self.print_idx == 0: + logger.debug( + f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}" + ) + # concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c + encoder_hidden_states = rearrange( + encoder_hidden_states, " b c t2 h w-> b (t2 h w) c" + ) + encoder_hidden_states = repeat( + encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1 + ) + hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c") + # bt (t2+1)hw d + encoder_hidden_states = torch.concat( + [encoder_hidden_states, hidden_states], dim=1 + ) + # encoder_hidden_states = align_repeat_tensor_single_dim( + # encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 + # ) + # end + + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + _, key_tokens, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + attention_mask = self.prepare_attention_mask( + attention_mask, key_tokens, batch_size + ) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = self.to_q(hidden_states, scale=scale) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = self.to_k(encoder_hidden_states, scale=scale) + value = self.to_v(encoder_hidden_states, scale=scale) + + query = self.head_to_batch_dim(query).contiguous() + key = self.head_to_batch_dim(key).contiguous() + value = self.head_to_batch_dim(value).contiguous() + + # query: b t hw d + # key/value: bt (t1+1)hw d + hidden_states = xformers.ops.memory_efficient_attention( + query, + key, + value, + attn_bias=attention_mask, + scale=self.scale, + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = self.to_out[0](hidden_states, scale=scale) + # dropout + hidden_states = self.to_out[1](hidden_states) + + hidden_states = rearrange( + hidden_states, + "bt (h w) c-> bt c h w", + h=height, + w=width, + ) + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + self.print_idx += 1 + return hidden_states diff --git a/musev/models/controlnet.py b/musev/models/controlnet.py new file mode 100755 index 0000000000000000000000000000000000000000..9daffed40653537b0dc8f00546e5efc759c24344 --- /dev/null +++ b/musev/models/controlnet.py @@ -0,0 +1,399 @@ +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union +import warnings +import os + + +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.modeling_utils import ModelMixin +import PIL +from einops import rearrange, repeat +import numpy as np +import torch +import torch.nn.init as init +from diffusers.models.controlnet import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers +from diffusers.utils.torch_utils import is_compiled_module + + +class ControlnetPredictor(object): + def __init__(self, controlnet_model_path: str, *args, **kwargs): + """Controlnet 推断函数,用于提取 controlnet backbone的emb,避免训练时重复抽取 + Controlnet inference predictor, used to extract the emb of the controlnet backbone to avoid repeated extraction during training + Args: + controlnet_model_path (str): controlnet 模型路径. controlnet model path. + """ + super(ControlnetPredictor, self).__init__(*args, **kwargs) + self.controlnet = ControlNetModel.from_pretrained( + controlnet_model_path, + ) + + def prepare_image( + self, + image, # b c t h w + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if height is None: + height = image.shape[-2] + if width is None: + width = image.shape[-1] + width, height = ( + x - x % self.control_image_processor.vae_scale_factor + for x in (width, height) + ) + image = rearrange(image, "b c t h w-> (b t) c h w") + image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0 + + image = ( + torch.nn.functional.interpolate( + image, + size=(height, width), + mode="bilinear", + ), + ) + + do_normalize = self.control_image_processor.config.do_normalize + if image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + image = self.control_image_processor.normalize(image) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + @torch.no_grad() + def __call__( + self, + batch_size: int, + device: str, + dtype: torch.dtype, + timesteps: List[float], + i: int, + scheduler: KarrasDiffusionSchedulers, + prompt_embeds: torch.Tensor, + do_classifier_free_guidance: bool = False, + # 2b co t ho wo + latent_model_input: torch.Tensor = None, + # b co t ho wo + latents: torch.Tensor = None, + # b c t h w + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + # b c t(1) hi wi + controlnet_condition_frames: Optional[torch.FloatTensor] = None, + # b c t ho wo + controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, + # b c t(1) ho wo + controlnet_condition_latents: Optional[torch.FloatTensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_videos_per_prompt: Optional[int] = 1, + return_dict: bool = True, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + latent_index: torch.LongTensor = None, + vision_condition_latent_index: torch.LongTensor = None, + **kwargs, + ): + assert ( + image is None and controlnet_latents is None + ), "should set one of image and controlnet_latents" + + controlnet = ( + self.controlnet._orig_mod + if is_compiled_module(self.controlnet) + else self.controlnet + ) + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance( + control_guidance_end, list + ): + control_guidance_start = len(control_guidance_end) * [ + control_guidance_start + ] + elif not isinstance(control_guidance_end, list) and isinstance( + control_guidance_start, list + ): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance( + control_guidance_end, list + ): + mult = ( + len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) + else 1 + ) + control_guidance_start, control_guidance_end = mult * [ + control_guidance_start + ], mult * [control_guidance_end] + + if isinstance(controlnet, MultiControlNetModel) and isinstance( + controlnet_conditioning_scale, float + ): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( + controlnet.nets + ) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + if ( + controlnet_latents is not None + and controlnet_condition_latents is not None + ): + if isinstance(controlnet_latents, np.ndarray): + controlnet_latents = torch.from_numpy(controlnet_latents) + if isinstance(controlnet_condition_latents, np.ndarray): + controlnet_condition_latents = torch.from_numpy( + controlnet_condition_latents + ) + # TODO:使用index进行concat + controlnet_latents = torch.concat( + [controlnet_condition_latents, controlnet_latents], dim=2 + ) + if not guess_mode and do_classifier_free_guidance: + controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) + controlnet_latents = rearrange( + controlnet_latents, "b c t h w->(b t) c h w" + ) + controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) + else: + # TODO:使用index进行concat + # TODO: concat with index + if controlnet_condition_frames is not None: + if isinstance(controlnet_condition_frames, np.ndarray): + image = np.concatenate( + [controlnet_condition_frames, image], axis=2 + ) + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + # TODO: 支持直接使用controlnet_latent而不是frames + # TODO: support using controlnet_latent directly instead of frames + if controlnet_latents is not None: + raise NotImplementedError + else: + for i, image_ in enumerate(image): + if controlnet_condition_frames is not None and isinstance( + controlnet_condition_frames, list + ): + if isinstance(controlnet_condition_frames[i], np.ndarray): + image_ = np.concatenate( + [controlnet_condition_frames[i], image_], axis=2 + ) + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append( + keeps[0] if isinstance(controlnet, ControlNetModel) else keeps + ) + + t = timesteps[i] + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + if isinstance(controlnet_keep[i], list): + cond_scale = [ + c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) + ] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + control_model_input_reshape = rearrange( + control_model_input, "b c t h w -> (b t) c h w" + ) + encoder_hidden_states_repeat = repeat( + controlnet_prompt_embeds, + "b n q->(b t) n q", + t=control_model_input.shape[2], + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input_reshape, + t, + encoder_hidden_states_repeat, + controlnet_cond=image, + controlnet_cond_latents=controlnet_latents, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + return down_block_res_samples, mid_block_res_sample + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +class PoseGuider(ModelMixin): + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 64, 128), + ): + super().__init__() + self.conv_in = InflatedConv3d( + conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 + ) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append( + InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) + ) + self.blocks.append( + InflatedConv3d( + channel_in, channel_out, kernel_size=3, padding=1, stride=2 + ) + ) + + self.conv_out = zero_module( + InflatedConv3d( + block_out_channels[-1], + conditioning_embedding_channels, + kernel_size=3, + padding=1, + ) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + @classmethod + def from_pretrained( + cls, + pretrained_model_path, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int] = (16, 32, 64, 128), + ): + if not os.path.exists(pretrained_model_path): + print(f"There is no model file in {pretrained_model_path}") + print( + f"loaded PoseGuider's pretrained weights from {pretrained_model_path} ..." + ) + + state_dict = torch.load(pretrained_model_path, map_location="cpu") + model = PoseGuider( + conditioning_embedding_channels=conditioning_embedding_channels, + conditioning_channels=conditioning_channels, + block_out_channels=block_out_channels, + ) + + m, u = model.load_state_dict(state_dict, strict=False) + # print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") + params = [p.numel() for n, p in model.named_parameters()] + print(f"### PoseGuider's Parameters: {sum(params) / 1e6} M") + + return model diff --git a/musev/models/embeddings.py b/musev/models/embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..b1e0aa0c90ba78a9005d351224d4c14c39e7f5fb --- /dev/null +++ b/musev/models/embeddings.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from einops import rearrange +import torch +from torch.nn import functional as F +import numpy as np + +from diffusers.models.embeddings import get_2d_sincos_pos_embed_from_grid + + +# ref diffusers.models.embeddings.get_2d_sincos_pos_embed +def get_2d_sincos_pos_embed( + embed_dim, + grid_size_w, + grid_size_h, + cls_token=False, + extra_tokens=0, + norm_length: bool = False, + max_length: float = 2048, +): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if norm_length and grid_size_h <= max_length and grid_size_w <= max_length: + grid_h = np.linspace(0, max_length, grid_size_h) + grid_w = np.linspace(0, max_length, grid_size_w) + else: + grid_h = np.arange(grid_size_h, dtype=np.float32) + grid_w = np.arange(grid_size_w, dtype=np.float32) + grid = np.meshgrid(grid_h, grid_w) # here h goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def resize_spatial_position_emb( + emb: torch.Tensor, + height: int, + width: int, + scale: float = None, + target_height: int = None, + target_width: int = None, +) -> torch.Tensor: + """_summary_ + + Args: + emb (torch.Tensor): b ( h w) d + height (int): _description_ + width (int): _description_ + scale (float, optional): _description_. Defaults to None. + target_height (int, optional): _description_. Defaults to None. + target_width (int, optional): _description_. Defaults to None. + + Returns: + torch.Tensor: b (target_height target_width) d + """ + if scale is not None: + target_height = int(height * scale) + target_width = int(width * scale) + emb = rearrange(emb, "(h w) (b d) ->b d h w", h=height, b=1) + emb = F.interpolate( + emb, + size=(target_height, target_width), + mode="bicubic", + align_corners=False, + ) + emb = rearrange(emb, "b d h w-> (h w) (b d)") + return emb diff --git a/musev/models/facein_loader.py b/musev/models/facein_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..38d4d17a4f97e1adb2406f3411c891f81cc51574 --- /dev/null +++ b/musev/models/facein_loader.py @@ -0,0 +1,120 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, +) +from diffusers.utils.import_utils import is_xformers_available + +from mmcm.vision.feature_extractor.clip_vision_extractor import ( + ImageClipVisionFeatureExtractor, + ImageClipVisionFeatureExtractorV2, +) +from mmcm.vision.feature_extractor.insight_face_extractor import InsightFaceExtractor + +from ip_adapter.resampler import Resampler +from ip_adapter.ip_adapter import ImageProjModel + +from .unet_loader import update_unet_with_sd +from .unet_3d_condition import UNet3DConditionModel +from .ip_adapter_loader import ip_adapter_keys_list + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 +unet_keys_list = [ + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_k_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.facein_to_v_ip.weight", +] + + +UNET2IPAadapter_Keys_MAPIING = { + k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) +} + + +def load_facein_extractor_and_proj_by_name( + model_name: str, + ip_ckpt: Tuple[str, nn.Module], + ip_image_encoder: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + clip_embeddings_dim: int = 512, + clip_extra_context_tokens: int = 1, + ip_scale: float = 0.0, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + unet: nn.Module = None, +) -> nn.Module: + pass + + +def update_unet_facein_cross_attn_param( + unet: UNet3DConditionModel, ip_adapter_state_dict: Dict +) -> None: + """use independent ip_adapter attn 中的 to_k, to_v in unet + ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight']的字典 + + + Args: + unet (UNet3DConditionModel): _description_ + ip_adapter_state_dict (Dict): _description_ + """ + pass diff --git a/musev/models/ip_adapter_face_loader.py b/musev/models/ip_adapter_face_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..c71e63f79aaee6c45a9a9b3700674a11179bc43e --- /dev/null +++ b/musev/models/ip_adapter_face_loader.py @@ -0,0 +1,179 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, +) +from diffusers.utils.import_utils import is_xformers_available + +from ip_adapter.resampler import Resampler +from ip_adapter.ip_adapter import ImageProjModel +from ip_adapter.ip_adapter_faceid import ProjPlusModel, MLPProjModel + +from mmcm.vision.feature_extractor.clip_vision_extractor import ( + ImageClipVisionFeatureExtractor, + ImageClipVisionFeatureExtractorV2, +) +from mmcm.vision.feature_extractor.insight_face_extractor import ( + InsightFaceExtractorNormEmb, +) + + +from .unet_loader import update_unet_with_sd +from .unet_3d_condition import UNet3DConditionModel +from .ip_adapter_loader import ip_adapter_keys_list + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 +unet_keys_list = [ + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_k_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.ip_adapter_face_to_v_ip.weight", +] + + +UNET2IPAadapter_Keys_MAPIING = { + k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) +} + + +def load_ip_adapter_face_extractor_and_proj_by_name( + model_name: str, + ip_ckpt: Tuple[str, nn.Module], + ip_image_encoder: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + clip_embeddings_dim: int = 1024, + clip_extra_context_tokens: int = 4, + ip_scale: float = 0.0, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + unet: nn.Module = None, +) -> nn.Module: + if model_name == "IPAdapterFaceID": + if ip_image_encoder is not None: + ip_adapter_face_emb_extractor = InsightFaceExtractorNormEmb( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + else: + ip_adapter_face_emb_extractor = None + ip_adapter_image_proj = MLPProjModel( + cross_attention_dim=cross_attention_dim, + id_embeddings_dim=clip_embeddings_dim, + num_tokens=clip_extra_context_tokens, + ).to(device, dtype=dtype) + else: + raise ValueError( + f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, IPAdapterFaceID" + ) + ip_adapter_state_dict = torch.load( + ip_ckpt, + map_location="cpu", + ) + ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) + if unet is not None and "ip_adapter" in ip_adapter_state_dict: + update_unet_ip_adapter_cross_attn_param( + unet, + ip_adapter_state_dict["ip_adapter"], + ) + logger.info( + f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" + ) + return ( + ip_adapter_face_emb_extractor, + ip_adapter_image_proj, + ) + + +def update_unet_ip_adapter_cross_attn_param( + unet: UNet3DConditionModel, ip_adapter_state_dict: Dict +) -> None: + """use independent ip_adapter attn 中的 to_k, to_v in unet + ip_adapter: like ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] + + + Args: + unet (UNet3DConditionModel): _description_ + ip_adapter_state_dict (Dict): _description_ + """ + unet_spatial_cross_atnns = unet.spatial_cross_attns[0] + unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} + for i, (unet_key_more, ip_adapter_key) in enumerate( + UNET2IPAadapter_Keys_MAPIING.items() + ): + ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] + unet_key_more_spit = unet_key_more.split(".") + unet_key = ".".join(unet_key_more_spit[:-3]) + suffix = ".".join(unet_key_more_spit[-3:]) + logger.debug( + f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", + ) + if ".ip_adapter_face_to_k" in suffix: + with torch.no_grad(): + unet_spatial_cross_atnns_dct[ + unet_key + ].ip_adapter_face_to_k_ip.weight.copy_(ip_adapter_value.data) + else: + with torch.no_grad(): + unet_spatial_cross_atnns_dct[ + unet_key + ].ip_adapter_face_to_v_ip.weight.copy_(ip_adapter_value.data) diff --git a/musev/models/ip_adapter_loader.py b/musev/models/ip_adapter_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..58c9366f08ec91f6e828d47c662cc11f8dac63f0 --- /dev/null +++ b/musev/models/ip_adapter_loader.py @@ -0,0 +1,340 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, +) +from diffusers.utils.import_utils import is_xformers_available + +from mmcm.vision.feature_extractor import clip_vision_extractor +from mmcm.vision.feature_extractor.clip_vision_extractor import ( + ImageClipVisionFeatureExtractor, + ImageClipVisionFeatureExtractorV2, + VerstailSDLastHiddenState2ImageEmb, +) + +from ip_adapter.resampler import Resampler +from ip_adapter.ip_adapter import ImageProjModel + +from .unet_loader import update_unet_with_sd +from .unet_3d_condition import UNet3DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def load_vision_clip_encoder_by_name( + ip_image_encoder: Tuple[str, nn.Module] = None, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + vision_clip_extractor_class_name: str = None, +) -> nn.Module: + if vision_clip_extractor_class_name is not None: + vision_clip_extractor = getattr( + clip_vision_extractor, vision_clip_extractor_class_name + )( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + else: + vision_clip_extractor = None + return vision_clip_extractor + + +def load_ip_adapter_image_proj_by_name( + model_name: str, + ip_ckpt: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + clip_embeddings_dim: int = 1024, + clip_extra_context_tokens: int = 4, + ip_scale: float = 0.0, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + unet: nn.Module = None, + vision_clip_extractor_class_name: str = None, + ip_image_encoder: Tuple[str, nn.Module] = None, +) -> nn.Module: + if model_name in [ + "IPAdapter", + "musev_referencenet", + "musev_referencenet_pose", + ]: + ip_adapter_image_proj = ImageProjModel( + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim, + clip_extra_context_tokens=clip_extra_context_tokens, + ) + + elif model_name == "IPAdapterPlus": + vision_clip_extractor = ImageClipVisionFeatureExtractorV2( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + ip_adapter_image_proj = Resampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=clip_extra_context_tokens, + embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, + output_dim=cross_attention_dim, + ff_mult=4, + ) + elif model_name in [ + "VerstailSDLastHiddenState2ImageEmb", + "OriginLastHiddenState2ImageEmbd", + "OriginLastHiddenState2Poolout", + ]: + ip_adapter_image_proj = getattr( + clip_vision_extractor, model_name + ).from_pretrained(ip_image_encoder) + else: + raise ValueError( + f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus, VerstailSDLastHiddenState2ImageEmb" + ) + if ip_ckpt is not None: + ip_adapter_state_dict = torch.load( + ip_ckpt, + map_location="cpu", + ) + ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) + if ( + unet is not None + and unet.ip_adapter_cross_attn + and "ip_adapter" in ip_adapter_state_dict + ): + update_unet_ip_adapter_cross_attn_param( + unet, ip_adapter_state_dict["ip_adapter"] + ) + logger.info( + f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" + ) + return ip_adapter_image_proj + + +def load_ip_adapter_vision_clip_encoder_by_name( + model_name: str, + ip_ckpt: Tuple[str, nn.Module], + ip_image_encoder: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + clip_embeddings_dim: int = 1024, + clip_extra_context_tokens: int = 4, + ip_scale: float = 0.0, + dtype: torch.dtype = torch.float16, + device: str = "cuda", + unet: nn.Module = None, + vision_clip_extractor_class_name: str = None, +) -> nn.Module: + if vision_clip_extractor_class_name is not None: + vision_clip_extractor = getattr( + clip_vision_extractor, vision_clip_extractor_class_name + )( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + else: + vision_clip_extractor = None + if model_name in [ + "IPAdapter", + "musev_referencenet", + ]: + if ip_image_encoder is not None: + if vision_clip_extractor_class_name is None: + vision_clip_extractor = ImageClipVisionFeatureExtractor( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + else: + vision_clip_extractor = None + ip_adapter_image_proj = ImageProjModel( + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=clip_embeddings_dim, + clip_extra_context_tokens=clip_extra_context_tokens, + ) + + elif model_name == "IPAdapterPlus": + if ip_image_encoder is not None: + if vision_clip_extractor_class_name is None: + vision_clip_extractor = ImageClipVisionFeatureExtractorV2( + pretrained_model_name_or_path=ip_image_encoder, + dtype=dtype, + device=device, + ) + else: + vision_clip_extractor = None + ip_adapter_image_proj = Resampler( + dim=cross_attention_dim, + depth=4, + dim_head=64, + heads=12, + num_queries=clip_extra_context_tokens, + embedding_dim=vision_clip_extractor.image_encoder.config.hidden_size, + output_dim=cross_attention_dim, + ff_mult=4, + ).to(dtype=torch.float16) + else: + raise ValueError( + f"unsupport model_name={model_name}, only support IPAdapter, IPAdapterPlus" + ) + ip_adapter_state_dict = torch.load( + ip_ckpt, + map_location="cpu", + ) + ip_adapter_image_proj.load_state_dict(ip_adapter_state_dict["image_proj"]) + if ( + unet is not None + and unet.ip_adapter_cross_attn + and "ip_adapter" in ip_adapter_state_dict + ): + update_unet_ip_adapter_cross_attn_param( + unet, ip_adapter_state_dict["ip_adapter"] + ) + logger.info( + f"update unet.spatial_cross_attn_ip_adapter parameter with {ip_ckpt}" + ) + return ( + vision_clip_extractor, + ip_adapter_image_proj, + ) + + +# refer https://github.com/tencent-ailab/IP-Adapter/issues/168#issuecomment-1846771651 +unet_keys_list = [ + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor.to_v_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_k_ip.weight", + "mid_block.attentions.0.transformer_blocks.0.attn2.processor.to_v_ip.weight", +] + + +ip_adapter_keys_list = [ + "1.to_k_ip.weight", + "1.to_v_ip.weight", + "3.to_k_ip.weight", + "3.to_v_ip.weight", + "5.to_k_ip.weight", + "5.to_v_ip.weight", + "7.to_k_ip.weight", + "7.to_v_ip.weight", + "9.to_k_ip.weight", + "9.to_v_ip.weight", + "11.to_k_ip.weight", + "11.to_v_ip.weight", + "13.to_k_ip.weight", + "13.to_v_ip.weight", + "15.to_k_ip.weight", + "15.to_v_ip.weight", + "17.to_k_ip.weight", + "17.to_v_ip.weight", + "19.to_k_ip.weight", + "19.to_v_ip.weight", + "21.to_k_ip.weight", + "21.to_v_ip.weight", + "23.to_k_ip.weight", + "23.to_v_ip.weight", + "25.to_k_ip.weight", + "25.to_v_ip.weight", + "27.to_k_ip.weight", + "27.to_v_ip.weight", + "29.to_k_ip.weight", + "29.to_v_ip.weight", + "31.to_k_ip.weight", + "31.to_v_ip.weight", +] + +UNET2IPAadapter_Keys_MAPIING = { + k: v for k, v in zip(unet_keys_list, ip_adapter_keys_list) +} + + +def update_unet_ip_adapter_cross_attn_param( + unet: UNet3DConditionModel, ip_adapter_state_dict: Dict +) -> None: + """use independent ip_adapter attn 中的 to_k, to_v in unet + ip_adapter: dict whose keys are ['1.to_k_ip.weight', '1.to_v_ip.weight', '3.to_k_ip.weight'] + + + Args: + unet (UNet3DConditionModel): _description_ + ip_adapter_state_dict (Dict): _description_ + """ + unet_spatial_cross_atnns = unet.spatial_cross_attns[0] + unet_spatial_cross_atnns_dct = {k: v for k, v in unet_spatial_cross_atnns} + for i, (unet_key_more, ip_adapter_key) in enumerate( + UNET2IPAadapter_Keys_MAPIING.items() + ): + ip_adapter_value = ip_adapter_state_dict[ip_adapter_key] + unet_key_more_spit = unet_key_more.split(".") + unet_key = ".".join(unet_key_more_spit[:-3]) + suffix = ".".join(unet_key_more_spit[-3:]) + logger.debug( + f"{i}: unet_key_more = {unet_key_more}, {unet_key}=unet_key, suffix={suffix}", + ) + if "to_k" in suffix: + with torch.no_grad(): + unet_spatial_cross_atnns_dct[unet_key].to_k_ip.weight.copy_( + ip_adapter_value.data + ) + else: + with torch.no_grad(): + unet_spatial_cross_atnns_dct[unet_key].to_v_ip.weight.copy_( + ip_adapter_value.data + ) diff --git a/musev/models/referencenet.py b/musev/models/referencenet.py new file mode 100755 index 0000000000000000000000000000000000000000..ddc32de16a5f319903886640871851d2b2ac4bb7 --- /dev/null +++ b/musev/models/referencenet.py @@ -0,0 +1,1216 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple, Union +import logging + +import torch +from diffusers.models.attention_processor import Attention, AttnProcessor +from einops import rearrange, repeat +import torch.nn as nn +import torch.nn.functional as F +import xformers +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.unet_2d_condition import ( + UNet2DConditionModel, + UNet2DConditionOutput, +) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils.constants import USE_PEFT_BACKEND +from diffusers.utils.deprecation_utils import deprecate +from diffusers.utils.peft_utils import scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.modeling_utils import ModelMixin, load_state_dict +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + PositionNet, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin + + +from ..data.data_util import align_repeat_tensor_single_dim +from .unet_3d_condition import UNet3DConditionModel +from .attention import BasicTransformerBlock, IPAttention +from .unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + +from . import Model_Register + + +logger = logging.getLogger(__name__) + + +@Model_Register.register +class ReferenceNet2D(UNet2DConditionModel, nn.Module): + """继承 UNet2DConditionModel. 新增功能,类似controlnet 返回模型中间特征,用于后续作用 + Inherit Unet2DConditionModel. Add new functions, similar to controlnet, return the intermediate features of the model for subsequent effects + Args: + UNet2DConditionModel (_type_): _description_ + """ + + _supports_gradient_checkpointing = True + print_idx = 0 + + @register_to_config + def __init__( + self, + sample_size: int | None = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: str | None = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: bool | Tuple[bool] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int | Tuple[int] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0, + act_fn: str = "silu", + norm_num_groups: int | None = 32, + norm_eps: float = 0.00001, + cross_attention_dim: int | Tuple[int] = 1280, + transformer_layers_per_block: int | Tuple[int] | Tuple[Tuple] = 1, + reverse_transformer_layers_per_block: Tuple[Tuple[int]] | None = None, + encoder_hid_dim: int | None = None, + encoder_hid_dim_type: str | None = None, + attention_head_dim: int | Tuple[int] = 8, + num_attention_heads: int | Tuple[int] | None = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: str | None = None, + addition_embed_type: str | None = None, + addition_time_embed_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1, + time_embedding_type: str = "positional", + time_embedding_dim: int | None = None, + time_embedding_act_fn: str | None = None, + timestep_post_act: str | None = None, + time_cond_proj_dim: int | None = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: int | None = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: bool | None = None, + cross_attention_norm: str | None = None, + addition_embed_type_num_heads=64, + need_self_attn_block_embs: bool = False, + need_block_embs: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if ( + isinstance(transformer_layers_per_block, list) + and reverse_transformer_layers_per_block is None + ): + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError( + "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info( + "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." + ) + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn + ) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear( + projection_class_embeddings_input_dim, time_embed_dim + ) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, + time_embed_dim, + num_heads=addition_embed_type_num_heads, + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, + image_embed_dim=cross_attention_dim, + time_embed_dim=time_embed_dim, + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps( + addition_time_embed_dim, flip_sin_to_cos, freq_shift + ) + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type is not None: + raise ValueError( + f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len( + down_block_types + ) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance( + cross_attention_dim, list + ): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = PositionNet( + positive_len=positive_len, + out_dim=cross_attention_dim, + feature_type=feature_type, + ) + self.need_block_embs = need_block_embs + self.need_self_attn_block_embs = need_self_attn_block_embs + + # only use referencenet soma layers, other layers set None + self.conv_norm_out = None + self.conv_act = None + self.conv_out = None + + self.up_blocks[-1].attentions[-1].proj_out = None + self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn1 = None + self.up_blocks[-1].attentions[-1].transformer_blocks[-1].attn2 = None + self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm2 = None + self.up_blocks[-1].attentions[-1].transformer_blocks[-1].ff = None + self.up_blocks[-1].attentions[-1].transformer_blocks[-1].norm3 = None + if not self.need_self_attn_block_embs: + self.up_blocks = None + + self.insert_spatial_self_attn_idx() + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + # update new paramestes start + num_frames: int = None, + return_ndim: int = 5, + # update new paramestes end + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if ( + "image_embeds" not in added_cond_kwargs + or "hint" not in added_cond_kwargs + ): + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + ): + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_image_proj" + ): + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds + ) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "image_proj" + ): + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "ip_image_proj" + ): + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to( + encoder_hidden_states.dtype + ) + encoder_hidden_states = torch.cat( + [encoder_hidden_states, image_embeds], dim=1 + ) + + # need_self_attn_block_embs + # 初始化 + # 或在unet中运算中会不断 append self_attn_blocks_embs,用完需要清理, + if self.need_self_attn_block_embs: + self_attn_block_embs = [None] * self.self_attn_num + else: + self_attn_block_embs = None + # 2. pre-process + sample = self.conv_in(sample) + if self.print_idx == 0: + logger.debug(f"after conv in sample={sample.mean()}") + # 2.5 GLIGEN position net + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = { + "objs": self.position_net(**gligen_args) + } + + # 3. down + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if ( + not is_adapter + and mid_block_additional_residual is None + and down_block_additional_residuals is not None + ): + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for i_downsample_block, downsample_block in enumerate(self.down_blocks): + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals[ + "additional_residuals" + ] = down_intrablock_additional_residuals.pop(0) + if self.print_idx == 0: + logger.debug( + f"downsample_block {i_downsample_block} sample={sample.mean()}" + ) + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + self_attn_block_embs=self_attn_block_embs, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + scale=lora_scale, + self_attn_block_embs=self_attn_block_embs, + ) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) + + down_block_res_samples = new_down_block_res_samples + + # update code start + def reshape_return_emb(tmp_emb): + if return_ndim == 4: + return tmp_emb + elif return_ndim == 5: + return rearrange(tmp_emb, "(b t) c h w-> b c t h w", t=num_frames) + else: + raise ValueError( + f"reshape_emb only support 4, 5 but given {return_ndim}" + ) + + if self.need_block_embs: + return_down_block_res_samples = [ + reshape_return_emb(tmp_emb) for tmp_emb in down_block_res_samples + ] + else: + return_down_block_res_samples = None + # update code end + + # 4. mid + if self.mid_block is not None: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + self_attn_block_embs=self_attn_block_embs, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + if self.need_block_embs: + return_mid_block_res_samples = reshape_return_emb(sample) + logger.debug( + f"return_mid_block_res_samples, is_leaf={return_mid_block_res_samples.is_leaf}, requires_grad={return_mid_block_res_samples.requires_grad}" + ) + else: + return_mid_block_res_samples = None + + if self.up_blocks is not None: + # update code end + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + self_attn_block_embs=self_attn_block_embs, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + self_attn_block_embs=self_attn_block_embs, + ) + + # update code start + if self.need_block_embs or self.need_self_attn_block_embs: + if self_attn_block_embs is not None: + self_attn_block_embs = [ + reshape_return_emb(tmp_emb=tmp_emb) + for tmp_emb in self_attn_block_embs + ] + self.print_idx += 1 + return ( + return_down_block_res_samples, + return_mid_block_res_samples, + self_attn_block_embs, + ) + + if not self.need_block_embs and not self.need_self_attn_block_embs: + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + self.print_idx += 1 + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + + def insert_spatial_self_attn_idx(self): + attns, basic_transformers = self.spatial_self_attns + self.self_attn_num = len(attns) + for i, (name, layer) in enumerate(attns): + logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}") + if layer is not None: + layer.spatial_self_attn_idx = i + for i, (name, layer) in enumerate(basic_transformers): + logger.debug(f"{self.__class__.__name__}, {i}, {name}, {type(layer)}") + if layer is not None: + layer.spatial_self_attn_idx = i + + @property + def spatial_self_attns( + self, + ) -> List[Tuple[str, Attention]]: + attns, spatial_transformers = self.get_self_attns( + include="attentions", exclude="temp_attentions" + ) + attns = sorted(attns) + spatial_transformers = sorted(spatial_transformers) + return attns, spatial_transformers + + def get_self_attns( + self, include: str = None, exclude: str = None + ) -> List[Tuple[str, Attention]]: + r""" + Returns: + `dict` of attention attns: A dictionary containing all attention attns used in the model with + indexed by its weight name. + """ + # set recursively + attns = [] + spatial_transformers = [] + + def fn_recursive_add_attns( + name: str, + module: torch.nn.Module, + attns: List[Tuple[str, Attention]], + spatial_transformers: List[Tuple[str, BasicTransformerBlock]], + ): + is_target = False + if isinstance(module, BasicTransformerBlock) and hasattr(module, "attn1"): + is_target = True + if include is not None: + is_target = include in name + if exclude is not None: + is_target = exclude not in name + if is_target: + attns.append([f"{name}.attn1", module.attn1]) + spatial_transformers.append([f"{name}", module]) + for sub_name, child in module.named_children(): + fn_recursive_add_attns( + f"{name}.{sub_name}", child, attns, spatial_transformers + ) + + return attns + + for name, module in self.named_children(): + fn_recursive_add_attns(name, module, attns, spatial_transformers) + + return attns, spatial_transformers + + +class ReferenceNet3D(UNet3DConditionModel): + """继承 UNet3DConditionModel, 用于提取中间emb用于后续作用。 + Inherit Unet3DConditionModel, used to extract the middle emb for subsequent actions. + Args: + UNet3DConditionModel (_type_): _description_ + """ + + pass diff --git a/musev/models/referencenet_loader.py b/musev/models/referencenet_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..41a7d0d63f423e9b5bd7486294e4cc9413ed4088 --- /dev/null +++ b/musev/models/referencenet_loader.py @@ -0,0 +1,124 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, +) +from diffusers.utils.import_utils import is_xformers_available + +from .referencenet import ReferenceNet2D +from .unet_loader import update_unet_with_sd + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def load_referencenet( + sd_referencenet_model: Tuple[str, nn.Module], + sd_model: nn.Module = None, + need_self_attn_block_embs: bool = False, + need_block_embs: bool = False, + dtype: torch.dtype = torch.float16, + cross_attention_dim: int = 768, + subfolder: str = "unet", +): + """ + Loads the ReferenceNet model. + + Args: + sd_referencenet_model (Tuple[str, nn.Module] or str): The pretrained ReferenceNet model or the path to the model. + sd_model (nn.Module, optional): The sd_model to update the ReferenceNet with. Defaults to None. + need_self_attn_block_embs (bool, optional): Whether to compute self-attention block embeddings. Defaults to False. + need_block_embs (bool, optional): Whether to compute block embeddings. Defaults to False. + dtype (torch.dtype, optional): The data type of the tensors. Defaults to torch.float16. + cross_attention_dim (int, optional): The dimension of the cross-attention. Defaults to 768. + subfolder (str, optional): The subfolder of the model. Defaults to "unet". + + Returns: + nn.Module: The loaded ReferenceNet model. + """ + + if isinstance(sd_referencenet_model, str): + referencenet = ReferenceNet2D.from_pretrained( + sd_referencenet_model, + subfolder=subfolder, + need_self_attn_block_embs=need_self_attn_block_embs, + need_block_embs=need_block_embs, + torch_dtype=dtype, + cross_attention_dim=cross_attention_dim, + ) + elif isinstance(sd_referencenet_model, nn.Module): + referencenet = sd_referencenet_model + if sd_model is not None: + referencenet = update_unet_with_sd(referencenet, sd_model) + return referencenet + + +def load_referencenet_by_name( + model_name: str, + sd_referencenet_model: Tuple[str, nn.Module], + sd_model: nn.Module = None, + cross_attention_dim: int = 768, + dtype: torch.dtype = torch.float16, +) -> nn.Module: + """通过模型名字 初始化 referencenet,载入预训练参数, + 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 + init referencenet with model_name. + if you want to use pretrained model with simple name, you need to define it here. + Args: + model_name (str): _description_ + sd_unet_model (Tuple[str, nn.Module]): _description_ + sd_model (Tuple[str, nn.Module]): _description_ + cross_attention_dim (int, optional): _description_. Defaults to 768. + dtype (torch.dtype, optional): _description_. Defaults to torch.float16. + + Raises: + ValueError: _description_ + + Returns: + nn.Module: _description_ + """ + if model_name in [ + "musev_referencenet", + ]: + unet = load_referencenet( + sd_referencenet_model=sd_referencenet_model, + sd_model=sd_model, + cross_attention_dim=cross_attention_dim, + dtype=dtype, + need_self_attn_block_embs=False, + need_block_embs=True, + subfolder="referencenet", + ) + else: + raise ValueError( + f"unsupport model_name={model_name}, only support ReferenceNet_V0_block13, ReferenceNet_V1_block13, ReferenceNet_V2_block13, ReferenceNet_V0_sefattn16" + ) + return unet diff --git a/musev/models/resnet.py b/musev/models/resnet.py new file mode 100755 index 0000000000000000000000000000000000000000..2724fad83484010eab19838ff0f8b16b3b1ea8eb --- /dev/null +++ b/musev/models/resnet.py @@ -0,0 +1,135 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# `TemporalConvLayer` Copyright 2023 Alibaba DAMO-VILAB, The ModelScope Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py +from __future__ import annotations + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from diffusers.models.resnet import TemporalConvLayer as DiffusersTemporalConvLayer +from ..data.data_util import batch_index_fill, batch_index_select +from . import Model_Register + + +@Model_Register.register +class TemporalConvLayer(nn.Module): + """ + Temporal convolutional layer that can be used for video (sequence of images) input Code mostly copied from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__( + self, + in_dim, + out_dim=None, + dropout=0.0, + keep_content_condition: bool = False, + femb_channels: Optional[int] = None, + need_temporal_weight: bool = True, + ): + super().__init__() + out_dim = out_dim or in_dim + self.in_dim = in_dim + self.out_dim = out_dim + self.keep_content_condition = keep_content_condition + self.femb_channels = femb_channels + self.need_temporal_weight = need_temporal_weight + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)), + ) + + # zero out the last layer params,so the conv block is identity + # nn.init.zeros_(self.conv4[-1].weight) + # nn.init.zeros_(self.conv4[-1].bias) + self.temporal_weight = nn.Parameter( + torch.tensor( + [ + 1e-5, + ] + ) + ) # initialize parameter with 0 + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + self.skip_temporal_layers = False # Whether to skip temporal layer + + def forward( + self, + hidden_states, + num_frames=1, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + femb: torch.Tensor = None, + ): + if self.skip_temporal_layers is True: + return hidden_states + hidden_states_dtype = hidden_states.dtype + hidden_states = rearrange( + hidden_states, "(b t) c h w -> b c t h w", t=num_frames + ) + identity = hidden_states + hidden_states = self.conv1(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.conv3(hidden_states) + hidden_states = self.conv4(hidden_states) + # 保留condition对应的frames,便于保持前序内容帧,提升一致性 + if self.keep_content_condition: + mask = torch.ones_like(hidden_states, device=hidden_states.device) + mask = batch_index_fill( + mask, dim=2, index=vision_conditon_frames_sample_index, value=0 + ) + if self.need_temporal_weight: + hidden_states = ( + identity + torch.abs(self.temporal_weight) * mask * hidden_states + ) + else: + hidden_states = identity + mask * hidden_states + else: + if self.need_temporal_weight: + hidden_states = ( + identity + torch.abs(self.temporal_weight) * hidden_states + ) + else: + hidden_states = identity + hidden_states + hidden_states = rearrange(hidden_states, " b c t h w -> (b t) c h w") + hidden_states = hidden_states.to(dtype=hidden_states_dtype) + return hidden_states diff --git a/musev/models/super_model.py b/musev/models/super_model.py new file mode 100755 index 0000000000000000000000000000000000000000..f4afdb8b1f58dc610c39159b0010208363f520e1 --- /dev/null +++ b/musev/models/super_model.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +import logging + +from typing import Any, Dict, Tuple, Union, Optional +from einops import rearrange, repeat +from torch import nn +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin, load_state_dict + +from ..data.data_util import align_repeat_tensor_single_dim + +from .unet_3d_condition import UNet3DConditionModel +from .referencenet import ReferenceNet2D +from ip_adapter.ip_adapter import ImageProjModel + +logger = logging.getLogger(__name__) + + +class SuperUNet3DConditionModel(nn.Module): + """封装了各种子模型的超模型,与 diffusers 的 pipeline 很像,只不过这里是模型定义。 + 主要作用 + 1. 将支持controlnet、referencenet等功能的计算封装起来,简洁些; + 2. 便于 accelerator 的分布式训练; + + wrap the sub-models, such as unet, referencenet, controlnet, vae, text_encoder, tokenizer, text_emb_extractor, clip_vision_extractor, ip_adapter_image_proj + 1. support controlnet, referencenet, etc. + 2. support accelerator distributed training + """ + + _supports_gradient_checkpointing = True + print_idx = 0 + + # @register_to_config + def __init__( + self, + unet: nn.Module, + referencenet: nn.Module = None, + controlnet: nn.Module = None, + vae: nn.Module = None, + text_encoder: nn.Module = None, + tokenizer: nn.Module = None, + text_emb_extractor: nn.Module = None, + clip_vision_extractor: nn.Module = None, + ip_adapter_image_proj: nn.Module = None, + ) -> None: + """_summary_ + + Args: + unet (nn.Module): _description_ + referencenet (nn.Module, optional): _description_. Defaults to None. + controlnet (nn.Module, optional): _description_. Defaults to None. + vae (nn.Module, optional): _description_. Defaults to None. + text_encoder (nn.Module, optional): _description_. Defaults to None. + tokenizer (nn.Module, optional): _description_. Defaults to None. + text_emb_extractor (nn.Module, optional): wrap text_encoder and tokenizer for str2emb. Defaults to None. + clip_vision_extractor (nn.Module, optional): _description_. Defaults to None. + """ + super().__init__() + self.unet = unet + self.referencenet = referencenet + self.controlnet = controlnet + self.vae = vae + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.text_emb_extractor = text_emb_extractor + self.clip_vision_extractor = clip_vision_extractor + self.ip_adapter_image_proj = ip_adapter_image_proj + + def forward( + self, + unet_params: Dict, + encoder_hidden_states: torch.Tensor, + referencenet_params: Dict = None, + controlnet_params: Dict = None, + controlnet_scale: float = 1.0, + vision_clip_emb: Union[torch.Tensor, None] = None, + prompt_only_use_image_prompt: bool = False, + ): + """_summary_ + + Args: + unet_params (Dict): _description_ + encoder_hidden_states (torch.Tensor): b t n d + referencenet_params (Dict, optional): _description_. Defaults to None. + controlnet_params (Dict, optional): _description_. Defaults to None. + controlnet_scale (float, optional): _description_. Defaults to 1.0. + vision_clip_emb (Union[torch.Tensor, None], optional): b t d. Defaults to None. + prompt_only_use_image_prompt (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + batch_size = unet_params["sample"].shape[0] + time_size = unet_params["sample"].shape[2] + + # ip_adapter_cross_attn, prepare image prompt + if vision_clip_emb is not None: + # b t n d -> b t n d + if self.print_idx == 0: + logger.debug( + f"vision_clip_emb, before ip_adapter_image_proj, shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" + ) + if vision_clip_emb.ndim == 3: + vision_clip_emb = rearrange(vision_clip_emb, "b t d-> b t 1 d") + if self.ip_adapter_image_proj is not None: + vision_clip_emb = rearrange(vision_clip_emb, "b t n d ->(b t) n d") + vision_clip_emb = self.ip_adapter_image_proj(vision_clip_emb) + if self.print_idx == 0: + logger.debug( + f"vision_clip_emb, after ip_adapter_image_proj shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" + ) + if vision_clip_emb.ndim == 2: + vision_clip_emb = rearrange(vision_clip_emb, "b d-> b 1 d") + vision_clip_emb = rearrange( + vision_clip_emb, "(b t) n d -> b t n d", b=batch_size + ) + vision_clip_emb = align_repeat_tensor_single_dim( + vision_clip_emb, target_length=time_size, dim=1 + ) + if self.print_idx == 0: + logger.debug( + f"vision_clip_emb, after reshape shape={vision_clip_emb.shape} mean={torch.mean(vision_clip_emb)}" + ) + + if vision_clip_emb is None and encoder_hidden_states is not None: + vision_clip_emb = encoder_hidden_states + if vision_clip_emb is not None and encoder_hidden_states is None: + encoder_hidden_states = vision_clip_emb + # 当 prompt_only_use_image_prompt 为True时, + # 1. referencenet 都使用 vision_clip_emb + # 2. unet 如果没有dual_cross_attn,使用vision_clip_emb,有时不更新 + # 3. controlnet 当前使用 text_prompt + + # when prompt_only_use_image_prompt True, + # 1. referencenet use vision_clip_emb + # 2. unet use vision_clip_emb if no dual_cross_attn, sometimes not update + # 3. controlnet use text_prompt + + # extract referencenet emb + if self.referencenet is not None and referencenet_params is not None: + referencenet_encoder_hidden_states = align_repeat_tensor_single_dim( + vision_clip_emb, + target_length=referencenet_params["num_frames"], + dim=1, + ) + referencenet_params["encoder_hidden_states"] = rearrange( + referencenet_encoder_hidden_states, "b t n d->(b t) n d" + ) + referencenet_out = self.referencenet(**referencenet_params) + ( + down_block_refer_embs, + mid_block_refer_emb, + refer_self_attn_emb, + ) = referencenet_out + if down_block_refer_embs is not None: + if self.print_idx == 0: + logger.debug( + f"len(down_block_refer_embs)={len(down_block_refer_embs)}" + ) + for i, down_emb in enumerate(down_block_refer_embs): + if self.print_idx == 0: + logger.debug( + f"down_emb, {i}, {down_emb.shape}, mean={down_emb.mean()}" + ) + else: + if self.print_idx == 0: + logger.debug(f"down_block_refer_embs is None") + if mid_block_refer_emb is not None: + if self.print_idx == 0: + logger.debug( + f"mid_block_refer_emb, {mid_block_refer_emb.shape}, mean={mid_block_refer_emb.mean()}" + ) + else: + if self.print_idx == 0: + logger.debug(f"mid_block_refer_emb is None") + if refer_self_attn_emb is not None: + if self.print_idx == 0: + logger.debug(f"refer_self_attn_emb, num={len(refer_self_attn_emb)}") + for i, self_attn_emb in enumerate(refer_self_attn_emb): + if self.print_idx == 0: + logger.debug( + f"referencenet, self_attn_emb, {i}th, shape={self_attn_emb.shape}, mean={self_attn_emb.mean()}" + ) + else: + if self.print_idx == 0: + logger.debug(f"refer_self_attn_emb is None") + else: + down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb = ( + None, + None, + None, + ) + + # extract controlnet emb + if self.controlnet is not None and controlnet_params is not None: + controlnet_encoder_hidden_states = align_repeat_tensor_single_dim( + encoder_hidden_states, + target_length=unet_params["sample"].shape[2], + dim=1, + ) + controlnet_params["encoder_hidden_states"] = rearrange( + controlnet_encoder_hidden_states, " b t n d -> (b t) n d" + ) + ( + down_block_additional_residuals, + mid_block_additional_residual, + ) = self.controlnet(**controlnet_params) + if controlnet_scale != 1.0: + down_block_additional_residuals = [ + x * controlnet_scale for x in down_block_additional_residuals + ] + mid_block_additional_residual = ( + mid_block_additional_residual * controlnet_scale + ) + for i, down_block_additional_residual in enumerate( + down_block_additional_residuals + ): + if self.print_idx == 0: + logger.debug( + f"{i}, down_block_additional_residual mean={torch.mean(down_block_additional_residual)}" + ) + + if self.print_idx == 0: + logger.debug( + f"mid_block_additional_residual mean={torch.mean(mid_block_additional_residual)}" + ) + else: + down_block_additional_residuals = None + mid_block_additional_residual = None + + if prompt_only_use_image_prompt and vision_clip_emb is not None: + encoder_hidden_states = vision_clip_emb + + # run unet + out = self.unet( + **unet_params, + down_block_refer_embs=down_block_refer_embs, + mid_block_refer_emb=mid_block_refer_emb, + refer_self_attn_emb=refer_self_attn_emb, + down_block_additional_residuals=down_block_additional_residuals, + mid_block_additional_residual=mid_block_additional_residual, + encoder_hidden_states=encoder_hidden_states, + vision_clip_emb=vision_clip_emb, + ) + self.print_idx += 1 + return out + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (UNet3DConditionModel, ReferenceNet2D)): + module.gradient_checkpointing = value diff --git a/musev/models/temporal_transformer.py b/musev/models/temporal_transformer.py new file mode 100755 index 0000000000000000000000000000000000000000..efc0faecadd586e5b51f46adf8dd1b384b58e742 --- /dev/null +++ b/musev/models/temporal_transformer.py @@ -0,0 +1,308 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/transformer_temporal.py +from __future__ import annotations +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Literal, Optional +import logging + +import torch +from torch import nn +from einops import rearrange, repeat + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformer_temporal import ( + TransformerTemporalModelOutput, + TransformerTemporalModel as DiffusersTransformerTemporalModel, +) +from diffusers.models.attention_processor import AttnProcessor + +from mmcm.utils.gpu_util import get_gpu_status +from ..data.data_util import ( + batch_concat_two_tensor_with_index, + batch_index_fill, + batch_index_select, + concat_two_tensor, + align_repeat_tensor_single_dim, +) +from ..utils.attention_util import generate_sparse_causcal_attn_mask +from .attention import BasicTransformerBlock +from .attention_processor import ( + BaseIPAttnProcessor, +) +from . import Model_Register + +# https://github.com/facebookresearch/xformers/issues/845 +# 输入bs*n_frames*w*h太高,xformers报错。因此将transformer_temporal的allow_xformers均关掉 +# if bs*n_frames*w*h to large, xformers will raise error. So we close the allow_xformers in transformer_temporal +logger = logging.getLogger(__name__) + + +@Model_Register.register +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + double_self_attention (`bool`, *optional*): + Configure if each TransformerBlock should contain two self-attention layers + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + femb_channels: Optional[int] = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + allow_xformers: bool = False, + only_cross_attention: bool = False, + keep_content_condition: bool = False, + need_spatial_position_emb: bool = False, + need_temporal_weight: bool = True, + self_attn_mask: str = None, + # TODO: 运行参数,有待改到forward里面去 + # TODO: running parameters, need to be moved to forward + image_scale: float = 1.0, + processor: AttnProcessor | None = None, + remove_femb_non_linear: bool = False, + ): + super().__init__() + + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 2. Define temporal positional embedding + self.frame_emb_proj = torch.nn.Linear(femb_channels, inner_dim) + self.remove_femb_non_linear = remove_femb_non_linear + if not remove_femb_non_linear: + self.nonlinearity = nn.SiLU() + + # spatial_position_emb 使用femb_的参数配置 + self.need_spatial_position_emb = need_spatial_position_emb + if need_spatial_position_emb: + self.spatial_position_emb_proj = torch.nn.Linear(femb_channels, inner_dim) + # 3. Define transformers blocks + # TODO: 该实现方式不好,待优化 + # TODO: bad implementation, need to be optimized + self.need_ipadapter = False + self.cross_attn_temporal_cond = False + self.allow_xformers = allow_xformers + if processor is not None and isinstance(processor, BaseIPAttnProcessor): + self.cross_attn_temporal_cond = True + self.allow_xformers = False + if "NonParam" not in processor.__class__.__name__: + self.need_ipadapter = True + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + allow_xformers=allow_xformers, + only_cross_attention=only_cross_attention, + cross_attn_temporal_cond=self.need_ipadapter, + image_scale=image_scale, + processor=processor, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.need_temporal_weight = need_temporal_weight + if need_temporal_weight: + self.temporal_weight = nn.Parameter( + torch.tensor( + [ + 1e-5, + ] + ) + ) # initialize parameter with 0 + self.skip_temporal_layers = False # Whether to skip temporal layer + self.keep_content_condition = keep_content_condition + self.self_attn_mask = self_attn_mask + self.only_cross_attention = only_cross_attention + self.double_self_attention = double_self_attention + self.cross_attention_dim = cross_attention_dim + self.image_scale = image_scale + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.proj_out.weight) + nn.init.zeros_(self.proj_out.bias) + + def forward( + self, + hidden_states, + femb, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + num_frames=1, + cross_attention_kwargs=None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: + [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. + When returning a tuple, the first element is the sample tensor. + """ + if self.skip_temporal_layers is True: + if not return_dict: + return (hidden_states,) + + return TransformerTemporalModelOutput(sample=hidden_states) + + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = rearrange( + hidden_states, "(b t) c h w -> b c t h w", b=batch_size + ) + residual = hidden_states + + hidden_states = self.norm(hidden_states) + + hidden_states = rearrange(hidden_states, "b c t h w -> (b h w) t c") + + hidden_states = self.proj_in(hidden_states) + + # 2 Positional embedding + # adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/resnet.py#L574 + if not self.remove_femb_non_linear: + femb = self.nonlinearity(femb) + femb = self.frame_emb_proj(femb) + femb = align_repeat_tensor_single_dim(femb, hidden_states.shape[0], dim=0) + hidden_states = hidden_states + femb + + # 3. Blocks + if ( + (self.only_cross_attention or not self.double_self_attention) + and self.cross_attention_dim is not None + and encoder_hidden_states is not None + ): + encoder_hidden_states = align_repeat_tensor_single_dim( + encoder_hidden_states, + hidden_states.shape[0], + dim=0, + n_src_base_length=batch_size, + ) + + for i, block in enumerate(self.transformer_blocks): + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 4. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = rearrange( + hidden_states, "(b h w) t c -> b c t h w", b=batch_size, h=height, w=width + ).contiguous() + + # 保留condition对应的frames,便于保持前序内容帧,提升一致性 + # keep the frames corresponding to the condition to maintain the previous content frames and improve consistency + if ( + vision_conditon_frames_sample_index is not None + and self.keep_content_condition + ): + mask = torch.ones_like(hidden_states, device=hidden_states.device) + mask = batch_index_fill( + mask, dim=2, index=vision_conditon_frames_sample_index, value=0 + ) + if self.need_temporal_weight: + output = ( + residual + torch.abs(self.temporal_weight) * mask * hidden_states + ) + else: + output = residual + mask * hidden_states + else: + if self.need_temporal_weight: + output = residual + torch.abs(self.temporal_weight) * hidden_states + else: + output = residual + mask * hidden_states + + # output = torch.abs(self.temporal_weight) * hidden_states + residual + output = rearrange(output, "b c t h w -> (b t) c h w") + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/musev/models/text_model.py b/musev/models/text_model.py new file mode 100755 index 0000000000000000000000000000000000000000..98712c5dae8d6779fb209dc61792376b62ce21b7 --- /dev/null +++ b/musev/models/text_model.py @@ -0,0 +1,40 @@ +from typing import Any, Dict +from torch import nn + + +class TextEmbExtractor(nn.Module): + def __init__(self, tokenizer, text_encoder) -> None: + super(TextEmbExtractor, self).__init__() + self.tokenizer = tokenizer + self.text_encoder = text_encoder + + def forward( + self, + texts, + text_params: Dict = None, + ): + if text_params is None: + text_params = {} + special_prompt_input = self.tokenizer( + texts, + max_length=self.tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): + attention_mask = special_prompt_input.attention_mask.to( + self.text_encoder.device + ) + else: + attention_mask = None + + embeddings = self.text_encoder( + special_prompt_input.input_ids.to(self.text_encoder.device), + attention_mask=attention_mask, + **text_params + ) + return embeddings diff --git a/musev/models/transformer_2d.py b/musev/models/transformer_2d.py new file mode 100755 index 0000000000000000000000000000000000000000..b5a74eb9f63a7ae5e345de15035ba42c3919ab13 --- /dev/null +++ b/musev/models/transformer_2d.py @@ -0,0 +1,445 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional +import logging + +from einops import rearrange + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.models.transformer_2d import ( + Transformer2DModelOutput, + Transformer2DModel as DiffusersTransformer2DModel, +) + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.models.attention import ( + BasicTransformerBlock as DiffusersBasicTransformerBlock, +) +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils.constants import USE_PEFT_BACKEND + +from .attention import BasicTransformerBlock + +logger = logging.getLogger(__name__) + +# 本部分 与 diffusers/models/transformer_2d.py 几乎一样 +# 更新部分 +# 1. 替换自定义 BasicTransformerBlock 类 +# 2. 在forward 里增加了 self_attn_block_embs 用于 提取 self_attn 中的emb + +# this module is same as diffusers/models/transformer_2d.py. The update part is +# 1 redefine BasicTransformerBlock +# 2. add self_attn_block_embs in forward to extract emb from self_attn + + +class Transformer2DModel(DiffusersTransformer2DModel): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int | None = None, + out_channels: int | None = None, + num_layers: int = 1, + dropout: float = 0, + norm_num_groups: int = 32, + cross_attention_dim: int | None = None, + attention_bias: bool = False, + sample_size: int | None = None, + num_vector_embeds: int | None = None, + patch_size: int | None = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: int | None = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + attention_type: str = "default", + cross_attn_temporal_cond: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + image_scale: float = 1.0, + ): + super().__init__( + num_attention_heads, + attention_head_dim, + in_channels, + out_channels, + num_layers, + dropout, + norm_num_groups, + cross_attention_dim, + attention_bias, + sample_size, + num_vector_embeds, + patch_size, + activation_fn, + num_embeds_ada_norm, + use_linear_projection, + only_cross_attention, + double_self_attention, + upcast_attention, + norm_type, + norm_elementwise_affine, + attention_type, + ) + inner_dim = num_attention_heads * attention_head_dim + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + attention_type=attention_type, + cross_attn_temporal_cond=cross_attn_temporal_cond, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + image_scale=image_scale, + ) + for d in range(num_layers) + ] + ) + self.num_layers = num_layers + self.cross_attn_temporal_cond = cross_attn_temporal_cond + self.ip_adapter_cross_attn = ip_adapter_cross_attn + + self.need_t2i_facein = need_t2i_facein + self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face + self.image_scale = image_scale + self.print_idx = 0 + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = ( + hidden_states.shape[-2] // self.patch_size, + hidden_states.shape[-1] // self.patch_size, + ) + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, + added_cond_kwargs, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + self_attn_block_embs, + self_attn_block_embs_mode, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + ) + # 将 转换 self_attn_emb的尺寸 + if ( + self_attn_block_embs is not None + and self_attn_block_embs_mode.lower() == "write" + ): + self_attn_idx = block.spatial_self_attn_idx + if self.print_idx == 0: + logger.debug( + f"self_attn_block_embs, num={len(self_attn_block_embs)}, before, shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}" + ) + self_attn_block_embs[self_attn_idx] = rearrange( + self_attn_block_embs[self_attn_idx], + "bt (h w) c->bt c h w", + h=height, + w=width, + ) + if self.print_idx == 0: + logger.debug( + f"self_attn_block_embs, num={len(self_attn_block_embs)}, after ,shape={self_attn_block_embs[self_attn_idx].shape}, height={height}, width={width}" + ) + + if self.proj_out is None: + return hidden_states + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = ( + self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + ) + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = ( + self.scale_shift_table[None] + embedded_timestep[:, None] + ).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=( + -1, + height, + width, + self.patch_size, + self.patch_size, + self.out_channels, + ) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=( + -1, + self.out_channels, + height * self.patch_size, + width * self.patch_size, + ) + ) + self.print_idx += 1 + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/musev/models/unet_2d_blocks.py b/musev/models/unet_2d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..4d30a9291441cd661e6e7e79cc0cc10ee35da7a4 --- /dev/null +++ b/musev/models/unet_2d_blocks.py @@ -0,0 +1,1537 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Literal, Optional, Tuple, Union, List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.utils.torch_utils import apply_freeu +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + Attention, + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, +) +from diffusers.models.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + Upsample2D, +) +from diffusers.models.unet_2d_blocks import ( + AttnDownBlock2D, + AttnDownEncoderBlock2D, + AttnSkipDownBlock2D, + AttnSkipUpBlock2D, + AttnUpBlock2D, + AttnUpDecoderBlock2D, + DownEncoderBlock2D, + KCrossAttnDownBlock2D, + KCrossAttnUpBlock2D, + KDownBlock2D, + KUpBlock2D, + ResnetDownsampleBlock2D, + ResnetUpsampleBlock2D, + SimpleCrossAttnDownBlock2D, + SimpleCrossAttnUpBlock2D, + SkipDownBlock2D, + SkipUpBlock2D, + UpDecoderBlock2D, +) + +from .transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = ( + down_block_type[7:] + if down_block_type.startswith("UNetRes") + else down_block_type + ) + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock2D" + ) + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D" + ) + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = ( + up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + ) + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock2D" + ) + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D" + ) + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = ( + resnet_groups if resnet_time_scale_shift == "default" else None + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels + if resnet_time_scale_shift == "spatial" + else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn( + hidden_states, + temb=temb, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") + else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + cross_attention_kwargs = ( + cross_attention_kwargs if cross_attention_kwargs is not None else {} + ) + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + ) + + # resnet + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class CrossAttnDownBlock2D(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if self.print_idx == 0: + logger.debug(f"unet3d after resnet {hidden_states.mean()}") + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + if self.print_idx == 0: + logger.debug(f"unet3d after resnet {hidden_states.mean()}") + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + self.print_idx += 1 + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + lora_scale = ( + cross_attention_kwargs.get("scale", 1.0) + if cross_attention_kwargs is not None + else 1.0 + ) + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + self_attn_block_embs_mode=self_attn_block_embs_mode, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + self_attn_block_embs=self_attn_block_embs, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler( + hidden_states, upsample_size, scale=lora_scale + ) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + self_attn_block_embs: Optional[List[torch.Tensor]] = None, + self_attn_block_embs_mode: Literal["read", "write"] = "write", + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + return hidden_states diff --git a/musev/models/unet_3d_blocks.py b/musev/models/unet_3d_blocks.py new file mode 100755 index 0000000000000000000000000000000000000000..d2d79bbfc6078735f77701f9621e34a454af82c7 --- /dev/null +++ b/musev/models/unet_3d_blocks.py @@ -0,0 +1,1413 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_blocks.py + +from typing import Any, Dict, List, Literal, Optional, Tuple, Union +import logging + +import torch +from torch import nn + +from diffusers.utils import is_torch_version +from diffusers.models.transformer_2d import ( + Transformer2DModel as DiffusersTransformer2DModel, +) +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from ..data.data_util import batch_adain_conditioned_tensor + +from .resnet import TemporalConvLayer +from .temporal_transformer import TransformerTemporalModel +from .transformer_2d import Transformer2DModel +from .attention_processor import ReferEmbFuseAttention + + +logger = logging.getLogger(__name__) + +# 注: +# (1) 原代码的`use_linear_projection`默认值均为True,与2D-SD模型不符,load时报错。因此均改为False +# (2) 原代码调用`Transformer2DModel`的输入参数顺序为n_channels // attn_num_head_channels, attn_num_head_channels, +# 与2D-SD模型不符。因此把顺序交换 +# (3) 增加了temporal attention用的frame embedding输入 + +# note: +# 1. The default value of `use_linear_projection` in the original code is True, which is inconsistent with the 2D-SD model and causes an error when loading. Therefore, it is changed to False. +# 2. The original code calls `Transformer2DModel` with the input parameter order of n_channels // attn_num_head_channels, attn_num_head_channels, which is inconsistent with the 2D-SD model. Therefore, the order is reversed. +# 3. Added the frame embedding input used by the temporal attention + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + femb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + need_spatial_position_emb: bool = False, + need_t2i_ip_adapter: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + need_refer_emb: bool = False, +): + if (isinstance(down_block_type, str) and down_block_type == "DownBlock3D") or ( + isinstance(down_block_type, nn.Module) + and down_block_type.__name__ == "DownBlock3D" + ): + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + femb_channels=femb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_conv_block=temporal_conv_block, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + need_refer_emb=need_refer_emb, + attn_num_head_channels=attn_num_head_channels, + ) + elif ( + isinstance(down_block_type, str) and down_block_type == "CrossAttnDownBlock3D" + ) or ( + isinstance(down_block_type, nn.Module) + and down_block_type.__name__ == "CrossAttnDownBlock3D" + ): + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnDownBlock3D" + ) + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + femb_channels=femb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_conv_block=temporal_conv_block, + temporal_transformer=temporal_transformer, + need_spatial_position_emb=need_spatial_position_emb, + need_t2i_ip_adapter=need_t2i_ip_adapter, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + need_refer_emb=need_refer_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + femb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel, + need_spatial_position_emb: bool = False, + need_t2i_ip_adapter: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, +): + if (isinstance(up_block_type, str) and up_block_type == "UpBlock3D") or ( + isinstance(up_block_type, nn.Module) and up_block_type.__name__ == "UpBlock3D" + ): + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + femb_channels=femb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_conv_block=temporal_conv_block, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + ) + elif (isinstance(up_block_type, str) and up_block_type == "CrossAttnUpBlock3D") or ( + isinstance(up_block_type, nn.Module) + and up_block_type.__name__ == "CrossAttnUpBlock3D" + ): + if cross_attention_dim is None: + raise ValueError( + "cross_attention_dim must be specified for CrossAttnUpBlock3D" + ) + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + femb_channels=femb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_conv_block=temporal_conv_block, + temporal_transformer=temporal_transformer, + need_spatial_position_emb=need_spatial_position_emb, + need_t2i_ip_adapter=need_t2i_ip_adapter, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + temb_channels: int, + femb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel, + need_spatial_position_emb: bool = False, + need_t2i_ip_adapter: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ] + if temporal_conv_block is not None: + temp_convs = [ + temporal_conv_block( + in_channels, + in_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ] + else: + temp_convs = [None] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + cross_attn_temporal_cond=need_t2i_ip_adapter, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + ) + ) + if temporal_transformer is not None: + temp_attention = temporal_transformer( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + femb_channels=femb_channels, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + need_spatial_position_emb=need_spatial_position_emb, + ) + else: + temp_attention = None + temp_attentions.append(temp_attention) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ) + if temporal_conv_block is not None: + temp_convs.append( + temporal_conv_block( + in_channels, + in_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ) + else: + temp_convs.append(None) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + self.need_adain_temporal_cond = need_adain_temporal_cond + + def forward( + self, + hidden_states, + temb=None, + femb=None, + encoder_hidden_states=None, + attention_mask=None, + num_frames=1, + cross_attention_kwargs=None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + refer_self_attn_emb: List[torch.Tensor] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + ): + hidden_states = self.resnets[0](hidden_states, temb) + if self.temp_convs[0] is not None: + hidden_states = self.temp_convs[0]( + hidden_states, + femb=femb, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + self_attn_block_embs=refer_self_attn_emb, + self_attn_block_embs_mode=refer_self_attn_emb_mode, + ).sample + if temp_attn is not None: + hidden_states = temp_attn( + hidden_states, + femb=femb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=encoder_hidden_states, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + ).sample + hidden_states = resnet(hidden_states, temb) + if temp_conv is not None: + hidden_states = temp_conv( + hidden_states, + femb=femb, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + self.print_idx += 1 + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + femb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel, + need_spatial_position_emb: bool = False, + need_t2i_ip_adapter: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + need_refer_emb: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + self.need_refer_emb = need_refer_emb + if need_refer_emb: + refer_emb_attns = [] + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ) + if temporal_conv_block is not None: + temp_convs.append( + temporal_conv_block( + out_channels, + out_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ) + else: + temp_convs.append(None) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + cross_attn_temporal_cond=need_t2i_ip_adapter, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + ) + ) + if temporal_transformer is not None: + temp_attention = temporal_transformer( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + femb_channels=femb_channels, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + need_spatial_position_emb=need_spatial_position_emb, + ) + else: + temp_attention = None + temp_attentions.append(temp_attention) + + if need_refer_emb: + refer_emb_attns.append( + ReferEmbFuseAttention( + query_dim=out_channels, + heads=attn_num_head_channels, + dim_head=out_channels // attn_num_head_channels, + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + if need_refer_emb: + refer_emb_attns.append( + ReferEmbFuseAttention( + query_dim=out_channels, + heads=attn_num_head_channels, + dim_head=out_channels // attn_num_head_channels, + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.need_adain_temporal_cond = need_adain_temporal_cond + if need_refer_emb: + self.refer_emb_attns = nn.ModuleList(refer_emb_attns) + logger.debug(f"cross attn downblock 3d need_refer_emb, {self.need_refer_emb}") + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + femb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + refer_embs: Optional[List[torch.Tensor]] = None, + refer_self_attn_emb: List[torch.Tensor] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + ): + # TODO(Patrick, William) - attention mask is not used + output_states = () + for i_downblock, (resnet, temp_conv, attn, temp_attn) in enumerate( + zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions) + ): + # print("crossattndownblock3d, attn,", type(attn), cross_attention_kwargs) + if self.training and self.gradient_checkpointing: + if self.print_idx == 0: + logger.debug( + f"self.training and self.gradient_checkpointing={self.training and self.gradient_checkpointing}" + ) + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if self.print_idx == 0: + logger.debug(f"unet3d after resnet {hidden_states.mean()}") + if temp_conv is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_conv), + hidden_states, + num_frames, + sample_index, + vision_conditon_frames_sample_index, + femb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # added_cond_kwargs + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + refer_self_attn_emb, + refer_self_attn_emb_mode, + **ckpt_kwargs, + )[0] + if temp_attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_attn, return_dict=False), + hidden_states, + femb, + # None, # encoder_hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + num_frames, + cross_attention_kwargs, + sample_index, + vision_conditon_frames_sample_index, + spatial_position_emb, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + if self.print_idx == 0: + logger.debug(f"unet3d after resnet {hidden_states.mean()}") + if temp_conv is not None: + hidden_states = temp_conv( + hidden_states, + femb=femb, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + self_attn_block_embs=refer_self_attn_emb, + self_attn_block_embs_mode=refer_self_attn_emb_mode, + ).sample + if temp_attn is not None: + hidden_states = temp_attn( + hidden_states, + femb=femb, + num_frames=num_frames, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + ).sample + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + # 使用 attn 的方式 来融合 down_block_refer_emb + if self.print_idx == 0: + logger.debug( + f"downblock, {i_downblock}, self.need_refer_emb={self.need_refer_emb}" + ) + if self.need_refer_emb and refer_embs is not None: + if self.print_idx == 0: + logger.debug( + f"{i_downblock}, self.refer_emb_attns {refer_embs[i_downblock].shape}" + ) + hidden_states = self.refer_emb_attns[i_downblock]( + hidden_states, refer_embs[i_downblock], num_frames=num_frames + ) + else: + if self.print_idx == 0: + logger.debug(f"crossattndownblock refer_emb_attns, no this step") + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + # 使用 attn 的方式 来融合 down_block_refer_emb + # TODO: adain和 refer_emb的顺序 + # TODO:adain 首帧特征还是refer_emb的 + if self.need_refer_emb and refer_embs is not None: + i_downblock += 1 + hidden_states = self.refer_emb_attns[i_downblock]( + hidden_states, refer_embs[i_downblock], num_frames=num_frames + ) + output_states += (hidden_states,) + self.print_idx += 1 + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + femb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + need_refer_emb: bool = False, + attn_num_head_channels: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + self.need_refer_emb = need_refer_emb + if need_refer_emb: + refer_emb_attns = [] + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ) + if temporal_conv_block is not None: + temp_convs.append( + temporal_conv_block( + out_channels, + out_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ) + else: + temp_convs.append(None) + if need_refer_emb: + refer_emb_attns.append( + ReferEmbFuseAttention( + query_dim=out_channels, + heads=attn_num_head_channels, + dim_head=out_channels // attn_num_head_channels, + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + if need_refer_emb: + refer_emb_attns.append( + ReferEmbFuseAttention( + query_dim=out_channels, + heads=attn_num_head_channels, + dim_head=out_channels // attn_num_head_channels, + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + self.need_adain_temporal_cond = need_adain_temporal_cond + if need_refer_emb: + self.refer_emb_attns = nn.ModuleList(refer_emb_attns) + + def forward( + self, + hidden_states, + temb=None, + num_frames=1, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + femb=None, + refer_embs: Optional[Tuple[torch.Tensor]] = None, + refer_self_attn_emb: List[torch.Tensor] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + ): + output_states = () + + for i_downblock, (resnet, temp_conv) in enumerate( + zip(self.resnets, self.temp_convs) + ): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if temp_conv is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_conv), + hidden_states, + num_frames, + sample_index, + vision_conditon_frames_sample_index, + femb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + if temp_conv is not None: + hidden_states = temp_conv( + hidden_states, + femb=femb, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + if self.need_refer_emb and refer_embs is not None: + hidden_states = self.refer_emb_attns[i_downblock]( + hidden_states, refer_embs[i_downblock], num_frames=num_frames + ) + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + if self.need_refer_emb and refer_embs is not None: + i_downblock += 1 + hidden_states = self.refer_emb_attns[i_downblock]( + hidden_states, refer_embs[i_downblock], num_frames=num_frames + ) + output_states += (hidden_states,) + self.print_idx += 1 + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + femb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + temporal_transformer: Union[nn.Module, None] = TransformerTemporalModel, + need_spatial_position_emb: bool = False, + need_t2i_ip_adapter: bool = False, + ip_adapter_cross_attn: bool = False, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ) + if temporal_conv_block is not None: + temp_convs.append( + temporal_conv_block( + out_channels, + out_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ) + else: + temp_convs.append(None) + attentions.append( + Transformer2DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + cross_attn_temporal_cond=need_t2i_ip_adapter, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + ) + ) + if temporal_transformer is not None: + temp_attention = temporal_transformer( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + femb_channels=femb_channels, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + need_spatial_position_emb=need_spatial_position_emb, + ) + else: + temp_attention = None + temp_attentions.append(temp_attention) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.need_adain_temporal_cond = need_adain_temporal_cond + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + femb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + refer_self_attn_emb: List[torch.Tensor] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + ): + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if temp_conv is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_conv), + hidden_states, + num_frames, + sample_index, + vision_conditon_frames_sample_index, + femb, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + None, # timestep + None, # added_cond_kwargs + None, # class_labels + cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + refer_self_attn_emb, + refer_self_attn_emb_mode, + **ckpt_kwargs, + )[0] + if temp_attn is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_attn, return_dict=False), + hidden_states, + femb, + # None, # encoder_hidden_states, + encoder_hidden_states, + None, # timestep + None, # class_labels + num_frames, + cross_attention_kwargs, + sample_index, + vision_conditon_frames_sample_index, + spatial_position_emb, + **ckpt_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + if temp_conv is not None: + hidden_states = temp_conv( + hidden_states, + num_frames=num_frames, + femb=femb, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + self_attn_block_embs=refer_self_attn_emb, + self_attn_block_embs_mode=refer_self_attn_emb_mode, + ).sample + if temp_attn is not None: + hidden_states = temp_attn( + hidden_states, + femb=femb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=encoder_hidden_states, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + ).sample + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + self.print_idx += 1 + return hidden_states + + +class UpBlock3D(nn.Module): + print_idx = 0 + + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + femb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + temporal_conv_block: Union[nn.Module, None] = TemporalConvLayer, + need_adain_temporal_cond: bool = False, + resnet_2d_skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=resnet_2d_skip_time_act, + ) + ) + if temporal_conv_block is not None: + temp_convs.append( + temporal_conv_block( + out_channels, + out_channels, + dropout=0.1, + femb_channels=femb_channels, + ) + ) + else: + temp_convs.append(None) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.need_adain_temporal_cond = need_adain_temporal_cond + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + upsample_size=None, + num_frames=1, + sample_index: torch.LongTensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + spatial_position_emb: torch.Tensor = None, + femb=None, + refer_self_attn_emb: List[torch.Tensor] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + ): + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + if temp_conv is not None: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(temp_conv), + hidden_states, + num_frames, + sample_index, + vision_conditon_frames_sample_index, + femb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + if temp_conv is not None: + hidden_states = temp_conv( + hidden_states, + num_frames=num_frames, + femb=femb, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + ) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + if ( + self.need_adain_temporal_cond + and num_frames > 1 + and sample_index is not None + ): + if self.print_idx == 0: + logger.debug(f"adain to vision_condition") + hidden_states = batch_adain_conditioned_tensor( + hidden_states, + num_frames=num_frames, + need_style_fidelity=False, + src_index=sample_index, + dst_index=vision_conditon_frames_sample_index, + ) + self.print_idx += 1 + return hidden_states diff --git a/musev/models/unet_3d_condition.py b/musev/models/unet_3d_condition.py new file mode 100755 index 0000000000000000000000000000000000000000..1cce55790252a5569cc6a5bb423a6595bc141cbb --- /dev/null +++ b/musev/models/unet_3d_condition.py @@ -0,0 +1,1740 @@ +# Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved. +# Copyright 2023 The ModelScope Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/unet_3d_condition.py + +# 1. 增加了from_pretrained,将模型从2D blocks改为3D blocks +# 1. add from_pretrained, change model from 2D blocks to 3D blocks + +from copy import deepcopy +from dataclasses import dataclass +import inspect +from pprint import pprint, pformat +from typing import Any, Dict, List, Optional, Tuple, Union, Literal +import os +import logging + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from einops import rearrange, repeat +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import BaseOutput + +# from diffusers.utils import logging +from diffusers.models.embeddings import ( + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin, load_state_dict +from diffusers import __version__ +from diffusers.utils import ( + CONFIG_NAME, + DIFFUSERS_CACHE, + FLAX_WEIGHTS_NAME, + HF_HUB_OFFLINE, + SAFETENSORS_WEIGHTS_NAME, + WEIGHTS_NAME, + _add_variant, + _get_model_file, + is_accelerate_available, + is_torch_version, +) +from diffusers.utils.import_utils import _safetensors_available +from diffusers.models.unet_3d_condition import ( + UNet3DConditionOutput, + UNet3DConditionModel as DiffusersUNet3DConditionModel, +) +from diffusers.models.attention_processor import ( + Attention, + AttentionProcessor, + AttnProcessor, + AttnProcessor2_0, + XFormersAttnProcessor, +) + +from ..models import Model_Register + +from .resnet import TemporalConvLayer +from .temporal_transformer import ( + TransformerTemporalModel, +) +from .embeddings import get_2d_sincos_pos_embed, resize_spatial_position_emb +from .unet_3d_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from ..data.data_util import ( + adaptive_instance_normalization, + align_repeat_tensor_single_dim, + batch_adain_conditioned_tensor, + batch_concat_two_tensor_with_index, + concat_two_tensor, + concat_two_tensor_with_index, +) +from .attention_processor import BaseIPAttnProcessor +from .attention_processor import ReferEmbFuseAttention +from .transformer_2d import Transformer2DModel +from .attention import BasicTransformerBlock + + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +# if is_torch_version(">=", "1.9.0"): +# _LOW_CPU_MEM_USAGE_DEFAULT = True +# else: +# _LOW_CPU_MEM_USAGE_DEFAULT = False +_LOW_CPU_MEM_USAGE_DEFAULT = False + +if is_accelerate_available(): + import accelerate + from accelerate.utils import set_module_tensor_to_device + from accelerate.utils.versions import is_torch_version + + +import safetensors + + +def hack_t2i_sd_layer_attn_with_ip( + unet: nn.Module, + self_attn_class: BaseIPAttnProcessor = None, + cross_attn_class: BaseIPAttnProcessor = None, +): + attn_procs = {} + for name in unet.attn_processors.keys(): + if "temp_attentions" in name or "transformer_in" in name: + continue + if name.endswith("attn1.processor") and self_attn_class is not None: + attn_procs[name] = self_attn_class() + if unet.print_idx == 0: + logger.debug( + f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}" + ) + elif name.endswith("attn2.processor") and cross_attn_class is not None: + attn_procs[name] = cross_attn_class() + if unet.print_idx == 0: + logger.debug( + f"hack attn_processor of {name} to {attn_procs[name].__class__.__name__}" + ) + unet.set_attn_processor(attn_procs, strict=False) + + +def convert_2D_to_3D( + module_names, + valid_modules=( + "CrossAttnDownBlock2D", + "CrossAttnUpBlock2D", + "DownBlock2D", + "UNetMidBlock2DCrossAttn", + "UpBlock2D", + ), +): + if not isinstance(module_names, list): + return module_names.replace("2D", "3D") + + return_modules = [] + for module_name in module_names: + if module_name in valid_modules: + return_modules.append(module_name.replace("2D", "3D")) + else: + return_modules.append(module_name) + return return_modules + + +def insert_spatial_self_attn_idx(unet): + pass + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep + and returns sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the models (such as downloading or saving, etc.) + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + The tuple of upsample blocks to use. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, it will skip the normalization and activation layers in post-processing + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + """ + + _supports_gradient_checkpointing = True + print_idx = 0 + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ), + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + attention_head_dim: Union[int, Tuple[int]] = 8, + temporal_conv_block: str = "TemporalConvLayer", + temporal_transformer: str = "TransformerTemporalModel", + need_spatial_position_emb: bool = False, + need_transformer_in: bool = True, + need_t2i_ip_adapter: bool = False, # self_attn, t2i.attn1 + need_adain_temporal_cond: bool = False, + t2i_ip_adapter_attn_processor: str = "NonParamT2ISelfReferenceXFormersAttnProcessor", + keep_vision_condtion: bool = False, + use_anivv1_cfg: bool = False, + resnet_2d_skip_time_act: bool = False, + need_zero_vis_cond_temb: bool = True, + norm_spatial_length: bool = False, + spatial_max_length: int = 2048, + need_refer_emb: bool = False, + ip_adapter_cross_attn: bool = False, # cross_attn, t2i.attn2 + t2i_crossattn_ip_adapter_attn_processor: str = "T2IReferencenetIPAdapterXFormersAttnProcessor", + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + need_vis_cond_mask: bool = False, + ): + """_summary_ + + Args: + sample_size (Optional[int], optional): _description_. Defaults to None. + in_channels (int, optional): _description_. Defaults to 4. + out_channels (int, optional): _description_. Defaults to 4. + down_block_types (Tuple[str], optional): _description_. Defaults to ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ). + up_block_types (Tuple[str], optional): _description_. Defaults to ( "UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", ). + block_out_channels (Tuple[int], optional): _description_. Defaults to (320, 640, 1280, 1280). + layers_per_block (int, optional): _description_. Defaults to 2. + downsample_padding (int, optional): _description_. Defaults to 1. + mid_block_scale_factor (float, optional): _description_. Defaults to 1. + act_fn (str, optional): _description_. Defaults to "silu". + norm_num_groups (Optional[int], optional): _description_. Defaults to 32. + norm_eps (float, optional): _description_. Defaults to 1e-5. + cross_attention_dim (int, optional): _description_. Defaults to 1024. + attention_head_dim (Union[int, Tuple[int]], optional): _description_. Defaults to 8. + temporal_conv_block (str, optional): 3D卷积字符串,需要注册在 Model_Register. Defaults to "TemporalConvLayer". + temporal_transformer (str, optional): 时序 Transformer block字符串,需要定义在 Model_Register. Defaults to "TransformerTemporalModel". + need_spatial_position_emb (bool, optional): 是否需要 spatial hw 的emb,需要配合 thw attn使用. Defaults to False. + need_transformer_in (bool, optional): 是否需要 第一个 temporal_transformer_block. Defaults to True. + need_t2i_ip_adapter (bool, optional): T2I 模块是否需要面向视觉条件帧的 attn. Defaults to False. + need_adain_temporal_cond (bool, optional): 是否需要面向首帧 使用Adain. Defaults to False. + t2i_ip_adapter_attn_processor (str, optional): + t2i attn_processor的优化版,需配合need_t2i_ip_adapter使用, + 有 NonParam 表示无参ReferenceOnly-attn,没有表示有参 IpAdapter. + Defaults to "NonParamT2ISelfReferenceXFormersAttnProcessor". + keep_vision_condtion (bool, optional): 是否对视觉条件帧不加 timestep emb. Defaults to False. + use_anivv1_cfg (bool, optional): 一些基本配置 是否延续AnivV设计. Defaults to False. + resnet_2d_skip_time_act (bool, optional): 配合use_anivv1_cfg,修改 transformer 2d block. Defaults to False. + need_zero_vis_cond_temb (bool, optional): 目前无效参数. Defaults to True. + norm_spatial_length (bool, optional): 是否需要 norm_spatial_length,只有当 need_spatial_position_emb= True时,才有效. Defaults to False. + spatial_max_length (int, optional): 归一化长度. Defaults to 2048. + + Raises: + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + """ + super(UNet3DConditionModel, self).__init__() + self.keep_vision_condtion = keep_vision_condtion + self.use_anivv1_cfg = use_anivv1_cfg + self.sample_size = sample_size + self.resnet_2d_skip_time_act = resnet_2d_skip_time_act + self.need_zero_vis_cond_temb = need_zero_vis_cond_temb + self.norm_spatial_length = norm_spatial_length + self.spatial_max_length = spatial_max_length + self.need_refer_emb = need_refer_emb + self.ip_adapter_cross_attn = ip_adapter_cross_attn + self.need_t2i_facein = need_t2i_facein + self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face + + logger.debug(f"need_t2i_ip_adapter_face={need_t2i_ip_adapter_face}") + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + if use_anivv1_cfg: + self.time_nonlinearity = nn.SiLU() + + # frame + frame_embed_dim = block_out_channels[0] * 4 + self.frame_proj = Timesteps(block_out_channels[0], True, 0) + frame_input_dim = block_out_channels[0] + if temporal_transformer is not None: + self.frame_embedding = TimestepEmbedding( + frame_input_dim, + frame_embed_dim, + act_fn=act_fn, + ) + else: + self.frame_embedding = None + if use_anivv1_cfg: + self.femb_nonlinearity = nn.SiLU() + + # spatial_position_emb + self.need_spatial_position_emb = need_spatial_position_emb + if need_spatial_position_emb: + self.spatial_position_input_dim = block_out_channels[0] * 2 + self.spatial_position_embed_dim = block_out_channels[0] * 4 + + self.spatial_position_embedding = TimestepEmbedding( + self.spatial_position_input_dim, + self.spatial_position_embed_dim, + act_fn=act_fn, + ) + + # 从模型注册表中获取 模型类 + temporal_conv_block = ( + Model_Register[temporal_conv_block] + if isinstance(temporal_conv_block, str) + and temporal_conv_block.lower() != "none" + else None + ) + self.need_transformer_in = need_transformer_in + + temporal_transformer = ( + Model_Register[temporal_transformer] + if isinstance(temporal_transformer, str) + and temporal_transformer.lower() != "none" + else None + ) + self.need_vis_cond_mask = need_vis_cond_mask + + if need_transformer_in and temporal_transformer is not None: + self.transformer_in = temporal_transformer( + num_attention_heads=attention_head_dim, + attention_head_dim=block_out_channels[0] // attention_head_dim, + in_channels=block_out_channels[0], + num_layers=1, + femb_channels=frame_embed_dim, + need_spatial_position_emb=need_spatial_position_emb, + cross_attention_dim=cross_attention_dim, + ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + self.need_t2i_ip_adapter = need_t2i_ip_adapter + # 确定T2I Attn 是否加入 ReferenceOnly机制或Ipadaper机制 + # TODO:有待更好的实现机制, + need_t2i_ip_adapter_param = ( + t2i_ip_adapter_attn_processor is not None + and "NonParam" not in t2i_ip_adapter_attn_processor + and need_t2i_ip_adapter + ) + self.need_adain_temporal_cond = need_adain_temporal_cond + self.t2i_ip_adapter_attn_processor = t2i_ip_adapter_attn_processor + + if need_refer_emb: + self.first_refer_emb_attns = ReferEmbFuseAttention( + query_dim=block_out_channels[0], + heads=attention_head_dim[0], + dim_head=block_out_channels[0] // attention_head_dim[0], + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + self.mid_block_refer_emb_attns = ReferEmbFuseAttention( + query_dim=block_out_channels[-1], + heads=attention_head_dim[-1], + dim_head=block_out_channels[-1] // attention_head_dim[-1], + dropout=0, + bias=False, + cross_attention_dim=None, + upcast_attention=False, + ) + else: + self.first_refer_emb_attns = None + self.mid_block_refer_emb_attns = None + # down + output_channel = block_out_channels[0] + self.layers_per_block = layers_per_block + self.block_out_channels = block_out_channels + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + femb_channels=frame_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=False, + temporal_conv_block=temporal_conv_block, + temporal_transformer=temporal_transformer, + need_spatial_position_emb=need_spatial_position_emb, + need_t2i_ip_adapter=need_t2i_ip_adapter_param, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + need_refer_emb=need_refer_emb, + ) + self.down_blocks.append(down_block) + # mid + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + femb_channels=frame_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + temporal_conv_block=temporal_conv_block, + temporal_transformer=temporal_transformer, + need_spatial_position_emb=need_spatial_position_emb, + need_t2i_ip_adapter=need_t2i_ip_adapter_param, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + femb_channels=frame_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=False, + temporal_conv_block=temporal_conv_block, + temporal_transformer=temporal_transformer, + need_spatial_position_emb=need_spatial_position_emb, + need_t2i_ip_adapter=need_t2i_ip_adapter_param, + ip_adapter_cross_attn=ip_adapter_cross_attn, + need_t2i_facein=need_t2i_facein, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + need_adain_temporal_cond=need_adain_temporal_cond, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, + ) + self.insert_spatial_self_attn_idx() + + # 根据需要hack attn_processor,实现ip_adapter等功能 + if need_t2i_ip_adapter or ip_adapter_cross_attn: + hack_t2i_sd_layer_attn_with_ip( + self, + self_attn_class=Model_Register[t2i_ip_adapter_attn_processor] + if t2i_ip_adapter_attn_processor is not None and need_t2i_ip_adapter + else None, + cross_attn_class=Model_Register[t2i_crossattn_ip_adapter_attn_processor] + if t2i_crossattn_ip_adapter_attn_processor is not None + and ( + ip_adapter_cross_attn or need_t2i_facein or need_t2i_ip_adapter_face + ) + else None, + ) + # logger.debug(pformat(self.attn_processors)) + + # 非参数AttnProcessor,就不需要to_k_ip、to_v_ip参数了 + if ( + t2i_ip_adapter_attn_processor is None + or "NonParam" in t2i_ip_adapter_attn_processor + ): + need_t2i_ip_adapter = False + + if self.print_idx == 0: + logger.debug("Unet3Model Parameters") + # logger.debug(pformat(self.__dict__)) + + # 会在 set_skip_temporal_layers 设置 skip_refer_downblock_emb + # 当为 True 时,会跳过 referencenet_block_emb的影响,主要用于首帧生成 + self.skip_refer_downblock_emb = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): + if hasattr(module, "set_processor"): + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = ( + num_sliceable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, + processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], + strict: bool = True, + ): + r""" + Parameters: + `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + of **all** `Attention` layers. + In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count and strict: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + logger.debug( + f"module {name} set attn processor {processor.__class__.__name__}" + ) + module.set_processor(processor) + else: + if f"{name}.processor" in processor: + logger.debug( + "module {} set attn processor {}".format( + name, processor[f"{name}.processor"].__class__.__name__ + ) + ) + module.set_processor(processor.pop(f"{name}.processor")) + else: + logger.debug( + f"module {name} has no new target attn_processor, still use {module.processor.__class__.__name__} " + ) + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance( + module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D) + ): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + sample_index: torch.LongTensor = None, + vision_condition_frames_sample: torch.Tensor = None, + vision_conditon_frames_sample_index: torch.LongTensor = None, + sample_frame_rate: int = 10, + skip_temporal_layers: bool = None, + frame_index: torch.LongTensor = None, + down_block_refer_embs: Optional[Tuple[torch.Tensor]] = None, + mid_block_refer_emb: Optional[torch.Tensor] = None, + refer_self_attn_emb: Optional[List[torch.Tensor]] = None, + refer_self_attn_emb_mode: Literal["read", "write"] = "read", + vision_clip_emb: torch.Tensor = None, + ip_adapter_scale: float = 1.0, + face_emb: torch.Tensor = None, + facein_scale: float = 1.0, + ip_adapter_face_emb: torch.Tensor = None, + ip_adapter_face_scale: float = 1.0, + do_classifier_free_guidance: bool = False, + pose_guider_emb: torch.Tensor = None, + ) -> Union[UNet3DConditionOutput, Tuple]: + """_summary_ + + Args: + sample (torch.FloatTensor): _description_ + timestep (Union[torch.Tensor, float, int]): _description_ + encoder_hidden_states (torch.Tensor): _description_ + class_labels (Optional[torch.Tensor], optional): _description_. Defaults to None. + timestep_cond (Optional[torch.Tensor], optional): _description_. Defaults to None. + attention_mask (Optional[torch.Tensor], optional): _description_. Defaults to None. + cross_attention_kwargs (Optional[Dict[str, Any]], optional): _description_. Defaults to None. + down_block_additional_residuals (Optional[Tuple[torch.Tensor]], optional): _description_. Defaults to None. + mid_block_additional_residual (Optional[torch.Tensor], optional): _description_. Defaults to None. + return_dict (bool, optional): _description_. Defaults to True. + sample_index (torch.LongTensor, optional): _description_. Defaults to None. + vision_condition_frames_sample (torch.Tensor, optional): _description_. Defaults to None. + vision_conditon_frames_sample_index (torch.LongTensor, optional): _description_. Defaults to None. + sample_frame_rate (int, optional): _description_. Defaults to 10. + skip_temporal_layers (bool, optional): _description_. Defaults to None. + frame_index (torch.LongTensor, optional): _description_. Defaults to None. + up_block_additional_residual (Optional[torch.Tensor], optional): 用于up_block的 参考latent. Defaults to None. + down_block_refer_embs (Optional[torch.Tensor], optional): 用于 download 的 参考latent. Defaults to None. + how_fuse_referencenet_emb (Literal, optional): 如何融合 参考 latent. Defaults to ["add", "attn"]="add". + add: 要求 additional_latent 和 latent hw 同尺寸. hw of addtional_latent should be same as of latent + attn: concat bt*h1w1*c and bt*h2w2*c into bt*(h1w1+h2w2)*c, and then as key,value into attn + Raises: + ValueError: _description_ + + Returns: + Union[UNet3DConditionOutput, Tuple]: _description_ + """ + + if skip_temporal_layers is not None: + self.set_skip_temporal_layers(skip_temporal_layers) + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + # logger.debug("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + batch_size = sample.shape[0] + + # when vision_condition_frames_sample is not None and vision_conditon_frames_sample_index is not None + # if not None, b c t h w -> b c (t + n_content ) h w + + if vision_condition_frames_sample is not None: + sample = batch_concat_two_tensor_with_index( + sample, + sample_index, + vision_condition_frames_sample, + vision_conditon_frames_sample_index, + dim=2, + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + batch_size, channel, num_frames, height, width = sample.shape + + # 准备 timestep emb + timesteps = timesteps.expand(sample.shape[0]) + temb = self.time_proj(timesteps) + temb = temb.to(dtype=self.dtype) + emb = self.time_embedding(temb, timestep_cond) + if self.use_anivv1_cfg: + emb = self.time_nonlinearity(emb) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + # 一致性保持,使条件时序帧的 首帧 timesteps emb 为 0,即不影响视觉条件帧 + # keep consistent with the first frame of vision condition frames + if ( + self.keep_vision_condtion + and num_frames > 1 + and sample_index is not None + and vision_conditon_frames_sample_index is not None + ): + emb = rearrange(emb, "(b t) d -> b t d", t=num_frames) + emb[:, vision_conditon_frames_sample_index, :] = 0 + emb = rearrange(emb, "b t d->(b t) d") + + # temporal positional embedding + femb = None + if self.temporal_transformer is not None: + if frame_index is None: + frame_index = torch.arange( + num_frames, dtype=torch.long, device=sample.device + ) + if self.use_anivv1_cfg: + frame_index = (frame_index * sample_frame_rate).to(dtype=torch.long) + femb = self.frame_proj(frame_index) + if self.print_idx == 0: + logger.debug( + f"unet prepare frame_index, {femb.shape}, {batch_size}" + ) + femb = repeat(femb, "t d-> b t d", b=batch_size) + else: + # b t -> b t d + assert frame_index.ndim == 2, ValueError( + "ndim of given frame_index should be 2, but {frame_index.ndim}" + ) + femb = torch.stack( + [self.frame_proj(frame_index[i]) for i in range(batch_size)], dim=0 + ) + if self.temporal_transformer is not None: + femb = femb.to(dtype=self.dtype) + femb = self.frame_embedding( + femb, + ) + if self.use_anivv1_cfg: + femb = self.femb_nonlinearity(femb) + if encoder_hidden_states.ndim == 3: + encoder_hidden_states = align_repeat_tensor_single_dim( + encoder_hidden_states, target_length=emb.shape[0], dim=0 + ) + elif encoder_hidden_states.ndim == 4: + encoder_hidden_states = rearrange( + encoder_hidden_states, "b t n q-> (b t) n q" + ) + else: + raise ValueError( + f"only support ndim in [3, 4], but given {encoder_hidden_states.ndim}" + ) + if vision_clip_emb is not None: + if vision_clip_emb.ndim == 4: + vision_clip_emb = rearrange(vision_clip_emb, "b t n q-> (b t) n q") + # 准备 hw 层面的 spatial positional embedding + # prepare spatial_position_emb + if self.need_spatial_position_emb: + # height * width, self.spatial_position_input_dim + spatial_position_emb = get_2d_sincos_pos_embed( + embed_dim=self.spatial_position_input_dim, + grid_size_w=width, + grid_size_h=height, + cls_token=False, + norm_length=self.norm_spatial_length, + max_length=self.spatial_max_length, + ) + spatial_position_emb = torch.from_numpy(spatial_position_emb).to( + device=sample.device, dtype=self.dtype + ) + # height * width, self.spatial_position_embed_dim + spatial_position_emb = self.spatial_position_embedding(spatial_position_emb) + else: + spatial_position_emb = None + + # prepare cross_attention_kwargs,ReferenceOnly/IpAdapter的attn_processor需要这些参数 进行 latenst和viscond_latents拆分运算 + if ( + self.need_t2i_ip_adapter + or self.ip_adapter_cross_attn + or self.need_t2i_facein + or self.need_t2i_ip_adapter_face + ): + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + cross_attention_kwargs["num_frames"] = num_frames + cross_attention_kwargs[ + "do_classifier_free_guidance" + ] = do_classifier_free_guidance + cross_attention_kwargs["sample_index"] = sample_index + cross_attention_kwargs[ + "vision_conditon_frames_sample_index" + ] = vision_conditon_frames_sample_index + if self.ip_adapter_cross_attn: + cross_attention_kwargs["vision_clip_emb"] = vision_clip_emb + cross_attention_kwargs["ip_adapter_scale"] = ip_adapter_scale + if self.need_t2i_facein: + if self.print_idx == 0: + logger.debug( + f"face_emb={type(face_emb)}, facein_scale={facein_scale}" + ) + cross_attention_kwargs["face_emb"] = face_emb + cross_attention_kwargs["facein_scale"] = facein_scale + if self.need_t2i_ip_adapter_face: + if self.print_idx == 0: + logger.debug( + f"ip_adapter_face_emb={type(ip_adapter_face_emb)}, ip_adapter_face_scale={ip_adapter_face_scale}" + ) + cross_attention_kwargs["ip_adapter_face_emb"] = ip_adapter_face_emb + cross_attention_kwargs["ip_adapter_face_scale"] = ip_adapter_face_scale + # 2. pre-process + sample = rearrange(sample, "b c t h w -> (b t) c h w") + sample = self.conv_in(sample) + + if pose_guider_emb is not None: + if self.print_idx == 0: + logger.debug( + f"sample={sample.shape}, pose_guider_emb={pose_guider_emb.shape}" + ) + sample = sample + pose_guider_emb + + if self.print_idx == 0: + logger.debug(f"after conv in sample={sample.mean()}") + if spatial_position_emb is not None: + if self.print_idx == 0: + logger.debug( + f"unet3d, transformer_in, spatial_position_emb={spatial_position_emb.shape}" + ) + if self.print_idx == 0: + logger.debug( + f"unet vision_conditon_frames_sample_index, {type(vision_conditon_frames_sample_index)}", + ) + if vision_conditon_frames_sample_index is not None: + if self.print_idx == 0: + logger.debug( + f"vision_conditon_frames_sample_index shape {vision_conditon_frames_sample_index.shape}", + ) + if self.print_idx == 0: + logger.debug(f"unet sample_index {type(sample_index)}") + if sample_index is not None: + if self.print_idx == 0: + logger.debug(f"sample_index shape {sample_index.shape}") + if self.need_transformer_in: + if self.print_idx == 0: + logger.debug(f"unet3d, transformer_in, sample={sample.shape}") + sample = self.transformer_in( + sample, + femb=femb, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=encoder_hidden_states, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + ).sample + if ( + self.need_refer_emb + and down_block_refer_embs is not None + and not self.skip_refer_downblock_emb + ): + if self.print_idx == 0: + logger.debug( + f"self.first_refer_emb_attns, {self.first_refer_emb_attns.__class__.__name__} {down_block_refer_embs[0].shape}" + ) + sample = self.first_refer_emb_attns( + sample, down_block_refer_embs[0], num_frames=num_frames + ) + if self.print_idx == 0: + logger.debug( + f"first_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, down_block_refer_embs, {down_block_refer_embs[0].is_leaf}, {down_block_refer_embs[0].requires_grad}," + ) + else: + if self.print_idx == 0: + logger.debug(f"first_refer_emb_attns, no this step") + # 将 refer_self_attn_emb 转化成字典,增加一个当前index,表示block 的对应关系 + # convert refer_self_attn_emb to dict, add a current index to represent the corresponding relationship of the block + + # 3. down + down_block_res_samples = (sample,) + for i_down_block, downsample_block in enumerate(self.down_blocks): + # 使用 attn 的方式 来融合 refer_emb,这里是准备 downblock 对应的 refer_emb + # fuse refer_emb with attn, here is to prepare the refer_emb corresponding to downblock + if ( + not self.need_refer_emb + or down_block_refer_embs is None + or self.skip_refer_downblock_emb + ): + this_down_block_refer_embs = None + if self.print_idx == 0: + logger.debug( + f"{i_down_block}, prepare this_down_block_refer_embs, is None" + ) + else: + is_final_block = i_down_block == len(self.block_out_channels) - 1 + num_block = self.layers_per_block + int(not is_final_block * 1) + this_downblock_start_idx = 1 + num_block * i_down_block + this_down_block_refer_embs = down_block_refer_embs[ + this_downblock_start_idx : this_downblock_start_idx + num_block + ] + if self.print_idx == 0: + logger.debug( + f"prepare this_down_block_refer_embs, {len(this_down_block_refer_embs)}, {this_down_block_refer_embs[0].shape}" + ) + if self.print_idx == 0: + logger.debug(f"downsample_block {i_down_block}, sample={sample.mean()}") + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + femb=femb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + refer_embs=this_down_block_refer_embs, + refer_self_attn_emb=refer_self_attn_emb, + refer_self_attn_emb_mode=refer_self_attn_emb_mode, + ) + else: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + femb=femb, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + refer_embs=this_down_block_refer_embs, + refer_self_attn_emb=refer_self_attn_emb, + refer_self_attn_emb_mode=refer_self_attn_emb_mode, + ) + + # resize spatial_position_emb + if self.need_spatial_position_emb: + has_downblock = i_down_block < len(self.down_blocks) - 1 + if has_downblock: + spatial_position_emb = resize_spatial_position_emb( + spatial_position_emb, + scale=0.5, + height=sample.shape[2] * 2, + width=sample.shape[3] * 2, + ) + down_block_res_samples += res_samples + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + hidden_states=sample, + temb=emb, + femb=femb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + refer_self_attn_emb=refer_self_attn_emb, + refer_self_attn_emb_mode=refer_self_attn_emb_mode, + ) + # 使用 attn 的方式 来融合 mid_block_refer_emb + # fuse mid_block_refer_emb with attn + if ( + self.mid_block_refer_emb_attns is not None + and mid_block_refer_emb is not None + and not self.skip_refer_downblock_emb + ): + if self.print_idx == 0: + logger.debug( + f"self.mid_block_refer_emb_attns={self.mid_block_refer_emb_attns}, mid_block_refer_emb={mid_block_refer_emb.shape}" + ) + sample = self.mid_block_refer_emb_attns( + sample, mid_block_refer_emb, num_frames=num_frames + ) + if self.print_idx == 0: + logger.debug( + f"mid_block_refer_emb_attns, sample is_leaf={sample.is_leaf}, requires_grad={sample.requires_grad}, mid_block_refer_emb, {mid_block_refer_emb[0].is_leaf}, {mid_block_refer_emb[0].requires_grad}," + ) + else: + if self.print_idx == 0: + logger.debug(f"mid_block_refer_emb_attns, no this step") + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i_up_block, upsample_block in enumerate(self.up_blocks): + is_final_block = i_up_block == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): + sample = upsample_block( + hidden_states=sample, + temb=emb, + femb=femb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + refer_self_attn_emb=refer_self_attn_emb, + refer_self_attn_emb_mode=refer_self_attn_emb_mode, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + femb=femb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + sample_index=sample_index, + vision_conditon_frames_sample_index=vision_conditon_frames_sample_index, + spatial_position_emb=spatial_position_emb, + refer_self_attn_emb=refer_self_attn_emb, + refer_self_attn_emb_mode=refer_self_attn_emb_mode, + ) + # resize spatial_position_emb + if self.need_spatial_position_emb: + has_upblock = i_up_block < len(self.up_blocks) - 1 + if has_upblock: + spatial_position_emb = resize_spatial_position_emb( + spatial_position_emb, + scale=2, + height=int(sample.shape[2] / 2), + width=int(sample.shape[3] / 2), + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + sample = rearrange(sample, "(b t) c h w -> b c t h w", t=num_frames) + + # if self.need_adain_temporal_cond and num_frames > 1: + # sample = batch_adain_conditioned_tensor( + # sample, + # num_frames=num_frames, + # need_style_fidelity=False, + # src_index=sample_index, + # dst_index=vision_conditon_frames_sample_index, + # ) + self.print_idx += 1 + + if skip_temporal_layers is not None: + self.set_skip_temporal_layers(not skip_temporal_layers) + if not return_dict: + return (sample,) + else: + return UNet3DConditionOutput(sample=sample) + + # from https://github.com/huggingface/diffusers/blob/v0.16.1/src/diffusers/models/modeling_utils.py#L328 + @classmethod + def from_pretrained_2d( + cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs + ): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., + `./my_model_directory/`. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype + will be automatically derived from the model's weights. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file. + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, + setting this argument to `True` will raise an error. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_flax`. + use_safetensors (`bool`, *optional* ): + If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to + `None` (the default). The pipeline will load using `safetensors` if safetensors weights are available + *and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use + this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None) + strict = kwargs.pop("strict", True) + + allow_pickle = False + if use_safetensors is None: + allow_pickle = True + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if device_map is not None and not is_accelerate_available(): + raise NotImplementedError( + "Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" + " `device_map=None`. You can install accelerate with `pip install accelerate`." + ) + + # Check if we can handle device_map and dispatching the weights + if device_map is not None and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `device_map=None`." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + if low_cpu_mem_usage is False and device_map is not None: + raise ValueError( + f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" + " dispatching. Please make sure to set `low_cpu_mem_usage=True`." + ) + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + # load config + config, unused_kwargs, commit_hash = cls.load_config( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + return_commit_hash=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + user_agent=user_agent, + **kwargs, + ) + + config["_class_name"] = cls.__name__ + config["down_block_types"] = convert_2D_to_3D(config["down_block_types"]) + if "mid_block_type" in config: + config["mid_block_type"] = convert_2D_to_3D(config["mid_block_type"]) + else: + config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + config["up_block_types"] = convert_2D_to_3D(config["up_block_types"]) + + # load model + model_file = None + if from_flax: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=FLAX_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + model = cls.from_config(config, **unused_kwargs) + + # Convert the weights + from diffusers.models.modeling_pytorch_flax_utils import ( + load_flax_checkpoint_in_pytorch_model, + ) + + model = load_flax_checkpoint_in_pytorch_model(model, model_file) + else: + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + except IOError as e: + if not allow_pickle: + raise e + pass + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=_add_variant(WEIGHTS_NAME, variant), + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + commit_hash=commit_hash, + ) + + if low_cpu_mem_usage: + # Instantiate model with empty weights + with accelerate.init_empty_weights(): + model = cls.from_config(config, **unused_kwargs) + + # if device_map is None, load the state dict and move the params from meta device to the cpu + if device_map is None: + param_device = "cpu" + state_dict = load_state_dict(model_file, variant=variant) + # move the params from meta device to cpu + missing_keys = set(model.state_dict().keys()) - set( + state_dict.keys() + ) + if len(missing_keys) > 0: + if strict: + raise ValueError( + f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" + f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" + " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" + " those weights or else make sure your checkpoint file is correct." + ) + else: + logger.warning( + f"model{cls} has no target pretrained paramter from {pretrained_model_name_or_path}, {', '.join(missing_keys)}" + ) + + empty_state_dict = model.state_dict() + for param_name, param in state_dict.items(): + accepts_dtype = "dtype" in set( + inspect.signature( + set_module_tensor_to_device + ).parameters.keys() + ) + + if empty_state_dict[param_name].shape != param.shape: + raise ValueError( + f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." + ) + + if accepts_dtype: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=torch_dtype, + ) + else: + set_module_tensor_to_device( + model, param_name, param_device, value=param + ) + else: # else let accelerate handle loading and dispatching. + # Load weights and dispatch according to the device_map + # by default the device_map is None and the weights are loaded on the CPU + accelerate.load_checkpoint_and_dispatch( + model, model_file, device_map, dtype=torch_dtype + ) + + loading_info = { + "missing_keys": [], + "unexpected_keys": [], + "mismatched_keys": [], + "error_msgs": [], + } + else: + model = cls.from_config(config, **unused_kwargs) + + state_dict = load_state_dict(model_file, variant=variant) + + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + model_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + + if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): + raise ValueError( + f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." + ) + elif torch_dtype is not None: + model = model.to(torch_dtype) + + model.register_to_config(_name_or_path=pretrained_model_name_or_path) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + if output_loading_info: + return model, loading_info + + return model + + def set_skip_temporal_layers( + self, + valid: bool, + ) -> None: # turn 3Dunet to 2Dunet + # Recursively walk through all the children. + # Any children which exposes the skip_temporal_layers parameter gets the message + + # 推断时使用参数控制refer_image和ip_adapter_image来控制,不需要这里了 + # if hasattr(self, "skip_refer_downblock_emb"): + # self.skip_refer_downblock_emb = valid + + def fn_recursive_set_mem_eff(module: torch.nn.Module): + if hasattr(module, "skip_temporal_layers"): + module.skip_temporal_layers = valid + # if hasattr(module, "skip_refer_downblock_emb"): + # module.skip_refer_downblock_emb = valid + + for child in module.children(): + fn_recursive_set_mem_eff(child) + + for module in self.children(): + if isinstance(module, torch.nn.Module): + fn_recursive_set_mem_eff(module) + + def insert_spatial_self_attn_idx(self): + attns, basic_transformers = self.spatial_self_attns + self.self_attn_num = len(attns) + for i, (name, layer) in enumerate(attns): + logger.debug( + f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}" + ) + layer.spatial_self_attn_idx = i + for i, (name, layer) in enumerate(basic_transformers): + logger.debug( + f"{self.__class__.__name__}, {i}, {name}, {layer.__class__.__name__}" + ) + layer.spatial_self_attn_idx = i + + @property + def spatial_self_attns( + self, + ) -> List[Tuple[str, Attention]]: + attns, spatial_transformers = self.get_attns( + include="attentions", exclude="temp_attentions", attn_name="attn1" + ) + attns = sorted(attns) + spatial_transformers = sorted(spatial_transformers) + return attns, spatial_transformers + + @property + def spatial_cross_attns( + self, + ) -> List[Tuple[str, Attention]]: + attns, spatial_transformers = self.get_attns( + include="attentions", exclude="temp_attentions", attn_name="attn2" + ) + attns = sorted(attns) + spatial_transformers = sorted(spatial_transformers) + return attns, spatial_transformers + + def get_attns( + self, + attn_name: str, + include: str = None, + exclude: str = None, + ) -> List[Tuple[str, Attention]]: + r""" + Returns: + `dict` of attention attns: A dictionary containing all attention attns used in the model with + indexed by its weight name. + """ + # set recursively + attns = [] + spatial_transformers = [] + + def fn_recursive_add_attns( + name: str, + module: torch.nn.Module, + attns: List[Tuple[str, Attention]], + spatial_transformers: List[Tuple[str, BasicTransformerBlock]], + ): + is_target = False + if isinstance(module, BasicTransformerBlock) and hasattr(module, attn_name): + is_target = True + if include is not None: + is_target = include in name + if exclude is not None: + is_target = exclude not in name + if is_target: + attns.append([f"{name}.{attn_name}", getattr(module, attn_name)]) + spatial_transformers.append([f"{name}", module]) + for sub_name, child in module.named_children(): + fn_recursive_add_attns( + f"{name}.{sub_name}", child, attns, spatial_transformers + ) + + return attns + + for name, module in self.named_children(): + fn_recursive_add_attns(name, module, attns, spatial_transformers) + + return attns, spatial_transformers diff --git a/musev/models/unet_loader.py b/musev/models/unet_loader.py new file mode 100755 index 0000000000000000000000000000000000000000..965e84099bda3871ffd11fe20027f28c305ecdef --- /dev/null +++ b/musev/models/unet_loader.py @@ -0,0 +1,273 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, +) +from diffusers.utils.import_utils import is_xformers_available + +from ..models.unet_3d_condition import UNet3DConditionModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def update_unet_with_sd( + unet: nn.Module, sd_model: Tuple[str, nn.Module], subfolder: str = "unet" +): + """更新T2V模型中的T2I参数. update t2i parameters in t2v model + + Args: + unet (nn.Module): _description_ + sd_model (Tuple[str, nn.Module]): _description_ + + Returns: + _type_: _description_ + """ + # dtype = unet.dtype + # TODO: in this way, sd_model_path must be absolute path, to be more dynamic + if isinstance(sd_model, str): + if os.path.isdir(sd_model): + unet_state_dict = load_state_dict( + os.path.join(sd_model, subfolder, "diffusion_pytorch_model.bin"), + ) + elif os.path.isfile(sd_model): + if sd_model.endswith("pth"): + unet_state_dict = torch.load(sd_model, map_location="cpu") + print(f"referencenet successful load ={sd_model} with torch.load") + else: + try: + unet_state_dict = load_state_dict(sd_model) + print( + f"referencenet successful load with {sd_model} with load_state_dict" + ) + except Exception as e: + print(e) + + elif isinstance(sd_model, nn.Module): + unet_state_dict = sd_model.state_dict() + else: + raise ValueError(f"given {type(sd_model)}, but only support nn.Module or str") + missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) + assert len(unexpected) == 0, f"unet load_state_dict error, unexpected={unexpected}" + # unet.to(dtype=dtype) + return unet + + +def load_unet( + sd_unet_model: Tuple[str, nn.Module], + sd_model: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + temporal_transformer: str = "TransformerTemporalModel", + temporal_conv_block: str = "TemporalConvLayer", + need_spatial_position_emb: bool = False, + need_transformer_in: bool = True, + need_t2i_ip_adapter: bool = False, + need_adain_temporal_cond: bool = False, + t2i_ip_adapter_attn_processor: str = "IPXFormersAttnProcessor", + keep_vision_condtion: bool = False, + use_anivv1_cfg: bool = False, + resnet_2d_skip_time_act: bool = False, + dtype: torch.dtype = torch.float16, + need_zero_vis_cond_temb: bool = True, + norm_spatial_length: bool = True, + spatial_max_length: int = 2048, + need_refer_emb: bool = False, + ip_adapter_cross_attn=False, + t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + strict: bool = True, +): + """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. + 该部分都是通过 models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 + model is defined and trained in models.unet_3d_condition.py:UNet3DConditionModel + + Args: + sd_unet_model (Tuple[str, nn.Module]): _description_ + sd_model (Tuple[str, nn.Module]): _description_ + cross_attention_dim (int, optional): _description_. Defaults to 768. + temporal_transformer (str, optional): _description_. Defaults to "TransformerTemporalModel". + temporal_conv_block (str, optional): _description_. Defaults to "TemporalConvLayer". + need_spatial_position_emb (bool, optional): _description_. Defaults to False. + need_transformer_in (bool, optional): _description_. Defaults to True. + need_t2i_ip_adapter (bool, optional): _description_. Defaults to False. + need_adain_temporal_cond (bool, optional): _description_. Defaults to False. + t2i_ip_adapter_attn_processor (str, optional): _description_. Defaults to "IPXFormersAttnProcessor". + keep_vision_condtion (bool, optional): _description_. Defaults to False. + use_anivv1_cfg (bool, optional): _description_. Defaults to False. + resnet_2d_skip_time_act (bool, optional): _description_. Defaults to False. + dtype (torch.dtype, optional): _description_. Defaults to torch.float16. + need_zero_vis_cond_temb (bool, optional): _description_. Defaults to True. + norm_spatial_length (bool, optional): _description_. Defaults to True. + spatial_max_length (int, optional): _description_. Defaults to 2048. + + Returns: + _type_: _description_ + """ + if isinstance(sd_unet_model, str): + unet = UNet3DConditionModel.from_pretrained_2d( + sd_unet_model, + subfolder="unet", + temporal_transformer=temporal_transformer, + temporal_conv_block=temporal_conv_block, + cross_attention_dim=cross_attention_dim, + need_spatial_position_emb=need_spatial_position_emb, + need_transformer_in=need_transformer_in, + need_t2i_ip_adapter=need_t2i_ip_adapter, + need_adain_temporal_cond=need_adain_temporal_cond, + t2i_ip_adapter_attn_processor=t2i_ip_adapter_attn_processor, + keep_vision_condtion=keep_vision_condtion, + use_anivv1_cfg=use_anivv1_cfg, + resnet_2d_skip_time_act=resnet_2d_skip_time_act, + torch_dtype=dtype, + need_zero_vis_cond_temb=need_zero_vis_cond_temb, + norm_spatial_length=norm_spatial_length, + spatial_max_length=spatial_max_length, + need_refer_emb=need_refer_emb, + ip_adapter_cross_attn=ip_adapter_cross_attn, + t2i_crossattn_ip_adapter_attn_processor=t2i_crossattn_ip_adapter_attn_processor, + need_t2i_facein=need_t2i_facein, + strict=strict, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + ) + elif isinstance(sd_unet_model, nn.Module): + unet = sd_unet_model + if sd_model is not None: + unet = update_unet_with_sd(unet, sd_model) + return unet + + +def load_unet_custom_unet( + sd_unet_model: Tuple[str, nn.Module], + sd_model: Tuple[str, nn.Module], + unet_class: nn.Module, +): + """ + 通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. + 该部分都是通过 不通过models.unet_3d_condition.py:UNet3DConditionModel 定义、训练的模型 + model is not defined in models.unet_3d_condition.py:UNet3DConditionModel + Args: + sd_unet_model (Tuple[str, nn.Module]): _description_ + sd_model (Tuple[str, nn.Module]): _description_ + unet_class (nn.Module): _description_ + + Returns: + _type_: _description_ + """ + if isinstance(sd_unet_model, str): + unet = unet_class.from_pretrained( + sd_unet_model, + subfolder="unet", + ) + elif isinstance(sd_unet_model, nn.Module): + unet = sd_unet_model + + # TODO: in this way, sd_model_path must be absolute path, to be more dynamic + if isinstance(sd_model, str): + unet_state_dict = load_state_dict( + os.path.join(sd_model, "unet/diffusion_pytorch_model.bin"), + ) + elif isinstance(sd_model, nn.Module): + unet_state_dict = sd_model.state_dict() + missing, unexpected = unet.load_state_dict(unet_state_dict, strict=False) + assert ( + len(unexpected) == 0 + ), "unet load_state_dict error" # Load scheduler, tokenizer and models. + return unet + + +def load_unet_by_name( + model_name: str, + sd_unet_model: Tuple[str, nn.Module], + sd_model: Tuple[str, nn.Module] = None, + cross_attention_dim: int = 768, + dtype: torch.dtype = torch.float16, + need_t2i_facein: bool = False, + need_t2i_ip_adapter_face: bool = False, + strict: bool = True, +) -> nn.Module: + """通过模型名字 初始化Unet,载入预训练参数. init unet with model_name. + 如希望后续通过简单名字就可以使用预训练模型,需要在这里完成定义 + if you want to use pretrained model with simple name, you need to define it here. + Args: + model_name (str): _description_ + sd_unet_model (Tuple[str, nn.Module]): _description_ + sd_model (Tuple[str, nn.Module]): _description_ + cross_attention_dim (int, optional): _description_. Defaults to 768. + dtype (torch.dtype, optional): _description_. Defaults to torch.float16. + + Raises: + ValueError: _description_ + + Returns: + nn.Module: _description_ + """ + if model_name in ["musev"]: + unet = load_unet( + sd_unet_model=sd_unet_model, + sd_model=sd_model, + need_spatial_position_emb=False, + cross_attention_dim=cross_attention_dim, + need_t2i_ip_adapter=True, + need_adain_temporal_cond=True, + t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", + dtype=dtype, + ) + elif model_name in [ + "musev_referencenet", + "musev_referencenet_pose", + ]: + unet = load_unet( + sd_unet_model=sd_unet_model, + sd_model=sd_model, + cross_attention_dim=cross_attention_dim, + temporal_conv_block="TemporalConvLayer", + need_transformer_in=False, + temporal_transformer="TransformerTemporalModel", + use_anivv1_cfg=True, + resnet_2d_skip_time_act=True, + need_t2i_ip_adapter=True, + need_adain_temporal_cond=True, + keep_vision_condtion=True, + t2i_ip_adapter_attn_processor="NonParamReferenceIPXFormersAttnProcessor", + dtype=dtype, + need_refer_emb=True, + need_zero_vis_cond_temb=True, + ip_adapter_cross_attn=True, + t2i_crossattn_ip_adapter_attn_processor="T2IReferencenetIPAdapterXFormersAttnProcessor", + need_t2i_facein=need_t2i_facein, + strict=strict, + need_t2i_ip_adapter_face=need_t2i_ip_adapter_face, + ) + else: + raise ValueError( + f"unsupport model_name={model_name}, only support musev, musev_referencenet, musev_referencenet_pose" + ) + return unet diff --git a/musev/pipelines/__init__.py b/musev/pipelines/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/musev/pipelines/context.py b/musev/pipelines/context.py new file mode 100755 index 0000000000000000000000000000000000000000..6d55ca6a75235f3917ab566163f9d8f5bd61c62e --- /dev/null +++ b/musev/pipelines/context.py @@ -0,0 +1,149 @@ +# TODO: Adapted from cli +import math +from typing import Callable, List, Optional + +import numpy as np + +from mmcm.utils.itertools_util import generate_sample_idxs + +# copy from https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/context.py + + +def ordered_halving(val): + bin_str = f"{val:064b}" + bin_flip = bin_str[::-1] + as_int = int(bin_flip, 2) + + return as_int / (1 << 64) + + +# TODO: closed_loop not work, to fix it +def uniform( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + if num_frames <= context_size: + yield list(range(num_frames)) + return + + context_stride = min( + context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1 + ) + + for context_step in 1 << np.arange(context_stride): + pad = int(round(num_frames * ordered_halving(step))) + for j in range( + int(ordered_halving(step) * context_step) + pad, + num_frames + pad + (0 if closed_loop else -context_overlap), + (context_size * context_step - context_overlap), + ): + yield [ + e % num_frames + for e in range(j, j + context_size * context_step, context_step) + ] + + +def uniform_v2( + step: int = ..., + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + return generate_sample_idxs( + total=num_frames, + window_size=context_size, + step=context_size - context_overlap, + sample_rate=1, + drop_last=False, + ) + + +def get_context_scheduler(name: str) -> Callable: + if name == "uniform": + return uniform + elif name == "uniform_v2": + return uniform_v2 + else: + raise ValueError(f"Unknown context_overlap policy {name}") + + +def get_total_steps( + scheduler, + timesteps: List[int], + num_steps: Optional[int] = None, + num_frames: int = ..., + context_size: Optional[int] = None, + context_stride: int = 3, + context_overlap: int = 4, + closed_loop: bool = True, +): + return sum( + len( + list( + scheduler( + i, + num_steps, + num_frames, + context_size, + context_stride, + context_overlap, + ) + ) + ) + for i in range(len(timesteps)) + ) + + +def drop_last_repeat_context(contexts: List[List[int]]) -> List[List[int]]: + """if len(contexts)>=2 and the max value the oenultimate list same as of the last list + + Args: + List (_type_): _description_ + + Returns: + List[List[int]]: _description_ + """ + if len(contexts) >= 2 and contexts[-1][-1] == contexts[-2][-1]: + return contexts[:-1] + else: + return contexts + + +def prepare_global_context( + context_schedule: str, + num_inference_steps: int, + time_size: int, + context_frames: int, + context_stride: int, + context_overlap: int, + context_batch_size: int, +): + context_scheduler = get_context_scheduler(context_schedule) + context_queue = list( + context_scheduler( + step=0, + num_steps=num_inference_steps, + num_frames=time_size, + context_size=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + ) + ) + # 如果context_queue的最后一个索引最大值和倒数第二个索引最大值相同,说明最后一个列表就是因为step带来的冗余项,可以去掉 + # remove the last context if max index of the last context is the same as the max index of the second last context + context_queue = drop_last_repeat_context(context_queue) + num_context_batches = math.ceil(len(context_queue) / context_batch_size) + global_context = [] + for i_tmp in range(num_context_batches): + global_context.append( + context_queue[i_tmp * context_batch_size : (i_tmp + 1) * context_batch_size] + ) + return global_context diff --git a/musev/pipelines/pipeline_controlnet.py b/musev/pipelines/pipeline_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..814ae0dcfdf7605ff448d21800c60ff74a3deb5b --- /dev/null +++ b/musev/pipelines/pipeline_controlnet.py @@ -0,0 +1,2202 @@ +from __future__ import annotations + +import inspect +import math +import time +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from dataclasses import dataclass + +from einops import rearrange, repeat +import PIL.Image +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from diffusers.pipelines.controlnet.pipeline_controlnet import ( + StableDiffusionSafetyChecker, + EXAMPLE_DOC_STRING, +) +from diffusers.pipelines.controlnet.pipeline_controlnet_img2img import ( + StableDiffusionControlNetImg2ImgPipeline as DiffusersStableDiffusionControlNetImg2ImgPipeline, +) +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) + +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + logging, + BaseOutput, + replace_example_docstring, +) +from diffusers.utils.torch_utils import is_compiled_module +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models.attention import ( + BasicTransformerBlock as DiffusersBasicTransformerBlock, +) +from mmcm.vision.process.correct_color import ( + hist_match_color_video_batch, + hist_match_video_bcthw, +) + +from ..models.attention import BasicTransformerBlock +from ..models.unet_3d_condition import UNet3DConditionModel +from ..utils.noise_util import random_noise, video_fusion_noise +from ..data.data_util import ( + adaptive_instance_normalization, + align_repeat_tensor_single_dim, + batch_adain_conditioned_tensor, + batch_concat_two_tensor_with_index, + batch_index_select, + fuse_part_tensor, +) +from ..utils.text_emb_util import encode_weighted_prompt +from ..utils.tensor_util import his_match +from ..utils.timesteps_util import generate_parameters_with_timesteps +from .context import get_context_scheduler, prepare_global_context + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class VideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + latents: Union[torch.Tensor, np.ndarray] + videos_mid: Union[torch.Tensor, np.ndarray] + down_block_res_samples: Tuple[torch.FloatTensor] = None + mid_block_res_samples: torch.FloatTensor = None + up_block_res_samples: torch.FloatTensor = None + mid_video_latents: List[torch.FloatTensor] = None + mid_video_noises: List[torch.FloatTensor] = None + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +def prepare_image( + image, # b c t h w + batch_size, + device, + dtype, + image_processor: Callable, + num_images_per_prompt: int = 1, + width=None, + height=None, +): + if isinstance(image, List) and isinstance(image[0], str): + raise NotImplementedError + if isinstance(image, List) and isinstance(image[0], np.ndarray): + image = np.concatenate(image, axis=0) + if isinstance(image, np.ndarray): + image = torch.from_numpy(image) + if image.ndim == 5: + image = rearrange(image, "b c t h w-> (b t) c h w") + if height is None: + height = image.shape[-2] + if width is None: + width = image.shape[-1] + width, height = (x - x % image_processor.vae_scale_factor for x in (width, height)) + if height != image.shape[-2] or width != image.shape[-1]: + image = torch.nn.functional.interpolate( + image, size=(height, width), mode="bilinear" + ) + image = image.to(dtype=torch.float32) / 255.0 + do_normalize = image_processor.config.do_normalize + if image.min() < 0: + warnings.warn( + "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " + f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", + FutureWarning, + ) + do_normalize = False + + if do_normalize: + image = image_processor.normalize(image) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + return image + + +class MusevControlNetPipeline( + DiffusersStableDiffusionControlNetImg2ImgPipeline, TextualInversionLoaderMixin +): + """ + a union diffusers pipeline, support + 1. text2image model only, or text2video model, by setting skip_temporal_layer + 2. text2video, image2video, video2video; + 3. multi controlnet + 4. IPAdapter + 5. referencenet + 6. IPAdapterFaceID + """ + + _optional_components = [ + "safety_checker", + "feature_extractor", + ] + print_idx = 0 + + def __init__( + self, + vae: AutoencoderKL, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + controlnet: ControlNetModel + | List[ControlNetModel] + | Tuple[ControlNetModel] + | MultiControlNetModel, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + # | MultiControlNetModel = None, + # text_encoder: CLIPTextModel = None, + # tokenizer: CLIPTokenizer = None, + # safety_checker: StableDiffusionSafetyChecker = None, + # feature_extractor: CLIPImageProcessor = None, + requires_safety_checker: bool = False, + referencenet: nn.Module = None, + vision_clip_extractor: nn.Module = None, + ip_adapter_image_proj: nn.Module = None, + face_emb_extractor: nn.Module = None, + facein_image_proj: nn.Module = None, + ip_adapter_face_emb_extractor: nn.Module = None, + ip_adapter_face_image_proj: nn.Module = None, + pose_guider: nn.Module = None, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + controlnet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker, + ) + self.referencenet = referencenet + + # ip_adapter + if isinstance(vision_clip_extractor, nn.Module): + vision_clip_extractor.to(dtype=self.unet.dtype, device=self.unet.device) + self.vision_clip_extractor = vision_clip_extractor + if isinstance(ip_adapter_image_proj, nn.Module): + ip_adapter_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) + self.ip_adapter_image_proj = ip_adapter_image_proj + + # facein + if isinstance(face_emb_extractor, nn.Module): + face_emb_extractor.to(dtype=self.unet.dtype, device=self.unet.device) + self.face_emb_extractor = face_emb_extractor + if isinstance(facein_image_proj, nn.Module): + facein_image_proj.to(dtype=self.unet.dtype, device=self.unet.device) + self.facein_image_proj = facein_image_proj + + # ip_adapter_face + if isinstance(ip_adapter_face_emb_extractor, nn.Module): + ip_adapter_face_emb_extractor.to( + dtype=self.unet.dtype, device=self.unet.device + ) + self.ip_adapter_face_emb_extractor = ip_adapter_face_emb_extractor + if isinstance(ip_adapter_face_image_proj, nn.Module): + ip_adapter_face_image_proj.to( + dtype=self.unet.dtype, device=self.unet.device + ) + self.ip_adapter_face_image_proj = ip_adapter_face_image_proj + + if isinstance(pose_guider, nn.Module): + pose_guider.to(dtype=self.unet.dtype, device=self.unet.device) + self.pose_guider = pose_guider + + def decode_latents(self, latents): + batch_size = latents.shape[0] + latents = rearrange(latents, "b c f h w -> (b f) c h w") + video = super().decode_latents(latents=latents) + video = rearrange(video, "(b f) h w c -> b c f h w", b=batch_size) + return video + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + video_length: int, + height: int, + width: int, + dtype: torch.dtype, + device: torch.device, + generator: torch.Generator, + latents: torch.Tensor = None, + w_ind_noise: float = 0.5, + image: torch.Tensor = None, + timestep: int = None, + initial_common_latent: torch.Tensor = None, + noise_type: str = "random", + add_latents_noise: bool = False, + need_img_based_video_noise: bool = False, + condition_latents: torch.Tensor = None, + img_weight=1e-3, + ) -> torch.Tensor: + """ + 支持多种情况下的latens: + img_based_latents: 当Image t=1,latents=None时,使用image赋值到shape,然后加噪;适用于text2video、middle2video。 + video_based_latents:image =shape或Latents!=None时,加噪,适用于video2video; + noise_latents:当image 和latents都为None时,生成随机噪声,适用于text2video + + support multi latents condition: + img_based_latents: when Image t=1, latents=None, use image to assign to shape, then add noise; suitable for text2video, middle2video. + video_based_latents: image =shape or Latents!=None, add noise, suitable for video2video; + noise_laten: when image and latents are both None, generate random noise, suitable for text2video + + Args: + batch_size (int): _description_ + num_channels_latents (int): _description_ + video_length (int): _description_ + height (int): _description_ + width (int): _description_ + dtype (torch.dtype): _description_ + device (torch.device): _description_ + generator (torch.Generator): _description_ + latents (torch.Tensor, optional): _description_. Defaults to None. + w_ind_noise (float, optional): _description_. Defaults to 0.5. + image (torch.Tensor, optional): _description_. Defaults to None. + timestep (int, optional): _description_. Defaults to None. + initial_common_latent (torch.Tensor, optional): _description_. Defaults to None. + noise_type (str, optional): _description_. Defaults to "random". + add_latents_noise (bool, optional): _description_. Defaults to False. + need_img_based_video_noise (bool, optional): _description_. Defaults to False. + condition_latents (torch.Tensor, optional): _description_. Defaults to None. + img_weight (_type_, optional): _description_. Defaults to 1e-3. + + Raises: + ValueError: _description_ + ValueError: _description_ + ValueError: _description_ + + Returns: + torch.Tensor: latents + """ + + # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py#L691 + # ref https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py#L659 + shape = ( + batch_size, + num_channels_latents, + video_length, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None or (latents is not None and add_latents_noise): + if noise_type == "random": + noise = random_noise( + shape=shape, dtype=dtype, device=device, generator=generator + ) + elif noise_type == "video_fusion": + noise = video_fusion_noise( + shape=shape, + dtype=dtype, + device=device, + generator=generator, + w_ind_noise=w_ind_noise, + initial_common_noise=initial_common_latent, + ) + if ( + need_img_based_video_noise + and condition_latents is not None + and image is None + and latents is None + ): + if self.print_idx == 0: + logger.debug( + ( + f"need_img_based_video_noise, condition_latents={condition_latents.shape}," + f"batch_size={batch_size}, noise={noise.shape}, video_length={video_length}" + ) + ) + condition_latents = condition_latents.mean(dim=2, keepdim=True) + condition_latents = repeat( + condition_latents, "b c t h w->b c (t x) h w", x=video_length + ) + noise = ( + img_weight**0.5 * condition_latents + + (1 - img_weight) ** 0.5 * noise + ) + if self.print_idx == 0: + logger.debug(f"noise={noise.shape}") + + if image is not None: + if image.ndim == 5: + image = rearrange(image, "b c t h w->(b t) c h w") + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + # self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) + self.vae.encode(image[i : i + 1]).latent_dist.mean + for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + # init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = self.vae.encode(image).latent_dist.mean + init_latents = self.vae.config.scaling_factor * init_latents + # scale the initial noise by the standard deviation required by the scheduler + if ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] == 0 + ): + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate( + "len(prompt) != len(image)", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat( + [init_latents] * additional_image_per_prompt, dim=0 + ) + elif ( + batch_size > init_latents.shape[0] + and batch_size % init_latents.shape[0] != 0 + ): + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + if init_latents.shape[2] != shape[3] and init_latents.shape[3] != shape[4]: + init_latents = torch.nn.functional.interpolate( + init_latents, + size=(shape[3], shape[4]), + mode="bilinear", + ) + init_latents = rearrange( + init_latents, "(b t) c h w-> b c t h w", t=video_length + ) + if self.print_idx == 0: + logger.debug(f"init_latensts={init_latents.shape}") + if latents is None: + if image is None: + latents = noise * self.scheduler.init_noise_sigma + else: + if self.print_idx == 0: + logger.debug(f"prepare latents, image is not None") + latents = self.scheduler.add_noise(init_latents, noise, timestep) + else: + if isinstance(latents, np.ndarray): + latents = torch.from_numpy(latents) + latents = latents.to(device=device, dtype=dtype) + if add_latents_noise: + latents = self.scheduler.add_noise(latents, noise, timestep) + else: + latents = latents * self.scheduler.init_noise_sigma + if latents.shape != shape: + raise ValueError( + f"Unexpected latents shape, got {latents.shape}, expected {shape}" + ) + latents = latents.to(device, dtype=dtype) + return latents + + def prepare_image( + self, + image, # b c t h w + batch_size, + num_images_per_prompt, + device, + dtype, + width=None, + height=None, + ): + return prepare_image( + image=image, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + width=width, + height=height, + image_processor=self.image_processor, + ) + + def prepare_control_image( + self, + image, # b c t h w + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = prepare_image( + image=image, + batch_size=batch_size, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=dtype, + width=width, + height=height, + image_processor=self.control_image_processor, + ) + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + return image + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1, + control_guidance_start=0, + control_guidance_end=1, + ): + # TODO: to implement + if image is not None: + return super().check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + def hist_match_with_vis_cond( + self, video: np.ndarray, target: np.ndarray + ) -> np.ndarray: + """ + video: b c t1 h w + target: b c t2(=1) h w + """ + video = hist_match_video_bcthw(video, target, value=255.0) + return video + + def get_facein_image_emb( + self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance + ): + # refer_face_image and its face_emb + if self.print_idx == 0: + logger.debug( + f"face_emb_extractor={type(self.face_emb_extractor)}, facein_image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " + ) + if ( + self.face_emb_extractor is not None + and self.facein_image_proj is not None + and refer_face_image is not None + ): + if self.print_idx == 0: + logger.debug(f"refer_face_image={refer_face_image.shape}") + if isinstance(refer_face_image, np.ndarray): + refer_face_image = torch.from_numpy(refer_face_image) + refer_face_image_facein = refer_face_image + n_refer_face_image = refer_face_image_facein.shape[2] + refer_face_image_facein = rearrange( + refer_face_image, "b c t h w-> (b t) h w c" + ) + # refer_face_image_emb: bt d或者 bt h w d + ( + refer_face_image_emb, + refer_align_face_image, + ) = self.face_emb_extractor.extract_images( + refer_face_image_facein, return_type="torch" + ) + refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) + if self.print_idx == 0: + logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") + if refer_face_image_emb.shape == 2: + refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") + elif refer_face_image_emb.shape == 4: + refer_face_image_emb = rearrange( + refer_face_image_emb, "bt h w d-> bt (h w) d" + ) + refer_face_image_emb_bk = refer_face_image_emb + refer_face_image_emb = self.facein_image_proj(refer_face_image_emb) + # Todo:当前不支持 IPAdapterPlus的vision_clip的输出 + refer_face_image_emb = rearrange( + refer_face_image_emb, + "(b t) n q-> b (t n) q", + t=n_refer_face_image, + ) + refer_face_image_emb = align_repeat_tensor_single_dim( + refer_face_image_emb, target_length=batch_size, dim=0 + ) + if do_classifier_free_guidance: + # TODO:固定特征,有优化空间 + # TODO: fix the feature, there is optimization space + uncond_refer_face_image_emb = self.facein_image_proj( + torch.zeros_like(refer_face_image_emb_bk).to( + device=device, dtype=dtype + ) + ) + # Todo:当前可能不支持 IPAdapterPlus的vision_clip的输出 + # TODO: do not support IPAdapterPlus's vision_clip's output + uncond_refer_face_image_emb = rearrange( + uncond_refer_face_image_emb, + "(b t) n q-> b (t n) q", + t=n_refer_face_image, + ) + uncond_refer_face_image_emb = align_repeat_tensor_single_dim( + uncond_refer_face_image_emb, target_length=batch_size, dim=0 + ) + if self.print_idx == 0: + logger.debug( + f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" + ) + logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") + refer_face_image_emb = torch.concat( + [ + uncond_refer_face_image_emb, + refer_face_image_emb, + ], + ) + else: + refer_face_image_emb = None + if self.print_idx == 0: + logger.debug(f"refer_face_image_emb={type(refer_face_image_emb)}") + + return refer_face_image_emb + + def get_ip_adapter_face_emb( + self, refer_face_image, device, dtype, batch_size, do_classifier_free_guidance + ): + # refer_face_image and its ip_adapter_face_emb + if self.print_idx == 0: + logger.debug( + f"face_emb_extractor={type(self.face_emb_extractor)}, ip_adapter__image_proj={type(self.facein_image_proj)}, refer_face_image={type(refer_face_image)}, " + ) + if ( + self.ip_adapter_face_emb_extractor is not None + and self.ip_adapter_face_image_proj is not None + and refer_face_image is not None + ): + if self.print_idx == 0: + logger.debug(f"refer_face_image={refer_face_image.shape}") + if isinstance(refer_face_image, np.ndarray): + refer_face_image = torch.from_numpy(refer_face_image) + refer_ip_adapter_face_image = refer_face_image + n_refer_face_image = refer_ip_adapter_face_image.shape[2] + refer_ip_adapter_face_image = rearrange( + refer_ip_adapter_face_image, "b c t h w-> (b t) h w c" + ) + # refer_face_image_emb: bt d or bt h w d + ( + refer_face_image_emb, + refer_align_face_image, + ) = self.ip_adapter_face_emb_extractor.extract_images( + refer_ip_adapter_face_image, return_type="torch" + ) + refer_face_image_emb = refer_face_image_emb.to(device=device, dtype=dtype) + if self.print_idx == 0: + logger.debug(f"refer_face_image_emb={refer_face_image_emb.shape}") + if refer_face_image_emb.shape == 2: + refer_face_image_emb = rearrange(refer_face_image_emb, "bt d-> bt 1 d") + elif refer_face_image_emb.shape == 4: + refer_face_image_emb = rearrange( + refer_face_image_emb, "bt h w d-> bt (h w) d" + ) + refer_face_image_emb_bk = refer_face_image_emb + refer_face_image_emb = self.ip_adapter_face_image_proj(refer_face_image_emb) + + refer_face_image_emb = rearrange( + refer_face_image_emb, + "(b t) n q-> b (t n) q", + t=n_refer_face_image, + ) + refer_face_image_emb = align_repeat_tensor_single_dim( + refer_face_image_emb, target_length=batch_size, dim=0 + ) + if do_classifier_free_guidance: + # TODO:固定特征,有优化空间 + # TODO: fix the feature, there is optimization space + uncond_refer_face_image_emb = self.ip_adapter_face_image_proj( + torch.zeros_like(refer_face_image_emb_bk).to( + device=device, dtype=dtype + ) + ) + # TODO: 当前可能不支持 IPAdapterPlus的vision_clip的输出 + # TODO: do not support IPAdapterPlus's vision_clip's output + uncond_refer_face_image_emb = rearrange( + uncond_refer_face_image_emb, + "(b t) n q-> b (t n) q", + t=n_refer_face_image, + ) + uncond_refer_face_image_emb = align_repeat_tensor_single_dim( + uncond_refer_face_image_emb, target_length=batch_size, dim=0 + ) + if self.print_idx == 0: + logger.debug( + f"uncond_refer_face_image_emb, {uncond_refer_face_image_emb.shape}" + ) + logger.debug(f"refer_face_image_emb, {refer_face_image_emb.shape}") + refer_face_image_emb = torch.concat( + [ + uncond_refer_face_image_emb, + refer_face_image_emb, + ], + ) + else: + refer_face_image_emb = None + if self.print_idx == 0: + logger.debug(f"ip_adapter_face_emb={type(refer_face_image_emb)}") + + return refer_face_image_emb + + def get_ip_adapter_image_emb( + self, + ip_adapter_image, + device, + dtype, + batch_size, + do_classifier_free_guidance, + height, + width, + ): + # refer_image vision_clip and its ipadapter_emb + if self.print_idx == 0: + logger.debug( + f"vision_clip_extractor={type(self.vision_clip_extractor)}," + f"ip_adapter_image_proj={type(self.ip_adapter_image_proj)}," + f"ip_adapter_image={type(ip_adapter_image)}," + ) + if self.vision_clip_extractor is not None and ip_adapter_image is not None: + if self.print_idx == 0: + logger.debug(f"ip_adapter_image={ip_adapter_image.shape}") + if isinstance(ip_adapter_image, np.ndarray): + ip_adapter_image = torch.from_numpy(ip_adapter_image) + # ip_adapter_image = ip_adapter_image.to(device=device, dtype=dtype) + n_ip_adapter_image = ip_adapter_image.shape[2] + ip_adapter_image = rearrange(ip_adapter_image, "b c t h w-> (b t) h w c") + ip_adapter_image_emb = self.vision_clip_extractor.extract_images( + ip_adapter_image, + target_height=height, + target_width=width, + return_type="torch", + ) + if ip_adapter_image_emb.ndim == 2: + ip_adapter_image_emb = rearrange(ip_adapter_image_emb, "b q-> b 1 q") + + ip_adapter_image_emb_bk = ip_adapter_image_emb + # 存在只需要image_prompt、但不需要 proj的场景,如使用image_prompt替代text_prompt + # There are scenarios where only image_prompt is needed, but proj is not needed, such as using image_prompt instead of text_prompt + if self.ip_adapter_image_proj is not None: + logger.debug(f"ip_adapter_image_proj is None, ") + ip_adapter_image_emb = self.ip_adapter_image_proj(ip_adapter_image_emb) + # TODO: 当前不支持 IPAdapterPlus的vision_clip的输出 + # TODO: do not support IPAdapterPlus's vision_clip's output + ip_adapter_image_emb = rearrange( + ip_adapter_image_emb, + "(b t) n q-> b (t n) q", + t=n_ip_adapter_image, + ) + ip_adapter_image_emb = align_repeat_tensor_single_dim( + ip_adapter_image_emb, target_length=batch_size, dim=0 + ) + if do_classifier_free_guidance: + # TODO:固定特征,有优化空间 + # TODO: fix the feature, there is optimization space + if self.ip_adapter_image_proj is not None: + uncond_ip_adapter_image_emb = self.ip_adapter_image_proj( + torch.zeros_like(ip_adapter_image_emb_bk).to( + device=device, dtype=dtype + ) + ) + if self.print_idx == 0: + logger.debug( + f"uncond_ip_adapter_image_emb use ip_adapter_image_proj(zero_like)" + ) + else: + uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) + if self.print_idx == 0: + logger.debug(f"uncond_ip_adapter_image_emb use zero_like") + # TODO:当前可能不支持 IPAdapterPlus的vision_clip的输出 + # TODO: do not support IPAdapterPlus's vision_clip's output + uncond_ip_adapter_image_emb = rearrange( + uncond_ip_adapter_image_emb, + "(b t) n q-> b (t n) q", + t=n_ip_adapter_image, + ) + uncond_ip_adapter_image_emb = align_repeat_tensor_single_dim( + uncond_ip_adapter_image_emb, target_length=batch_size, dim=0 + ) + if self.print_idx == 0: + logger.debug( + f"uncond_ip_adapter_image_emb, {uncond_ip_adapter_image_emb.shape}" + ) + logger.debug(f"ip_adapter_image_emb, {ip_adapter_image_emb.shape}") + # uncond_ip_adapter_image_emb = torch.zeros_like(ip_adapter_image_emb) + ip_adapter_image_emb = torch.concat( + [ + uncond_ip_adapter_image_emb, + ip_adapter_image_emb, + ], + ) + + else: + ip_adapter_image_emb = None + if self.print_idx == 0: + logger.debug(f"ip_adapter_image_emb={type(ip_adapter_image_emb)}") + return ip_adapter_image_emb + + def get_referencenet_image_vae_emb( + self, + refer_image, + batch_size, + num_videos_per_prompt, + device, + dtype, + do_classifier_free_guidance, + width: int = None, + height: int = None, + ): + # prepare_referencenet_emb + if self.print_idx == 0: + logger.debug( + f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" + ) + if self.referencenet is not None and refer_image is not None: + n_refer_image = refer_image.shape[2] + refer_image_vae = self.prepare_image( + refer_image, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=dtype, + width=width, + height=height, + ) + # ref_hidden_states = self.vae.encode(refer_image_vae).latent_dist.sample() + refer_image_vae_emb = self.vae.encode(refer_image_vae).latent_dist.mean + refer_image_vae_emb = self.vae.config.scaling_factor * refer_image_vae_emb + + logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") + + if do_classifier_free_guidance: + # 1. zeros_like image + # uncond_refer_image_vae_emb = self.vae.encode( + # torch.zeros_like(refer_image_vae) + # ).latent_dist.mean + # uncond_refer_image_vae_emb = ( + # self.vae.config.scaling_factor * uncond_refer_image_vae_emb + # ) + + # 2. zeros_like image vae emb + # uncond_refer_image_vae_emb = torch.zeros_like(refer_image_vae_emb) + + # uncond_refer_image_vae_emb = rearrange( + # uncond_refer_image_vae_emb, + # "(b t) c h w-> b c t h w", + # t=n_refer_image, + # ) + + # refer_image_vae_emb = rearrange( + # refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image + # ) + # refer_image_vae_emb = torch.concat( + # [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 + # ) + # refer_image_vae_emb = rearrange( + # refer_image_vae_emb, "b c t h w-> (b t) c h w" + # ) + # logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") + + # 3. uncond_refer_image_vae_emb = refer_image_vae_emb + uncond_refer_image_vae_emb = refer_image_vae_emb + + uncond_refer_image_vae_emb = rearrange( + uncond_refer_image_vae_emb, + "(b t) c h w-> b c t h w", + t=n_refer_image, + ) + + refer_image_vae_emb = rearrange( + refer_image_vae_emb, "(b t) c h w-> b c t h w", t=n_refer_image + ) + refer_image_vae_emb = torch.concat( + [uncond_refer_image_vae_emb, refer_image_vae_emb], dim=0 + ) + refer_image_vae_emb = rearrange( + refer_image_vae_emb, "b c t h w-> (b t) c h w" + ) + logger.debug(f"refer_image_vae_emb={refer_image_vae_emb.shape}") + else: + refer_image_vae_emb = None + return refer_image_vae_emb + + def get_referencenet_emb( + self, + refer_image_vae_emb, + refer_image, + batch_size, + num_videos_per_prompt, + device, + dtype, + ip_adapter_image_emb, + do_classifier_free_guidance, + prompt_embeds, + ref_timestep_int: int = 0, + ): + # prepare_referencenet_emb + if self.print_idx == 0: + logger.debug( + f"referencenet={type(self.referencenet)}, refer_image={type(refer_image)}" + ) + if ( + self.referencenet is not None + and refer_image_vae_emb is not None + and refer_image is not None + ): + n_refer_image = refer_image.shape[2] + # ref_timestep = ( + # torch.ones((refer_image_vae_emb.shape[0],), device=device) + # * ref_timestep_int + # ) + ref_timestep = torch.zeros_like(ref_timestep_int) + # referencenet 优先使用 ip_adapter 中图像提取到的 clip_vision_emb + if ip_adapter_image_emb is not None: + refer_prompt_embeds = ip_adapter_image_emb + else: + refer_prompt_embeds = prompt_embeds + if self.print_idx == 0: + logger.debug( + f"use referencenet: n_refer_image={n_refer_image}, refer_image_vae_emb={refer_image_vae_emb.shape}, ref_timestep={ref_timestep.shape}" + ) + if prompt_embeds is not None: + logger.debug(f"prompt_embeds={prompt_embeds.shape},") + + # refer_image_vae_emb = self.scheduler.scale_model_input( + # refer_image_vae_emb, ref_timestep + # ) + # self.scheduler._step_index = None + # self.scheduler.is_scale_input_called = False + referencenet_params = { + "sample": refer_image_vae_emb, + "encoder_hidden_states": refer_prompt_embeds, + "timestep": ref_timestep, + "num_frames": n_refer_image, + "return_ndim": 5, + } + ( + down_block_refer_embs, + mid_block_refer_emb, + refer_self_attn_emb, + ) = self.referencenet(**referencenet_params) + + # many ways to prepare negative referencenet emb + # mode 1 + # zero shape like ref_image + # if do_classifier_free_guidance: + # # mode 2: + # # if down_block_refer_embs is not None: + # # down_block_refer_embs = [ + # # torch.cat([x] * 2) for x in down_block_refer_embs + # # ] + # # if mid_block_refer_emb is not None: + # # mid_block_refer_emb = torch.cat([mid_block_refer_emb] * 2) + # # if refer_self_attn_emb is not None: + # # refer_self_attn_emb = [ + # # torch.cat([x] * 2) for x in refer_self_attn_emb + # # ] + + # # mode 3 + # if down_block_refer_embs is not None: + # down_block_refer_embs = [ + # torch.cat([torch.zeros_like(x), x]) + # for x in down_block_refer_embs + # ] + # if mid_block_refer_emb is not None: + # mid_block_refer_emb = torch.cat( + # [torch.zeros_like(mid_block_refer_emb), mid_block_refer_emb] * 2 + # ) + # if refer_self_attn_emb is not None: + # refer_self_attn_emb = [ + # torch.cat([torch.zeros_like(x), x]) for x in refer_self_attn_emb + # ] + else: + down_block_refer_embs = None + mid_block_refer_emb = None + refer_self_attn_emb = None + if self.print_idx == 0: + logger.debug(f"down_block_refer_embs={type(down_block_refer_embs)}") + logger.debug(f"mid_block_refer_emb={type(mid_block_refer_emb)}") + logger.debug(f"refer_self_attn_emb={type(refer_self_attn_emb)}") + return down_block_refer_embs, mid_block_refer_emb, refer_self_attn_emb + + def prepare_condition_latents_and_index( + self, + condition_images, + condition_latents, + video_length, + batch_size, + dtype, + device, + latent_index, + vision_condition_latent_index, + ): + # prepare condition_latents + if condition_images is not None and condition_latents is None: + # condition_latents = self.vae.encode(condition_images).latent_dist.sample() + condition_latents = self.vae.encode(condition_images).latent_dist.mean + condition_latents = self.vae.config.scaling_factor * condition_latents + condition_latents = rearrange( + condition_latents, "(b t) c h w-> b c t h w", b=batch_size + ) + if self.print_idx == 0: + logger.debug( + f"condition_latents from condition_images, shape is condition_latents={condition_latents.shape}", + ) + if condition_latents is not None: + total_frames = condition_latents.shape[2] + video_length + if isinstance(condition_latents, np.ndarray): + condition_latents = torch.from_numpy(condition_latents) + condition_latents = condition_latents.to(dtype=dtype, device=device) + # if condition is None, mean condition_latents head, generated video is tail + if vision_condition_latent_index is not None: + # vision_condition_latent_index should be list, whose length is condition_latents.shape[2] + # -1 -> will be converted to condition_latents.shape[2]+video_length + vision_condition_latent_index_lst = [ + i_v if i_v != -1 else total_frames - 1 + for i_v in vision_condition_latent_index + ] + vision_condition_latent_index = torch.LongTensor( + vision_condition_latent_index_lst, + ).to(device=device) + if self.print_idx == 0: + logger.debug( + f"vision_condition_latent_index {type(vision_condition_latent_index)}, {vision_condition_latent_index}" + ) + else: + # [0, condition_latents.shape[2]] + vision_condition_latent_index = torch.arange( + condition_latents.shape[2], dtype=torch.long, device=device + ) + vision_condition_latent_index_lst = ( + vision_condition_latent_index.tolist() + ) + if latent_index is None: + # [condition_latents.shape[2], condition_latents.shape[2]+video_length] + latent_index_lst = sorted( + list( + set(range(total_frames)) + - set(vision_condition_latent_index_lst) + ) + ) + latent_index = torch.LongTensor( + latent_index_lst, + ).to(device=device) + + if vision_condition_latent_index is not None: + vision_condition_latent_index = vision_condition_latent_index.to( + device=device + ) + if self.print_idx == 0: + logger.debug( + f"pipeline vision_condition_latent_index ={vision_condition_latent_index.shape}, {vision_condition_latent_index}" + ) + if latent_index is not None: + latent_index = latent_index.to(device=device) + if self.print_idx == 0: + logger.debug( + f"pipeline latent_index ={latent_index.shape}, {latent_index}" + ) + logger.debug(f"condition_latents={type(condition_latents)}") + logger.debug(f"latent_index={type(latent_index)}") + logger.debug( + f"vision_condition_latent_index={type(vision_condition_latent_index)}" + ) + return condition_latents, latent_index, vision_condition_latent_index + + def prepare_controlnet_and_guidance_parameter( + self, control_guidance_start, control_guidance_end + ): + controlnet = ( + self.controlnet._orig_mod + if is_compiled_module(self.controlnet) + else self.controlnet + ) + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance( + control_guidance_end, list + ): + control_guidance_start = len(control_guidance_end) * [ + control_guidance_start + ] + elif not isinstance(control_guidance_end, list) and isinstance( + control_guidance_start, list + ): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance( + control_guidance_end, list + ): + mult = ( + len(controlnet.nets) + if isinstance(controlnet, MultiControlNetModel) + else 1 + ) + control_guidance_start, control_guidance_end = mult * [ + control_guidance_start + ], mult * [control_guidance_end] + return controlnet, control_guidance_start, control_guidance_end + + def prepare_controlnet_guess_mode(self, controlnet, guess_mode): + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + return guess_mode + + def prepare_controlnet_image_and_latents( + self, + controlnet, + width, + height, + batch_size, + num_videos_per_prompt, + device, + dtype, + controlnet_latents=None, + controlnet_condition_latents=None, + control_image=None, + controlnet_condition_images=None, + guess_mode=False, + do_classifier_free_guidance=False, + ): + if isinstance(controlnet, ControlNetModel): + if controlnet_latents is not None: + if isinstance(controlnet_latents, np.ndarray): + controlnet_latents = torch.from_numpy(controlnet_latents) + if controlnet_condition_latents is not None: + if isinstance(controlnet_condition_latents, np.ndarray): + controlnet_condition_latents = torch.from_numpy( + controlnet_condition_latents + ) + # TODO:使用index进行concat + controlnet_latents = torch.concat( + [controlnet_condition_latents, controlnet_latents], dim=2 + ) + if not guess_mode and do_classifier_free_guidance: + controlnet_latents = torch.concat([controlnet_latents] * 2, dim=0) + controlnet_latents = rearrange( + controlnet_latents, "b c t h w->(b t) c h w" + ) + controlnet_latents = controlnet_latents.to(device=device, dtype=dtype) + if self.print_idx == 0: + logger.debug( + f"call, controlnet_latents.shape, f{controlnet_latents.shape}" + ) + else: + # TODO: concat with index + if isinstance(control_image, np.ndarray): + control_image = torch.from_numpy(control_image) + if controlnet_condition_images is not None: + if isinstance(controlnet_condition_images, np.ndarray): + controlnet_condition_images = torch.from_numpy( + controlnet_condition_images + ) + control_image = torch.concatenate( + [controlnet_condition_images, control_image], dim=2 + ) + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image.shape[-2:] + if self.print_idx == 0: + logger.debug(f"call, control_image.shape , {control_image.shape}") + + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + # TODO: directly support contronet_latent instead of frames + if ( + controlnet_latents is not None + and controlnet_condition_latents is not None + ): + raise NotImplementedError + for i, control_image_ in enumerate(control_image): + if controlnet_condition_images is not None and isinstance( + controlnet_condition_images, list + ): + if isinstance(controlnet_condition_images[i], np.ndarray): + control_image_ = np.concatenate( + [controlnet_condition_images[i], control_image_], axis=2 + ) + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + height, width = control_image[0].shape[-2:] + else: + assert False + if control_image is not None: + if not isinstance(control_image, list): + if self.print_idx == 0: + logger.debug(f"control_image shape is {control_image.shape}") + else: + if self.print_idx == 0: + logger.debug(f"control_image shape is {control_image[0].shape}") + + return control_image, controlnet_latents + + def get_controlnet_emb( + self, + run_controlnet, + guess_mode, + do_classifier_free_guidance, + latents, + prompt_embeds, + latent_model_input, + controlnet_keep, + controlnet_conditioning_scale, + control_image, + controlnet_latents, + i, + t, + ): + if run_controlnet and self.pose_guider is None: + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input( + control_model_input, t + ) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + if isinstance(controlnet_keep[i], list): + cond_scale = [ + c * s + for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i]) + ] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + control_model_input_reshape = rearrange( + control_model_input, "b c t h w -> (b t) c h w" + ) + logger.debug( + f"control_model_input_reshape={control_model_input_reshape.shape}, controlnet_prompt_embeds={controlnet_prompt_embeds.shape}" + ) + encoder_hidden_states_repeat = align_repeat_tensor_single_dim( + controlnet_prompt_embeds, + target_length=control_model_input_reshape.shape[0], + dim=0, + ) + + if self.print_idx == 0: + logger.debug( + f"control_model_input_reshape={control_model_input_reshape.shape}, " + f"encoder_hidden_states_repeat={encoder_hidden_states_repeat.shape}, " + ) + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input_reshape, + t, + encoder_hidden_states_repeat, + controlnet_cond=control_image, + controlnet_cond_latents=controlnet_latents, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + if self.print_idx == 0: + logger.debug( + f"controlnet, len(down_block_res_samples, {len(down_block_res_samples)}", + ) + for i_tmp, tmp in enumerate(down_block_res_samples): + logger.debug( + f"controlnet down_block_res_samples i={i_tmp}, down_block_res_sample={tmp.shape}" + ) + logger.debug( + f"controlnet mid_block_res_sample, {mid_block_res_sample.shape}" + ) + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [ + torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples + ] + mid_block_res_sample = torch.cat( + [ + torch.zeros_like(mid_block_res_sample), + mid_block_res_sample, + ] + ) + else: + down_block_res_samples = None + mid_block_res_sample = None + + return down_block_res_samples, mid_block_res_sample + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + video_length: Optional[int], + prompt: Union[str, List[str]] = None, + # b c t h w + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + # b c t(1) ho wo + condition_images: Optional[torch.FloatTensor] = None, + condition_latents: Optional[torch.FloatTensor] = None, + latents: Optional[torch.FloatTensor] = None, + add_latents_noise: bool = False, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + guidance_scale_end: float = None, + guidance_scale_method: str = "linear", + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + # b c t(1) hi wi + controlnet_condition_images: Optional[torch.FloatTensor] = None, + # b c t(1) ho wo + controlnet_condition_latents: Optional[torch.FloatTensor] = None, + controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + need_middle_latents: bool = False, + w_ind_noise: float = 0.5, + initial_common_latent: Optional[torch.FloatTensor] = None, + latent_index: torch.LongTensor = None, + vision_condition_latent_index: torch.LongTensor = None, + # noise parameters + noise_type: str = "random", + need_img_based_video_noise: bool = False, + skip_temporal_layer: bool = False, + img_weight: float = 1e-3, + need_hist_match: bool = False, + motion_speed: float = 8.0, + refer_image: Optional[Tuple[torch.Tensor, np.array]] = None, + ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, + refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, + ip_adapter_scale: float = 1.0, + facein_scale: float = 1.0, + ip_adapter_face_scale: float = 1.0, + ip_adapter_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, + prompt_only_use_image_prompt: bool = False, + # serial_denoise parameter start + record_mid_video_noises: bool = False, + last_mid_video_noises: List[torch.Tensor] = None, + record_mid_video_latents: bool = False, + last_mid_video_latents: List[torch.TensorType] = None, + video_overlap: int = 1, + # serial_denoise parameter end + # parallel_denoise parameter start + # refer to https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/src/pipelines/pipeline_pose2vid_long.py#L354 + context_schedule="uniform", + context_frames=12, + context_stride=1, + context_overlap=4, + context_batch_size=1, + interpolation_factor=1, + # parallel_denoise parameter end + decoder_t_segment: int = 200, + ): + r""" + 旨在兼容text2video、text2image、img2img、video2video、是否有controlnet等的通用pipeline。目前仅不支持img2img、video2video。 + 支持多片段同时denoise,交叉部分加权平均 + + 当 skip_temporal_layer 为 False 时, unet 起 video 生成作用;skip_temporal_layer为True时,unet起原image作用。 + 当controlnet的所有入参为None,等价于走的是text2video pipeline; + 当 condition_latents、controlnet_condition_images、controlnet_condition_latents为None时,表示不走首帧条件生成的时序condition pipeline + 现在没有考虑对 `num_videos_per_prompt` 的兼容性,不是1可能报错; + + if skip_temporal_layer is False, unet motion layer works, else unet only run text2image layers. + if parameters about controlnet are None, means text2video pipeline; + if ondition_latents、controlnet_condition_images、controlnet_condition_latents are None, means only run text2video without vision condition images. + By now, code works well with `num_videos_per_prpmpt=1`, !=1 may be wrong. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + condition_latents: + 与latents相对应,是Latents的时序condition,一般为首帧,b c t(1) ho wo + be corresponding to latents, vision condtion latents, usually first frame, should be b c t(1) ho wo. + controlnet_latents: + 与image二选一,image会被转化成controlnet_latents + Choose either image or controlnet_latents. If image is chosen, it will be converted to controlnet_latents. + controlnet_condition_images: + Optional[torch.FloatTensor]# b c t(1) ho wo,与image相对应,会和image在t通道concat一起,然后转化成 controlnet_latents + b c t(1) ho wo, corresponding to image, will be concatenated along the t channel with image and then converted to controlnet_latents. + controlnet_condition_latents: Optional[torch.FloatTensor]:# + b c t(1) ho wo,会和 controlnet_latents 在t 通道concat一起,转化成 controlnet_latents + b c t(1) ho wo will be concatenated along the t channel with controlnet_latents and converted to controlnet_latents. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + strength (`float`, *optional*, defaults to 0.8): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + skip_temporal_layer (`bool`: default to False) 为False时,unet起video生成作用,会运行时序生成的block;skip_temporal_layer为True时,unet起原image作用,跳过时序生成的block。 + need_img_based_video_noise: bool = False, 当只有首帧latents时,是否需要扩展为video noise; + num_videos_per_prompt: now only support 1. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + run_controlnet = control_image is not None or controlnet_latents is not None + + if run_controlnet: + ( + controlnet, + control_guidance_start, + control_guidance_end, + ) = self.prepare_controlnet_and_guidance_parameter( + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + dtype = self.unet.dtype + # print("pipeline unet dtype", dtype) + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + if run_controlnet: + if isinstance(controlnet, MultiControlNetModel) and isinstance( + controlnet_conditioning_scale, float + ): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len( + controlnet.nets + ) + guess_mode = self.prepare_controlnet_guess_mode( + controlnet=controlnet, + guess_mode=guess_mode, + ) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) + if cross_attention_kwargs is not None + else None + ) + if self.text_encoder is not None: + prompt_embeds = encode_weighted_prompt( + self, + prompt, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + # lora_scale=text_encoder_lora_scale, + ) + logger.debug(f"use text_encoder prepare prompt_emb={prompt_embeds.shape}") + else: + prompt_embeds = None + if image is not None: + image = self.prepare_image( + image, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=dtype, + ) + if self.print_idx == 0: + logger.debug(f"image={image.shape}") + if condition_images is not None: + condition_images = self.prepare_image( + condition_images, + width=width, + height=height, + batch_size=batch_size * num_videos_per_prompt, + num_images_per_prompt=num_videos_per_prompt, + device=device, + dtype=dtype, + ) + if self.print_idx == 0: + logger.debug(f"condition_images={condition_images.shape}") + # 4. Prepare image + if run_controlnet: + ( + control_image, + controlnet_latents, + ) = self.prepare_controlnet_image_and_latents( + controlnet=controlnet, + width=width, + height=height, + batch_size=batch_size, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + dtype=dtype, + controlnet_condition_latents=controlnet_condition_latents, + control_image=control_image, + controlnet_condition_images=controlnet_condition_images, + guess_mode=guess_mode, + do_classifier_free_guidance=do_classifier_free_guidance, + controlnet_latents=controlnet_latents, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + if strength and (image is not None and latents is not None): + if self.print_idx == 0: + logger.debug( + f"prepare timesteps, with get_timesteps strength={strength}, num_inference_steps={num_inference_steps}" + ) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) + else: + if self.print_idx == 0: + logger.debug(f"prepare timesteps, without get_timesteps") + timesteps = self.scheduler.timesteps + latent_timestep = timesteps[:1].repeat( + batch_size * num_videos_per_prompt + ) # 6. Prepare latent variables + + ( + condition_latents, + latent_index, + vision_condition_latent_index, + ) = self.prepare_condition_latents_and_index( + condition_images=condition_images, + condition_latents=condition_latents, + video_length=video_length, + batch_size=batch_size, + dtype=dtype, + device=device, + latent_index=latent_index, + vision_condition_latent_index=vision_condition_latent_index, + ) + if vision_condition_latent_index is None: + n_vision_cond = 0 + else: + n_vision_cond = vision_condition_latent_index.shape[0] + + num_channels_latents = self.unet.config.in_channels + if self.print_idx == 0: + logger.debug(f"pipeline controlnet, start prepare latents") + + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=num_channels_latents, + video_length=video_length, + height=height, + width=width, + dtype=dtype, + device=device, + generator=generator, + latents=latents, + image=image, + timestep=latent_timestep, + w_ind_noise=w_ind_noise, + initial_common_latent=initial_common_latent, + noise_type=noise_type, + add_latents_noise=add_latents_noise, + need_img_based_video_noise=need_img_based_video_noise, + condition_latents=condition_latents, + img_weight=img_weight, + ) + if self.print_idx == 0: + logger.debug(f"pipeline controlnet, finish prepare latents={latents.shape}") + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + if noise_type == "video_fusion" and "noise_type" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ): + extra_step_kwargs["w_ind_noise"] = w_ind_noise + extra_step_kwargs["noise_type"] = noise_type + # extra_step_kwargs["noise_offset"] = noise_offset + + # 7.1 Create tensor stating which controlnets to keep + if run_controlnet: + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append( + keeps[0] if isinstance(controlnet, ControlNetModel) else keeps + ) + else: + controlnet_keep = None + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + if skip_temporal_layer: + self.unet.set_skip_temporal_layers(True) + + n_timesteps = len(timesteps) + guidance_scale_lst = generate_parameters_with_timesteps( + start=guidance_scale, + stop=guidance_scale_end, + num=n_timesteps, + method=guidance_scale_method, + ) + if self.print_idx == 0: + logger.debug( + f"guidance_scale_lst, {guidance_scale_method}, {guidance_scale}, {guidance_scale_end}, {guidance_scale_lst}" + ) + + ip_adapter_image_emb = self.get_ip_adapter_image_emb( + ip_adapter_image=ip_adapter_image, + batch_size=batch_size, + device=device, + dtype=dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + height=height, + width=width, + ) + + # 当前仅当没有ip_adapter时,按照参数 prompt_only_use_image_prompt 要求是否完全替换 image_prompt_emb + # only if ip_adapter is None and prompt_only_use_image_prompt is True, use image_prompt_emb replace text_prompt + if ( + ip_adapter_image_emb is not None + and prompt_only_use_image_prompt + and not self.unet.ip_adapter_cross_attn + ): + prompt_embeds = ip_adapter_image_emb + logger.debug(f"use ip_adapter_image_emb replace prompt_embeds") + refer_face_image_emb = self.get_facein_image_emb( + refer_face_image=refer_face_image, + batch_size=batch_size, + device=device, + dtype=dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + ip_adapter_face_emb = self.get_ip_adapter_face_emb( + refer_face_image=ip_adapter_face_image, + batch_size=batch_size, + device=device, + dtype=dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + refer_image_vae_emb = self.get_referencenet_image_vae_emb( + refer_image=refer_image, + device=device, + dtype=dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + batch_size=batch_size, + width=width, + height=height, + ) + + if self.pose_guider is not None and control_image is not None: + if self.print_idx == 0: + logger.debug(f"pose_guider, controlnet_image={control_image.shape}") + control_image = rearrange( + control_image, " (b t) c h w->b c t h w", t=video_length + ) + pose_guider_emb = self.pose_guider(control_image) + pose_guider_emb = rearrange(pose_guider_emb, "b c t h w-> (b t) c h w") + else: + pose_guider_emb = None + logger.debug(f"prompt_embeds={prompt_embeds.shape}") + + if control_image is not None: + if isinstance(control_image, list): + logger.debug(f"control_imageis list, num={len(control_image)}") + control_image = [ + rearrange( + control_image_tmp, + " (b t) c h w->b c t h w", + b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, + ) + for control_image_tmp in control_image + ] + else: + logger.debug(f"control_image={control_image.shape}, before") + control_image = rearrange( + control_image, + " (b t) c h w->b c t h w", + b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, + ) + logger.debug(f"control_image={control_image.shape}, after") + + if controlnet_latents is not None: + if isinstance(controlnet_latents, list): + logger.debug( + f"controlnet_latents is list, num={len(controlnet_latents)}" + ) + controlnet_latents = [ + rearrange( + controlnet_latents_tmp, + " (b t) c h w->b c t h w", + b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, + ) + for controlnet_latents_tmp in controlnet_latents + ] + else: + logger.debug(f"controlnet_latents={controlnet_latents.shape}, before") + controlnet_latents = rearrange( + controlnet_latents, + " (b t) c h w->b c t h w", + b=(int(do_classifier_free_guidance) * 1 + 1) * batch_size, + ) + logger.debug(f"controlnet_latents={controlnet_latents.shape}, after") + + videos_mid = [] + mid_video_noises = [] if record_mid_video_noises else None + mid_video_latents = [] if record_mid_video_latents else None + + global_context = prepare_global_context( + context_schedule=context_schedule, + num_inference_steps=num_inference_steps, + time_size=latents.shape[2], + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + ) + logger.debug( + f"context_schedule={context_schedule}, time_size={latents.shape[2]}, context_frames={context_frames}, context_stride={context_stride}, context_overlap={context_overlap}, context_batch_size={context_batch_size}" + ) + logger.debug(f"global_context={global_context}") + # iterative denoise + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # 使用 last_mid_video_latents 来影响初始化latent,该部分效果较差,暂留代码 + # use last_mide_video_latents to affect initial latent. works bad, Temporarily reserved + if i == 0: + if record_mid_video_latents: + mid_video_latents.append(latents[:, :, -video_overlap:]) + if record_mid_video_noises: + mid_video_noises.append(None) + if ( + last_mid_video_latents is not None + and len(last_mid_video_latents) > 0 + ): + if self.print_idx == 1: + logger.debug( + f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" + ) + latents = fuse_part_tensor( + last_mid_video_latents[0], + latents, + video_overlap, + weight=0.1, + skip_step=0, + ) + noise_pred = torch.zeros( + ( + latents.shape[0] * (2 if do_classifier_free_guidance else 1), + *latents.shape[1:], + ), + device=latents.device, + dtype=latents.dtype, + ) + counter = torch.zeros( + (1, 1, latents.shape[2], 1, 1), + device=latents.device, + dtype=latents.dtype, + ) + if i == 0: + ( + down_block_refer_embs, + mid_block_refer_emb, + refer_self_attn_emb, + ) = self.get_referencenet_emb( + refer_image_vae_emb=refer_image_vae_emb, + refer_image=refer_image, + device=device, + dtype=dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + ip_adapter_image_emb=ip_adapter_image_emb, + batch_size=batch_size, + ref_timestep_int=t, + ) + for context in global_context: + # expand the latents if we are doing classifier free guidance + latents_c = torch.cat([latents[:, :, c] for c in context]) + latent_index_c = ( + torch.cat([latent_index[c] for c in context]) + if latent_index is not None + else None + ) + latent_model_input = latents_c.to(device).repeat( + 2 if do_classifier_free_guidance else 1, 1, 1, 1, 1 + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + sub_latent_index_c = ( + torch.LongTensor( + torch.arange(latent_index_c.shape[-1]) + n_vision_cond + ).to(device=latents_c.device) + if latent_index is not None + else None + ) + if condition_latents is not None: + latent_model_condition = ( + torch.cat([condition_latents] * 2) + if do_classifier_free_guidance + else latents + ) + + if self.print_idx == 0: + logger.debug( + f"vision_condition_latent_index, {vision_condition_latent_index.shape}, vision_condition_latent_index" + ) + logger.debug( + f"latent_model_condition, {latent_model_condition.shape}" + ) + logger.debug(f"latent_index, {latent_index_c.shape}") + logger.debug( + f"latent_model_input, {latent_model_input.shape}" + ) + logger.debug(f"sub_latent_index_c, {sub_latent_index_c}") + latent_model_input = batch_concat_two_tensor_with_index( + data1=latent_model_condition, + data1_index=vision_condition_latent_index, + data2=latent_model_input, + data2_index=sub_latent_index_c, + dim=2, + ) + if control_image is not None: + if vision_condition_latent_index is not None: + # 获取 vision_condition 对应的 control_imgae/control_latent 部分 + # generate control_image/control_latent corresponding to vision_condition + controlnet_condtion_latent_index = ( + vision_condition_latent_index.clone().cpu().tolist() + ) + if self.print_idx == 0: + logger.debug( + f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" + ) + controlnet_context = [ + controlnet_condtion_latent_index + + [c_i + n_vision_cond for c_i in c] + for c in context + ] + else: + controlnet_context = context + if self.print_idx == 0: + logger.debug( + f"controlnet_context={controlnet_context}, latent_model_input={latent_model_input.shape}" + ) + if isinstance(control_image, list): + control_image_c = [ + torch.cat( + [ + control_image_tmp[:, :, c] + for c in controlnet_context + ] + ) + for control_image_tmp in control_image + ] + control_image_c = [ + rearrange(control_image_tmp, " b c t h w-> (b t) c h w") + for control_image_tmp in control_image_c + ] + else: + control_image_c = torch.cat( + [control_image[:, :, c] for c in controlnet_context] + ) + control_image_c = rearrange( + control_image_c, " b c t h w-> (b t) c h w" + ) + else: + control_image_c = None + if controlnet_latents is not None: + if vision_condition_latent_index is not None: + # 获取 vision_condition 对应的 control_imgae/control_latent 部分 + # generate control_image/control_latent corresponding to vision_condition + controlnet_condtion_latent_index = ( + vision_condition_latent_index.clone().cpu().tolist() + ) + if self.print_idx == 0: + logger.debug( + f"context={context}, controlnet_condtion_latent_index={controlnet_condtion_latent_index}" + ) + controlnet_context = [ + controlnet_condtion_latent_index + + [c_i + n_vision_cond for c_i in c] + for c in context + ] + else: + controlnet_context = context + if self.print_idx == 0: + logger.debug( + f"controlnet_context={controlnet_context}, controlnet_latents={controlnet_latents.shape}, latent_model_input={latent_model_input.shape}," + ) + controlnet_latents_c = torch.cat( + [controlnet_latents[:, :, c] for c in controlnet_context] + ) + controlnet_latents_c = rearrange( + controlnet_latents_c, " b c t h w-> (b t) c h w" + ) + else: + controlnet_latents_c = None + ( + down_block_res_samples, + mid_block_res_sample, + ) = self.get_controlnet_emb( + run_controlnet=run_controlnet, + guess_mode=guess_mode, + do_classifier_free_guidance=do_classifier_free_guidance, + latents=latents_c, + prompt_embeds=prompt_embeds, + latent_model_input=latent_model_input, + control_image=control_image_c, + controlnet_latents=controlnet_latents_c, + controlnet_keep=controlnet_keep, + t=t, + i=i, + controlnet_conditioning_scale=controlnet_conditioning_scale, + ) + if self.print_idx == 0: + logger.debug( + f"{i}, latent_model_input={latent_model_input.shape}, sub_latent_index_c={sub_latent_index_c}" + f"{vision_condition_latent_index}" + ) + # time.sleep(10) + noise_pred_c = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + sample_index=sub_latent_index_c, + vision_conditon_frames_sample_index=vision_condition_latent_index, + sample_frame_rate=motion_speed, + down_block_refer_embs=down_block_refer_embs, + mid_block_refer_emb=mid_block_refer_emb, + refer_self_attn_emb=refer_self_attn_emb, + vision_clip_emb=ip_adapter_image_emb, + face_emb=refer_face_image_emb, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_emb=ip_adapter_face_emb, + ip_adapter_face_scale=ip_adapter_face_scale, + do_classifier_free_guidance=do_classifier_free_guidance, + pose_guider_emb=pose_guider_emb, + )[0] + if condition_latents is not None: + noise_pred_c = batch_index_select( + noise_pred_c, dim=2, index=sub_latent_index_c + ).contiguous() + if self.print_idx == 0: + logger.debug( + f"{i}, latent_model_input={latent_model_input.shape}, noise_pred_c={noise_pred_c.shape}, {len(context)}, {len(context[0])}" + ) + for j, c in enumerate(context): + noise_pred[:, :, c] = noise_pred[:, :, c] + noise_pred_c + counter[:, :, c] = counter[:, :, c] + 1 + noise_pred = noise_pred / counter + + if ( + last_mid_video_noises is not None + and len(last_mid_video_noises) > 0 + and i <= num_inference_steps // 2 # 是个超参数 super paramter + ): + if self.print_idx == 1: + logger.debug( + f"{i}, last_mid_video_noises={last_mid_video_noises[i].shape}" + ) + noise_pred = fuse_part_tensor( + last_mid_video_noises[i + 1], + noise_pred, + video_overlap, + weight=0.01, + skip_step=1, + ) + if record_mid_video_noises: + mid_video_noises.append(noise_pred[:, :, -video_overlap:]) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale_lst[i] * ( + noise_pred_text - noise_pred_uncond + ) + + if self.print_idx == 0: + logger.debug( + f"before step, noise_pred={noise_pred.shape}, {noise_pred.device}, latents={latents.shape}, {latents.device}, t={t}" + ) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + **extra_step_kwargs, + ).prev_sample + + if ( + last_mid_video_latents is not None + and len(last_mid_video_latents) > 0 + and i <= 1 # 超参数, super parameter + ): + if self.print_idx == 1: + logger.debug( + f"{i}, last_mid_video_latents={last_mid_video_latents[i].shape}" + ) + latents = fuse_part_tensor( + last_mid_video_latents[i + 1], + latents, + video_overlap, + weight=0.1, + skip_step=0, + ) + if record_mid_video_latents: + mid_video_latents.append(latents[:, :, -video_overlap:]) + + if need_middle_latents is True: + videos_mid.append(self.decode_latents(latents)) + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + self.print_idx += 1 + + if condition_latents is not None: + latents = batch_concat_two_tensor_with_index( + data1=condition_latents, + data1_index=vision_condition_latent_index, + data2=latents, + data2_index=latent_index, + dim=2, + ) + b, c, t, h, w = latents.shape + num_segments = (t + decoder_t_segment - 1) // decoder_t_segment + + video_segments = [] + # to avoid t chanel too large causing gpu memory error + # split video latents in slices along t channel, decode each slice, and then concatenate them + for i in range(num_segments): + logger.debug(f"Decoding {i} th segment") + start_t = i * decoder_t_segment + end_t = min((i + 1) * decoder_t_segment, t) + latents_segment = latents[:, :, start_t:end_t, :, :] + video_segment = self.decode_latents(latents_segment) + video_segments.append(video_segment) + video_segments_np = np.concatenate(video_segments, axis=2) + video = torch.from_numpy(video_segments_np) + + if skip_temporal_layer: + self.unet.set_skip_temporal_layers(False) + if need_hist_match: + video[:, :, latent_index, :, :] = self.hist_match_with_vis_cond( + batch_index_select(video, index=latent_index, dim=2), + batch_index_select(video, index=vision_condition_latent_index, dim=2), + ) + # Convert to tensor + if output_type == "tensor": + videos_mid = [torch.from_numpy(x) for x in videos_mid] + video = torch.from_numpy(video) + else: + latents = latents.cpu().numpy() + + if not return_dict: + return ( + video, + latents, + videos_mid, + mid_video_latents, + mid_video_noises, + ) + + return VideoPipelineOutput( + videos=video, + latents=latents, + videos_mid=videos_mid, + mid_video_latents=mid_video_latents, + mid_video_noises=mid_video_noises, + ) diff --git a/musev/pipelines/pipeline_controlnet_predictor.py b/musev/pipelines/pipeline_controlnet_predictor.py new file mode 100755 index 0000000000000000000000000000000000000000..2c07c8bdb638f904e69b297cd25f11078cc7bb1b --- /dev/null +++ b/musev/pipelines/pipeline_controlnet_predictor.py @@ -0,0 +1,1290 @@ +import copy +from typing import Any, Callable, Dict, Iterable, Union +import PIL +import cv2 +import torch +import argparse +import datetime +import logging +import inspect +import math +import os +import shutil +from typing import Dict, List, Optional, Tuple +from pprint import pformat, pprint +from collections import OrderedDict +from dataclasses import dataclass +import gc +import time + +import numpy as np +from omegaconf import OmegaConf +from omegaconf import SCMode +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.checkpoint +from einops import rearrange, repeat +import pandas as pd +import h5py +from diffusers.models.autoencoder_kl import AutoencoderKL + +from diffusers.models.modeling_utils import load_state_dict +from diffusers.utils import ( + logging, + BaseOutput, + logging, +) +from diffusers.utils.dummy_pt_objects import ConsistencyDecoderVAE +from diffusers.utils.import_utils import is_xformers_available + +from mmcm.utils.seed_util import set_all_seed +from mmcm.vision.data.video_dataset import DecordVideoDataset +from mmcm.vision.process.correct_color import hist_match_video_bcthw +from mmcm.vision.process.image_process import ( + batch_dynamic_crop_resize_images, + batch_dynamic_crop_resize_images_v2, +) +from mmcm.vision.utils.data_type_util import is_video +from mmcm.vision.feature_extractor.controlnet import load_controlnet_model + +from ..schedulers import ( + EulerDiscreteScheduler, + LCMScheduler, + DDIMScheduler, + DDPMScheduler, +) +from ..models.unet_3d_condition import UNet3DConditionModel +from .pipeline_controlnet import ( + MusevControlNetPipeline, + VideoPipelineOutput as PipelineVideoPipelineOutput, +) +from ..utils.util import save_videos_grid_with_opencv +from ..utils.model_util import ( + update_pipeline_basemodel, + update_pipeline_lora_model, + update_pipeline_lora_models, + update_pipeline_model_parameters, +) + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class VideoPipelineOutput(BaseOutput): + videos: Union[torch.Tensor, np.ndarray] + latents: Union[torch.Tensor, np.ndarray] + videos_mid: Union[torch.Tensor, np.ndarray] + controlnet_cond: Union[torch.Tensor, np.ndarray] + generated_videos: Union[torch.Tensor, np.ndarray] + + +def update_controlnet_processor_params( + src: Union[Dict, List[Dict]], dst: Union[Dict, List[Dict]] +): + """merge dst into src""" + if isinstance(src, list) and not isinstance(dst, List): + dst = [dst] * len(src) + if isinstance(src, list) and isinstance(dst, list): + return [ + update_controlnet_processor_params(src[i], dst[i]) for i in range(len(src)) + ] + if src is None: + dct = {} + else: + dct = copy.deepcopy(src) + if dst is None: + dst = {} + dct.update(dst) + return dct + + +class DiffusersPipelinePredictor(object): + """wraper of diffusers pipeline, support generation function interface. support + 1. text2video: inputs include text, image(optional), refer_image(optional) + 2. video2video: + 1. use controlnet to control spatial + 2. or use video fuse noise to denoise + """ + + def __init__( + self, + sd_model_path: str, + unet: nn.Module, + controlnet_name: Union[str, List[str]] = None, + controlnet: nn.Module = None, + lora_dict: Dict[str, Dict] = None, + requires_safety_checker: bool = False, + device: str = "cuda", + dtype: torch.dtype = torch.float16, + # controlnet parameters start + need_controlnet_processor: bool = True, + need_controlnet: bool = True, + image_resolution: int = 512, + detect_resolution: int = 512, + include_body: bool = True, + hand_and_face: bool = None, + include_face: bool = False, + include_hand: bool = True, + negative_embedding: List = None, + # controlnet parameters end + enable_xformers_memory_efficient_attention: bool = True, + lcm_lora_dct: Dict = None, + referencenet: nn.Module = None, + ip_adapter_image_proj: nn.Module = None, + vision_clip_extractor: nn.Module = None, + face_emb_extractor: nn.Module = None, + facein_image_proj: nn.Module = None, + ip_adapter_face_emb_extractor: nn.Module = None, + ip_adapter_face_image_proj: nn.Module = None, + vae_model: Optional[Tuple[nn.Module, str]] = None, + pose_guider: Optional[nn.Module] = None, + enable_zero_snr: bool = False, + ) -> None: + self.sd_model_path = sd_model_path + self.unet = unet + self.controlnet_name = controlnet_name + self.controlnet = controlnet + self.requires_safety_checker = requires_safety_checker + self.device = device + self.dtype = dtype + self.need_controlnet_processor = need_controlnet_processor + self.need_controlnet = need_controlnet + self.need_controlnet_processor = need_controlnet_processor + self.image_resolution = image_resolution + self.detect_resolution = detect_resolution + self.include_body = include_body + self.hand_and_face = hand_and_face + self.include_face = include_face + self.include_hand = include_hand + self.negative_embedding = negative_embedding + self.device = device + self.dtype = dtype + self.lcm_lora_dct = lcm_lora_dct + if controlnet is None and controlnet_name is not None: + controlnet, controlnet_processor, processor_params = load_controlnet_model( + controlnet_name, + device=device, + dtype=dtype, + need_controlnet_processor=need_controlnet_processor, + need_controlnet=need_controlnet, + image_resolution=image_resolution, + detect_resolution=detect_resolution, + include_body=include_body, + include_face=include_face, + hand_and_face=hand_and_face, + include_hand=include_hand, + ) + self.controlnet_processor = controlnet_processor + self.controlnet_processor_params = processor_params + logger.debug(f"init controlnet controlnet_name={controlnet_name}") + + if controlnet is not None: + controlnet = controlnet.to(device=device, dtype=dtype) + controlnet.eval() + if pose_guider is not None: + pose_guider = pose_guider.to(device=device, dtype=dtype) + pose_guider.eval() + unet.to(device=device, dtype=dtype) + unet.eval() + if referencenet is not None: + referencenet.to(device=device, dtype=dtype) + referencenet.eval() + if ip_adapter_image_proj is not None: + ip_adapter_image_proj.to(device=device, dtype=dtype) + ip_adapter_image_proj.eval() + if vision_clip_extractor is not None: + vision_clip_extractor.to(device=device, dtype=dtype) + vision_clip_extractor.eval() + if face_emb_extractor is not None: + face_emb_extractor.to(device=device, dtype=dtype) + face_emb_extractor.eval() + if facein_image_proj is not None: + facein_image_proj.to(device=device, dtype=dtype) + facein_image_proj.eval() + + if isinstance(vae_model, str): + # TODO: poor implementation, to improve + if "consistency" in vae_model: + vae = ConsistencyDecoderVAE.from_pretrained(vae_model) + else: + vae = AutoencoderKL.from_pretrained(vae_model) + elif isinstance(vae_model, nn.Module): + vae = vae_model + else: + vae = None + if vae is not None: + vae.to(device=device, dtype=dtype) + vae.eval() + if ip_adapter_face_emb_extractor is not None: + ip_adapter_face_emb_extractor.to(device=device, dtype=dtype) + ip_adapter_face_emb_extractor.eval() + if ip_adapter_face_image_proj is not None: + ip_adapter_face_image_proj.to(device=device, dtype=dtype) + ip_adapter_face_image_proj.eval() + params = { + "pretrained_model_name_or_path": sd_model_path, + "controlnet": controlnet, + "unet": unet, + "requires_safety_checker": requires_safety_checker, + "torch_dtype": dtype, + "torch_device": device, + "referencenet": referencenet, + "ip_adapter_image_proj": ip_adapter_image_proj, + "vision_clip_extractor": vision_clip_extractor, + "facein_image_proj": facein_image_proj, + "face_emb_extractor": face_emb_extractor, + "ip_adapter_face_emb_extractor": ip_adapter_face_emb_extractor, + "ip_adapter_face_image_proj": ip_adapter_face_image_proj, + "pose_guider": pose_guider, + } + if vae is not None: + params["vae"] = vae + pipeline = MusevControlNetPipeline.from_pretrained(**params) + pipeline = pipeline.to(torch_device=device, torch_dtype=dtype) + logger.debug( + f"init pipeline from sd_model_path={sd_model_path}, device={device}, dtype={dtype}" + ) + if ( + negative_embedding is not None + and pipeline.text_encoder is not None + and pipeline.tokenizer is not None + ): + for neg_emb_path, neg_token in negative_embedding: + pipeline.load_textual_inversion(neg_emb_path, token=neg_token) + + # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + # pipe.enable_model_cpu_offload() + if not enable_zero_snr: + pipeline.scheduler = EulerDiscreteScheduler.from_config( + pipeline.scheduler.config + ) + # pipeline.scheduler = DDIMScheduler.from_config( + # pipeline.scheduler.config, + # 该部分会影响生成视频的亮度,不适用于首帧给定的视频生成 + # this part will change brightness of video, not suitable for image2video mode + # rescale_betas_zero_snr affect the brightness of the generated video, not suitable for vision condition images mode + # # rescale_betas_zero_snr=True, + # ) + # pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) + else: + # moore scheduler, just for codetest + pipeline.scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="linear", + clip_sample=False, + steps_offset=1, + ### Zero-SNR params + prediction_type="v_prediction", + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + ) + + pipeline.enable_vae_slicing() + self.enable_xformers_memory_efficient_attention = ( + enable_xformers_memory_efficient_attention + ) + if enable_xformers_memory_efficient_attention: + if is_xformers_available(): + pipeline.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + self.pipeline = pipeline + self.unload_dict = [] # keep lora state + if lora_dict is not None: + self.load_lora(lora_dict=lora_dict) + logger.debug("load lora {}".format(" ".join(list(lora_dict.keys())))) + + if lcm_lora_dct is not None: + self.pipeline.scheduler = LCMScheduler.from_config( + self.pipeline.scheduler.config + ) + self.load_lora(lora_dict=lcm_lora_dct) + logger.debug("load lcm lora {}".format(" ".join(list(lcm_lora_dct.keys())))) + + # logger.debug("Unet3Model Parameters") + # logger.debug(pformat(self.__dict__)) + + def load_lora( + self, + lora_dict: Dict[str, Dict], + ): + self.pipeline, unload_dict = update_pipeline_lora_models( + self.pipeline, lora_dict, device=self.device + ) + self.unload_dict += unload_dict + + def unload_lora(self): + for layer_data in self.unload_dict: + layer = layer_data["layer"] + added_weight = layer_data["added_weight"] + layer.weight.data -= added_weight + self.unload_dict = [] + gc.collect() + torch.cuda.empty_cache() + + def update_unet(self, unet: nn.Module): + self.pipeline.unet = unet.to(device=self.device, dtype=self.dtype) + + def update_sd_model(self, model_path: str, text_model_path: str): + self.pipeline = update_pipeline_basemodel( + self.pipeline, + model_path, + text_sd_model_path=text_model_path, + device=self.device, + ) + + def update_sd_model_and_unet( + self, lora_sd_path: str, lora_path: str, sd_model_path: str = None + ): + self.pipeline = update_pipeline_model_parameters( + self.pipeline, + model_path=lora_sd_path, + lora_path=lora_path, + text_model_path=sd_model_path, + device=self.device, + ) + + def update_controlnet(self, controlnet_name=Union[str, List[str]]): + self.pipeline.controlnet = load_controlnet_model(controlnet_name).to( + device=self.device, dtype=self.dtype + ) + + def run_pipe_text2video( + self, + video_length: int, + prompt: Union[str, List[str]] = None, + # b c t h w + height: Optional[int] = None, + width: Optional[int] = None, + video_num_inference_steps: int = 50, + video_guidance_scale: float = 7.5, + video_guidance_scale_end: float = 3.5, + video_guidance_scale_method: str = "linear", + strength: float = 0.8, + video_negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + same_seed: Optional[Union[int, List[int]]] = None, + # b c t(1) ho wo + condition_latents: Optional[torch.FloatTensor] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + guidance_scale: float = 7.5, + num_inference_steps: int = 50, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + need_middle_latents: bool = False, + w_ind_noise: float = 0.5, + initial_common_latent: Optional[torch.FloatTensor] = None, + latent_index: torch.LongTensor = None, + vision_condition_latent_index: torch.LongTensor = None, + n_vision_condition: int = 1, + noise_type: str = "random", + max_batch_num: int = 30, + need_img_based_video_noise: bool = False, + condition_images: torch.Tensor = None, + fix_condition_images: bool = False, + redraw_condition_image: bool = False, + img_weight: float = 1e-3, + motion_speed: float = 8.0, + need_hist_match: bool = False, + refer_image: Optional[ + Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]] + ] = None, + ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, + fixed_refer_image: bool = True, + fixed_ip_adapter_image: bool = True, + redraw_condition_image_with_ipdapter: bool = True, + redraw_condition_image_with_referencenet: bool = True, + refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, + fixed_refer_face_image: bool = True, + redraw_condition_image_with_facein: bool = True, + ip_adapter_scale: float = 1.0, + redraw_condition_image_with_ip_adapter_face: bool = True, + facein_scale: float = 1.0, + ip_adapter_face_scale: float = 1.0, + prompt_only_use_image_prompt: bool = False, + # serial_denoise parameter start + record_mid_video_noises: bool = False, + record_mid_video_latents: bool = False, + video_overlap: int = 1, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule="uniform", + context_frames=12, + context_stride=1, + context_overlap=4, + context_batch_size=1, + interpolation_factor=1, + # parallel_denoise parameter end + ): + """ + generate long video with end2end mode + 1. prepare vision condition image by assingning, redraw, or generation with text2image module with skip_temporal_layer=True; + 2. use image or latest of vision condition image to generate first shot; + 3. use last n (1) image or last latent of last shot as new vision condition latent to generate next shot + 4. repeat n_batch times between 2 and 3 + + 类似img2img pipeline + refer_image和ip_adapter_image的来源: + 1. 输入给定; + 2. 当未输入时,纯text2video生成首帧,并赋值更新refer_image和ip_adapter_image; + 3. 当有输入,但是因为redraw更新了首帧时,也需要赋值更新refer_image和ip_adapter_image; + + refer_image和ip_adapter_image的作用: + 1. 当无首帧图像时,用于生成首帧; + 2. 用于生成视频。 + + + similar to diffusers img2img pipeline. + three ways to prepare refer_image and ip_adapter_image + 1. from input parameter + 2. when input paramter is None, use text2video to generate vis cond image, and use as refer_image and ip_adapter_image too. + 3. given from input paramter, but still redraw, update with redrawn vis cond image. + """ + # crop resize images + if condition_images is not None: + logger.debug( + f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}" + ) + condition_images = batch_dynamic_crop_resize_images_v2( + condition_images, + target_height=height, + target_width=width, + ) + if refer_image is not None: + logger.debug( + f"center crop resize refer_image to height={height}, width={width}" + ) + refer_image = batch_dynamic_crop_resize_images_v2( + refer_image, + target_height=height, + target_width=width, + ) + if ip_adapter_image is not None: + logger.debug( + f"center crop resize ip_adapter_image to height={height}, width={width}" + ) + ip_adapter_image = batch_dynamic_crop_resize_images_v2( + ip_adapter_image, + target_height=height, + target_width=width, + ) + if refer_face_image is not None: + logger.debug( + f"center crop resize refer_face_image to height={height}, width={width}" + ) + refer_face_image = batch_dynamic_crop_resize_images_v2( + refer_face_image, + target_height=height, + target_width=width, + ) + run_video_length = video_length + # generate vision condition frame start + # if condition_images is None, generate with refer_image, ip_adapter_image + # if condition_images not None and need redraw, according to redraw_condition_image_with_ipdapter, redraw_condition_image_with_referencenet, refer_image, ip_adapter_image + if n_vision_condition > 0: + if condition_images is None and condition_latents is None: + logger.debug("run_pipe_text2video, generate first_image") + ( + condition_images, + condition_latents, + _, + _, + _, + ) = self.pipeline( + prompt=prompt, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + video_length=1, + height=height, + width=width, + return_dict=False, + skip_temporal_layer=True, + output_type="np", + generator=generator, + w_ind_noise=w_ind_noise, + need_img_based_video_noise=need_img_based_video_noise, + refer_image=refer_image + if redraw_condition_image_with_referencenet + else None, + ip_adapter_image=ip_adapter_image + if redraw_condition_image_with_ipdapter + else None, + refer_face_image=refer_face_image + if redraw_condition_image_with_facein + else None, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_scale=ip_adapter_face_scale, + ip_adapter_face_image=refer_face_image + if redraw_condition_image_with_ip_adapter_face + else None, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + ) + run_video_length = video_length - 1 + elif ( + condition_images is not None + and redraw_condition_image + and condition_latents is None + ): + logger.debug("run_pipe_text2video, redraw first_image") + + ( + condition_images, + condition_latents, + _, + _, + _, + ) = self.pipeline( + prompt=prompt, + image=condition_images, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + strength=strength, + video_length=condition_images.shape[2], + height=height, + width=width, + return_dict=False, + skip_temporal_layer=True, + output_type="np", + generator=generator, + w_ind_noise=w_ind_noise, + need_img_based_video_noise=need_img_based_video_noise, + refer_image=refer_image + if redraw_condition_image_with_referencenet + else None, + ip_adapter_image=ip_adapter_image + if redraw_condition_image_with_ipdapter + else None, + refer_face_image=refer_face_image + if redraw_condition_image_with_facein + else None, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_scale=ip_adapter_face_scale, + ip_adapter_face_image=refer_face_image + if redraw_condition_image_with_ip_adapter_face + else None, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + ) + else: + condition_images = None + condition_latents = None + # generate vision condition frame end + + # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above start + if ( + refer_image is not None + and redraw_condition_image + and condition_images is not None + ): + refer_image = condition_images * 255.0 + logger.debug(f"update refer_image because of redraw_condition_image") + elif ( + refer_image is None + and self.pipeline.referencenet is not None + and condition_images is not None + ): + refer_image = condition_images * 255.0 + logger.debug(f"update refer_image because of generate first_image") + + # ipadapter_image + if ( + ip_adapter_image is not None + and redraw_condition_image + and condition_images is not None + ): + ip_adapter_image = condition_images * 255.0 + logger.debug(f"update ip_adapter_image because of redraw_condition_image") + elif ( + ip_adapter_image is None + and self.pipeline.ip_adapter_image_proj is not None + and condition_images is not None + ): + ip_adapter_image = condition_images * 255.0 + logger.debug(f"update ip_adapter_image because of generate first_image") + # refer_image and ip_adapter_image, update mode from 2 and 3 as mentioned above end + + # refer_face_image, update mode from 2 and 3 as mentioned above start + if ( + refer_face_image is not None + and redraw_condition_image + and condition_images is not None + ): + refer_face_image = condition_images * 255.0 + logger.debug(f"update refer_face_image because of redraw_condition_image") + elif ( + refer_face_image is None + and self.pipeline.facein_image_proj is not None + and condition_images is not None + ): + refer_face_image = condition_images * 255.0 + logger.debug(f"update face_image because of generate first_image") + # refer_face_image, update mode from 2 and 3 as mentioned above end + + last_mid_video_noises = None + last_mid_video_latents = None + initial_common_latent = None + + out_videos = [] + for i_batch in range(max_batch_num): + logger.debug(f"sd_pipeline_predictor, run_pipe_text2video: {i_batch}") + if max_batch_num is not None and i_batch == max_batch_num: + break + + if i_batch == 0: + result_overlap = 0 + else: + if n_vision_condition > 0: + # ignore condition_images if condition_latents is not None in pipeline + if not fix_condition_images: + logger.debug(f"{i_batch}, update condition_latents") + condition_latents = out_latents_batch[ + :, :, -n_vision_condition:, :, : + ] + else: + logger.debug(f"{i_batch}, do not update condition_latents") + result_overlap = n_vision_condition + + if not fixed_refer_image and n_vision_condition > 0: + logger.debug("ref_image use last frame of last generated out video") + refer_image = out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + else: + logger.debug("use given fixed ref_image") + + if not fixed_ip_adapter_image and n_vision_condition > 0: + logger.debug( + "ip_adapter_image use last frame of last generated out video" + ) + ip_adapter_image = ( + out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + ) + else: + logger.debug("use given fixed ip_adapter_image") + + if not fixed_refer_face_image and n_vision_condition > 0: + logger.debug( + "refer_face_image use last frame of last generated out video" + ) + refer_face_image = ( + out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + ) + else: + logger.debug("use given fixed ip_adapter_image") + + run_video_length = video_length + if same_seed is not None: + _, generator = set_all_seed(same_seed) + + out = self.pipeline( + video_length=run_video_length, # int + prompt=prompt, + num_inference_steps=video_num_inference_steps, + height=height, + width=width, + generator=generator, + condition_images=condition_images, + condition_latents=condition_latents, # b co t(1) ho wo + skip_temporal_layer=False, + output_type="np", + noise_type=noise_type, + negative_prompt=video_negative_prompt, + guidance_scale=video_guidance_scale, + guidance_scale_end=video_guidance_scale_end, + guidance_scale_method=video_guidance_scale_method, + w_ind_noise=w_ind_noise, + need_img_based_video_noise=need_img_based_video_noise, + img_weight=img_weight, + motion_speed=motion_speed, + vision_condition_latent_index=vision_condition_latent_index, + refer_image=refer_image, + ip_adapter_image=ip_adapter_image, + refer_face_image=refer_face_image, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_scale=ip_adapter_face_scale, + ip_adapter_face_image=refer_face_image, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + initial_common_latent=initial_common_latent, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + last_mid_video_noises=last_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + last_mid_video_latents=last_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + ) + logger.debug( + f"run_pipe_text2video, out.videos.shape, i_batch={i_batch}, videos={out.videos.shape}, result_overlap={result_overlap}" + ) + out_batch = out.videos[:, :, result_overlap:, :, :] + out_latents_batch = out.latents[:, :, result_overlap:, :, :] + out_videos.append(out_batch) + + out_videos = np.concatenate(out_videos, axis=2) + if need_hist_match: + out_videos[:, :, 1:, :, :] = hist_match_video_bcthw( + out_videos[:, :, 1:, :, :], out_videos[:, :, :1, :, :], value=255.0 + ) + return out_videos + + def run_pipe_with_latent_input( + self, + ): + pass + + def run_pipe_middle2video_with_middle(self, middle: Tuple[str, Iterable]): + pass + + def run_pipe_video2video( + self, + video: Tuple[str, Iterable], + time_size: int = None, + sample_rate: int = None, + overlap: int = None, + step: int = None, + prompt: Union[str, List[str]] = None, + # b c t h w + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + video_num_inference_steps: int = 50, + guidance_scale: float = 7.5, + video_guidance_scale: float = 7.5, + video_guidance_scale_end: float = 3.5, + video_guidance_scale_method: str = "linear", + video_negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_videos_per_prompt: Optional[int] = 1, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + controlnet_latents: Union[torch.FloatTensor, np.ndarray] = None, + # b c t(1) hi wi + controlnet_condition_images: Optional[torch.FloatTensor] = None, + # b c t(1) ho wo + controlnet_condition_latents: Optional[torch.FloatTensor] = None, + # b c t(1) ho wo + condition_latents: Optional[torch.FloatTensor] = None, + condition_images: Optional[torch.FloatTensor] = None, + fix_condition_images: bool = False, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "tensor", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + need_middle_latents: bool = False, + w_ind_noise: float = 0.5, + img_weight: float = 0.001, + initial_common_latent: Optional[torch.FloatTensor] = None, + latent_index: torch.LongTensor = None, + vision_condition_latent_index: torch.LongTensor = None, + noise_type: str = "random", + controlnet_processor_params: Dict = None, + need_return_videos: bool = False, + need_return_condition: bool = False, + max_batch_num: int = 30, + strength: float = 0.8, + video_strength: float = 0.8, + need_video2video: bool = False, + need_img_based_video_noise: bool = False, + need_hist_match: bool = False, + end_to_end: bool = True, + refer_image: Optional[ + Tuple[np.ndarray, torch.Tensor, List[str], List[np.ndarray]] + ] = None, + ip_adapter_image: Optional[Tuple[torch.Tensor, np.array]] = None, + fixed_refer_image: bool = True, + fixed_ip_adapter_image: bool = True, + redraw_condition_image: bool = False, + redraw_condition_image_with_ipdapter: bool = True, + redraw_condition_image_with_referencenet: bool = True, + refer_face_image: Optional[Tuple[torch.Tensor, np.array]] = None, + fixed_refer_face_image: bool = True, + redraw_condition_image_with_facein: bool = True, + ip_adapter_scale: float = 1.0, + facein_scale: float = 1.0, + ip_adapter_face_scale: float = 1.0, + redraw_condition_image_with_ip_adapter_face: bool = False, + n_vision_condition: int = 1, + prompt_only_use_image_prompt: bool = False, + motion_speed: float = 8.0, + # serial_denoise parameter start + record_mid_video_noises: bool = False, + record_mid_video_latents: bool = False, + video_overlap: int = 1, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule="uniform", + context_frames=12, + context_stride=1, + context_overlap=4, + context_batch_size=1, + interpolation_factor=1, + # parallel_denoise parameter end + # 支持 video_path 时多种输入 + # TODO:// video_has_condition =False,当且仅支持 video_is_middle=True, 待后续重构 + # TODO:// when video_has_condition =False, video_is_middle should be True. + video_is_middle: bool = False, + video_has_condition: bool = True, + ): + """ + 类似controlnet text2img pipeline。 输入视频,用视频得到controlnet condition。 + 目前仅支持time_size == step,overlap=0 + 输出视频长度=输入视频长度 + + similar to controlnet text2image pipeline, generate video with controlnet condition from given video. + By now, sliding window only support time_size == step, overlap = 0. + """ + if isinstance(video, str): + video_reader = DecordVideoDataset( + video, + time_size=time_size, + step=step, + overlap=overlap, + sample_rate=sample_rate, + device="cpu", + data_type="rgb", + channels_order="c t h w", + drop_last=True, + ) + else: + video_reader = video + videos = [] if need_return_videos else None + out_videos = [] + out_condition = ( + [] + if need_return_condition and self.pipeline.controlnet is not None + else None + ) + # crop resize images + if condition_images is not None: + logger.debug( + f"center crop resize condition_images={condition_images.shape}, to height={height}, width={width}" + ) + condition_images = batch_dynamic_crop_resize_images_v2( + condition_images, + target_height=height, + target_width=width, + ) + if refer_image is not None: + logger.debug( + f"center crop resize refer_image to height={height}, width={width}" + ) + refer_image = batch_dynamic_crop_resize_images_v2( + refer_image, + target_height=height, + target_width=width, + ) + if ip_adapter_image is not None: + logger.debug( + f"center crop resize ip_adapter_image to height={height}, width={width}" + ) + ip_adapter_image = batch_dynamic_crop_resize_images_v2( + ip_adapter_image, + target_height=height, + target_width=width, + ) + if refer_face_image is not None: + logger.debug( + f"center crop resize refer_face_image to height={height}, width={width}" + ) + refer_face_image = batch_dynamic_crop_resize_images_v2( + refer_face_image, + target_height=height, + target_width=width, + ) + first_image = None + last_mid_video_noises = None + last_mid_video_latents = None + initial_common_latent = None + # initial_common_latent = torch.randn((1, 4, 1, 112, 64)).to( + # device=self.device, dtype=self.dtype + # ) + + for i_batch, item in enumerate(video_reader): + logger.debug(f"\n sd_pipeline_predictor, run_pipe_video2video: {i_batch}") + if max_batch_num is not None and i_batch == max_batch_num: + break + # read and prepare video batch + batch = item.data + batch = batch_dynamic_crop_resize_images( + batch, + target_height=height, + target_width=width, + ) + + batch = batch[np.newaxis, ...] + batch_size, channel, video_length, video_height, video_width = batch.shape + # extract controlnet middle + if self.pipeline.controlnet is not None: + batch = rearrange(batch, "b c t h w-> (b t) h w c") + controlnet_processor_params = update_controlnet_processor_params( + src=self.controlnet_processor_params, + dst=controlnet_processor_params, + ) + if not video_is_middle: + batch_condition = self.controlnet_processor( + data=batch, + data_channel_order="b h w c", + target_height=height, + target_width=width, + return_type="np", + return_data_channel_order="b c h w", + input_rgb_order="rgb", + processor_params=controlnet_processor_params, + ) + else: + # TODO: 临时用于可视化输入的 controlnet middle 序列,后续待拆到 middl2video中,也可以增加参数支持 + # TODO: only use video_path is controlnet middle output, to improved + batch_condition = rearrange( + copy.deepcopy(batch), " b h w c-> b c h w" + ) + + # 当前仅当 输入是 middle、condition_image的pose在middle首帧之前,需要重新生成condition_images的pose并绑定到middle_batch上 + # when video_path is middle seq and condition_image is not aligned with middle seq, + # regenerate codntion_images pose, and then concat into middle_batch, + if ( + i_batch == 0 + and not video_has_condition + and video_is_middle + and condition_images is not None + ): + condition_images_reshape = rearrange( + condition_images, "b c t h w-> (b t) h w c" + ) + condition_images_condition = self.controlnet_processor( + data=condition_images_reshape, + data_channel_order="b h w c", + target_height=height, + target_width=width, + return_type="np", + return_data_channel_order="b c h w", + input_rgb_order="rgb", + processor_params=controlnet_processor_params, + ) + condition_images_condition = rearrange( + condition_images_condition, + "(b t) c h w-> b c t h w", + b=batch_size, + ) + else: + condition_images_condition = None + if not isinstance(batch_condition, list): + batch_condition = rearrange( + batch_condition, "(b t) c h w-> b c t h w", b=batch_size + ) + if condition_images_condition is not None: + batch_condition = np.concatenate( + [ + condition_images_condition, + batch_condition, + ], + axis=2, + ) + # 此时 batch_condition 比 batch 多了一帧,为了最终视频能 concat 存储,替换下 + # 当前仅适用于 condition_images_condition 不为None + # when condition_images_condition is not None, batch_condition has more frames than batch + batch = rearrange(batch_condition, "b c t h w ->(b t) h w c") + else: + batch_condition = [ + rearrange(x, "(b t) c h w-> b c t h w", b=batch_size) + for x in batch_condition + ] + if condition_images_condition is not None: + batch_condition = [ + np.concatenate( + [condition_images_condition, batch_condition_tmp], + axis=2, + ) + for batch_condition_tmp in batch_condition + ] + batch = rearrange(batch, "(b t) h w c -> b c t h w", b=batch_size) + else: + batch_condition = None + # condition [0,255] + # latent: [0,1] + # 按需求生成多个片段, + # generate multi video_shot + # 第一个片段 会特殊处理,需要生成首帧 + # first shot is special because of first frame. + # 后续片段根据拿前一个片段结果,首尾相连的方式生成。 + # use last frame of last shot as the first frame of the current shot + # TODO: 当前独立拆开实现,待后续融合到一起实现 + # TODO: to optimize implementation way + if n_vision_condition == 0: + actual_video_length = video_length + control_image = batch_condition + first_image_controlnet_condition = None + first_image_latents = None + if need_video2video: + video = batch + else: + video = None + result_overlap = 0 + else: + if i_batch == 0: + if self.pipeline.controlnet is not None: + if not isinstance(batch_condition, list): + first_image_controlnet_condition = batch_condition[ + :, :, :1, :, : + ] + else: + first_image_controlnet_condition = [ + x[:, :, :1, :, :] for x in batch_condition + ] + else: + first_image_controlnet_condition = None + if need_video2video: + if condition_images is None: + video = batch[:, :, :1, :, :] + else: + video = condition_images + else: + video = None + if condition_images is not None and not redraw_condition_image: + first_image = condition_images + first_image_latents = None + else: + ( + first_image, + first_image_latents, + _, + _, + _, + ) = self.pipeline( + prompt=prompt, + image=video, + control_image=first_image_controlnet_condition, + num_inference_steps=num_inference_steps, + video_length=1, + height=height, + width=width, + return_dict=False, + skip_temporal_layer=True, + output_type="np", + generator=generator, + negative_prompt=negative_prompt, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + w_ind_noise=w_ind_noise, + strength=strength, + refer_image=refer_image + if redraw_condition_image_with_referencenet + else None, + ip_adapter_image=ip_adapter_image + if redraw_condition_image_with_ipdapter + else None, + refer_face_image=refer_face_image + if redraw_condition_image_with_facein + else None, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_scale=ip_adapter_face_scale, + ip_adapter_face_image=refer_face_image + if redraw_condition_image_with_ip_adapter_face + else None, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + ) + if refer_image is not None: + refer_image = first_image * 255.0 + if ip_adapter_image is not None: + ip_adapter_image = first_image * 255.0 + # 首帧用于后续推断可以直接用first_image_latent不需要 first_image了 + first_image = None + if self.pipeline.controlnet is not None: + if not isinstance(batch_condition, list): + control_image = batch_condition[:, :, 1:, :, :] + logger.debug(f"control_image={control_image.shape}") + else: + control_image = [x[:, :, 1:, :, :] for x in batch_condition] + else: + control_image = None + + actual_video_length = time_size - int(video_has_condition) + if need_video2video: + video = batch[:, :, 1:, :, :] + else: + video = None + + result_overlap = 0 + else: + actual_video_length = time_size + if self.pipeline.controlnet is not None: + if not fix_condition_images: + logger.debug( + f"{i_batch}, update first_image_controlnet_condition" + ) + + if not isinstance(last_batch_condition, list): + first_image_controlnet_condition = last_batch_condition[ + :, :, -1:, :, : + ] + else: + first_image_controlnet_condition = [ + x[:, :, -1:, :, :] for x in last_batch_condition + ] + else: + logger.debug( + f"{i_batch}, do not update first_image_controlnet_condition" + ) + control_image = batch_condition + else: + control_image = None + first_image_controlnet_condition = None + if not fix_condition_images: + logger.debug(f"{i_batch}, update condition_images") + first_image_latents = out_latents_batch[:, :, -1:, :, :] + else: + logger.debug(f"{i_batch}, do not update condition_images") + + if need_video2video: + video = batch + else: + video = None + result_overlap = 1 + + # 更新 ref_image和 ipadapter_image + if not fixed_refer_image: + logger.debug( + "ref_image use last frame of last generated out video" + ) + refer_image = ( + out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + ) + else: + logger.debug("use given fixed ref_image") + + if not fixed_ip_adapter_image: + logger.debug( + "ip_adapter_image use last frame of last generated out video" + ) + ip_adapter_image = ( + out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + ) + else: + logger.debug("use given fixed ip_adapter_image") + + # face image + if not fixed_ip_adapter_image: + logger.debug( + "refer_face_image use last frame of last generated out video" + ) + refer_face_image = ( + out_batch[:, :, -n_vision_condition:, :, :] * 255.0 + ) + else: + logger.debug("use given fixed ip_adapter_image") + + out = self.pipeline( + video_length=actual_video_length, # int + prompt=prompt, + num_inference_steps=video_num_inference_steps, + height=height, + width=width, + generator=generator, + image=video, + control_image=control_image, # b ci(3) t hi wi + controlnet_condition_images=first_image_controlnet_condition, # b ci(3) t(1) hi wi + # controlnet_condition_images=np.zeros_like( + # first_image_controlnet_condition + # ), # b ci(3) t(1) hi wi + condition_images=first_image, + condition_latents=first_image_latents, # b co t(1) ho wo + skip_temporal_layer=False, + output_type="np", + noise_type=noise_type, + negative_prompt=video_negative_prompt, + need_img_based_video_noise=need_img_based_video_noise, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + w_ind_noise=w_ind_noise, + img_weight=img_weight, + motion_speed=video_reader.sample_rate, + guidance_scale=video_guidance_scale, + guidance_scale_end=video_guidance_scale_end, + guidance_scale_method=video_guidance_scale_method, + strength=video_strength, + refer_image=refer_image, + ip_adapter_image=ip_adapter_image, + refer_face_image=refer_face_image, + ip_adapter_scale=ip_adapter_scale, + facein_scale=facein_scale, + ip_adapter_face_scale=ip_adapter_face_scale, + ip_adapter_face_image=refer_face_image, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + initial_common_latent=initial_common_latent, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + last_mid_video_noises=last_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + last_mid_video_latents=last_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + ) + last_batch = batch + last_batch_condition = batch_condition + last_mid_video_latents = out.mid_video_latents + last_mid_video_noises = out.mid_video_noises + out_batch = out.videos[:, :, result_overlap:, :, :] + out_latents_batch = out.latents[:, :, result_overlap:, :, :] + out_videos.append(out_batch) + if need_return_videos: + videos.append(batch) + if out_condition is not None: + out_condition.append(batch_condition) + + out_videos = np.concatenate(out_videos, axis=2) + if need_return_videos: + videos = np.concatenate(videos, axis=2) + if out_condition is not None: + if not isinstance(out_condition[0], list): + out_condition = np.concatenate(out_condition, axis=2) + else: + out_condition = [ + [out_condition[j][i] for j in range(len(out_condition))] + for i in range(len(out_condition[0])) + ] + out_condition = [np.concatenate(x, axis=2) for x in out_condition] + if need_hist_match: + videos[:, :, 1:, :, :] = hist_match_video_bcthw( + videos[:, :, 1:, :, :], videos[:, :, :1, :, :], value=255.0 + ) + return out_videos, out_condition, videos diff --git a/musev/schedulers/__init__.py b/musev/schedulers/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e1e491a502d517a114a93cec4c6080dbe2119990 --- /dev/null +++ b/musev/schedulers/__init__.py @@ -0,0 +1,6 @@ +from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler +from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler +from .scheduling_euler_discrete import EulerDiscreteScheduler +from .scheduling_lcm import LCMScheduler +from .scheduling_ddim import DDIMScheduler +from .scheduling_ddpm import DDPMScheduler diff --git a/musev/schedulers/scheduling_ddim.py b/musev/schedulers/scheduling_ddim.py new file mode 100755 index 0000000000000000000000000000000000000000..bcfa43ef0233372d8cc21c825299a605262f911b --- /dev/null +++ b/musev/schedulers/scheduling_ddim.py @@ -0,0 +1,302 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy import ndarray +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, +) +from diffusers.schedulers.scheduling_ddim import ( + DDIMSchedulerOutput, + rescale_zero_terminal_snr, + betas_for_alpha_bar, + DDIMScheduler as DiffusersDDIMScheduler, +) +from ..utils.noise_util import video_fusion_noise + + +class DDIMScheduler(DiffusersDDIMScheduler): + """ + `DDIMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: ndarray | List[float] | None = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1, + sample_max_value: float = 1, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + clip_sample, + set_alpha_to_one, + steps_offset, + prediction_type, + thresholding, + dynamic_thresholding_ratio, + clip_sample_range, + sample_max_value, + timestep_spacing, + rescale_betas_zero_snr, + ) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = ( + timestep - self.config.num_train_timesteps // self.num_inference_steps + ) + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + ( + beta_prod_t**0.5 + ) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = ( + sample - alpha_prod_t ** (0.5) * pred_original_sample + ) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** ( + 0.5 + ) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = ( + alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + ) + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + # if variance_noise is None: + # variance_noise = randn_tensor( + # model_output.shape, + # generator=generator, + # device=model_output.device, + # dtype=model_output.dtype, + # ) + device = model_output.device + + if noise_type == "random": + variance_noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + variance_noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) diff --git a/musev/schedulers/scheduling_ddpm.py b/musev/schedulers/scheduling_ddpm.py new file mode 100755 index 0000000000000000000000000000000000000000..b55e30fa08ea9249c3fff20e88fd5997f5fee9b9 --- /dev/null +++ b/musev/schedulers/scheduling_ddpm.py @@ -0,0 +1,262 @@ +# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +from numpy import ndarray +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, +) +from diffusers.schedulers.scheduling_ddpm import ( + DDPMSchedulerOutput, + betas_for_alpha_bar, + DDPMScheduler as DiffusersDDPMScheduler, +) +from ..utils.noise_util import video_fusion_noise + + +class DDPMScheduler(DiffusersDDPMScheduler): + """ + `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + variance_type (`str`, defaults to `"fixed_small"`): + Clip the variance when adding noise to the denoised sample. Choose from `fixed_small`, `fixed_small_log`, + `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: ndarray | List[float] | None = None, + variance_type: str = "fixed_small", + clip_sample: bool = True, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1, + sample_max_value: float = 1, + timestep_spacing: str = "leading", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + variance_type, + clip_sample, + prediction_type, + thresholding, + dynamic_thresholding_ratio, + clip_sample_range, + sample_max_value, + timestep_spacing, + steps_offset, + ) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> Union[DDPMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + t = timestep + + prev_t = self.previous_timestep(t) + + if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [ + "learned", + "learned_range", + ]: + model_output, predicted_variance = torch.split( + model_output, sample.shape[1], dim=1 + ) + else: + predicted_variance = None + + # 1. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[t] + alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + current_alpha_t = alpha_prod_t / alpha_prod_t_prev + current_beta_t = 1 - current_alpha_t + + # 2. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for the DDPMScheduler." + ) + + # 3. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_original_sample_coeff = ( + alpha_prod_t_prev ** (0.5) * current_beta_t + ) / beta_prod_t + current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t + + # 5. Compute predicted previous sample µ_t + # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf + pred_prev_sample = ( + pred_original_sample_coeff * pred_original_sample + + current_sample_coeff * sample + ) + + # 6. Add noise + variance = 0 + if t > 0: + device = model_output.device + # if variance_noise is None: + # variance_noise = randn_tensor( + # model_output.shape, + # generator=generator, + # device=model_output.device, + # dtype=model_output.dtype, + # ) + device = model_output.device + + if noise_type == "random": + variance_noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + variance_noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + if self.variance_type == "fixed_small_log": + variance = ( + self._get_variance(t, predicted_variance=predicted_variance) + * variance_noise + ) + elif self.variance_type == "learned_range": + variance = self._get_variance(t, predicted_variance=predicted_variance) + variance = torch.exp(0.5 * variance) * variance_noise + else: + variance = ( + self._get_variance(t, predicted_variance=predicted_variance) ** 0.5 + ) * variance_noise + + pred_prev_sample = pred_prev_sample + variance + + if not return_dict: + return (pred_prev_sample,) + + return DDPMSchedulerOutput( + prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample + ) diff --git a/musev/schedulers/scheduling_dpmsolver_multistep.py b/musev/schedulers/scheduling_dpmsolver_multistep.py new file mode 100755 index 0000000000000000000000000000000000000000..82d865c54200692b19e70889361a04f374fd8802 --- /dev/null +++ b/musev/schedulers/scheduling_dpmsolver_multistep.py @@ -0,0 +1,815 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + DPM-Solver (and the improved version DPM-Solver++) is a fast dedicated high-order solver for diffusion ODEs with + the convergence order guarantee. Empirically, sampling by DPM-Solver with only 20 steps can generate high-quality + samples, and it can generate quite good samples even in only 10 steps. + + For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 + + Currently, we support the multistep DPM-Solver for both noise prediction models and data prediction models. We + recommend to use `solver_order=2` for guided sampling, and `solver_order=3` for unconditional sampling. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse + diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the + second-order `sde-dpmsolver++`. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = True, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", + "dpmsolver++", + "sde-dpmsolver", + "sde-dpmsolver++", + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} does is not implemented for {self.__class__}" + ) + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} does is not implemented for {self.__class__}" + ) + + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas + + def set_timesteps( + self, num_inference_steps: int = None, device: Union[str, torch.device] = None + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted( + torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped + ) + timesteps = ( + np.linspace( + 0, + self.config.num_train_timesteps - 1 - clipped_idx, + num_inference_steps + 1, + ) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + timesteps = np.array( + [self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas] + ).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = ( + sample.float() + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = ( + torch.clamp(sample, -s, s) / s + ) # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = ( + np.cumsum((dists >= 0), axis=0) + .argmax(axis=0) + .clip(max=log_sigmas.shape[0] - 2) + ) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras( + self, in_sigmas: torch.FloatTensor, num_inference_steps + ) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - ( + alpha_t * (torch.exp(-h) - 1.0) + ) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - ( + sigma_t * (torch.exp(h) - 1.0) + ) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = ( + prev_timestep, + timestep_list[-1], + timestep_list[-2], + timestep_list[-3], + ) + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = ( + 0 + if step_index == len(self.timesteps) - 1 + else self.timesteps[step_index + 1] + ) + lower_order_final = ( + (step_index == len(self.timesteps) - 1) + and self.config.lower_order_final + and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) + and self.config.lower_order_final + and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + # noise = randn_tensor( + # model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + # ) + common_noise = torch.randn( + model_output.shape[:2] + (1,) + model_output.shape[3:], + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) # common noise + ind_noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=model_output.dtype, + ) + s = torch.tensor( + w_ind_noise, device=model_output.device, dtype=model_output.dtype + ).to(device) + noise = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise + + else: + noise = None + + if ( + self.config.solver_order == 1 + or self.lower_order_nums < 1 + or lower_order_final + ): + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) + elif ( + self.config.solver_order == 2 + or self.lower_order_nums < 2 + or lower_order_second + ): + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + ) + else: + timestep_list = [ + self.timesteps[step_index - 2], + self.timesteps[step_index - 1], + timestep, + ] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input( + self, sample: torch.FloatTensor, *args, **kwargs + ) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to( + device=original_samples.device, dtype=original_samples.dtype + ) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = ( + sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + ) + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/musev/schedulers/scheduling_euler_ancestral_discrete.py b/musev/schedulers/scheduling_euler_ancestral_discrete.py new file mode 100755 index 0000000000000000000000000000000000000000..71dadb2ecf9150b0ae15e727fbc9bb068ecc4a42 --- /dev/null +++ b/musev/schedulers/scheduling_euler_ancestral_discrete.py @@ -0,0 +1,356 @@ +# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config + +from diffusers.utils import BaseOutput, logging + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, +) + +from ..utils.noise_util import video_fusion_noise + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete +class EulerAncestralDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): + """ + Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace( + beta_start, beta_end, num_train_timesteps, dtype=torch.float32 + ) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace( + beta_start**0.5, + beta_end**0.5, + num_train_timesteps, + dtype=torch.float32, + ) + ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError( + f"{beta_schedule} does is not implemented for {self.__class__}" + ) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace( + 0, num_train_timesteps - 1, num_train_timesteps, dtype=float + )[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] + ) -> torch.FloatTensor: + """ + Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain + + Returns: + `torch.FloatTensor`: scaled input sample + """ + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps( + self, num_inference_steps: int, device: Union[str, torch.device] = None + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float + )[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=device) + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = torch.from_numpy(timesteps).to(device, dtype=torch.float32) + else: + self.timesteps = torch.from_numpy(timesteps).to(device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise + a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( + sample / (sigma**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = ( + sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2 + ) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = sigma_down - sigma + + prev_sample = sample + derivative * dt + + device = model_output.device + if noise_type == "random": + noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + + prev_sample = prev_sample + noise * sigma_up + + if not return_dict: + return (prev_sample,) + + return EulerAncestralDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype + ) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32 + ) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/musev/schedulers/scheduling_euler_discrete.py b/musev/schedulers/scheduling_euler_discrete.py new file mode 100755 index 0000000000000000000000000000000000000000..e62a22c07b4a3e0c62c2d82e583c41806ea9ce26 --- /dev/null +++ b/musev/schedulers/scheduling_euler_discrete.py @@ -0,0 +1,293 @@ +from __future__ import annotations +import logging + +from typing import List, Optional, Tuple, Union +import numpy as np +from numpy import ndarray +import torch +from torch import Generator, FloatTensor +from diffusers.schedulers.scheduling_euler_discrete import ( + EulerDiscreteScheduler as DiffusersEulerDiscreteScheduler, + EulerDiscreteSchedulerOutput, +) +from diffusers.utils.torch_utils import randn_tensor + +from ..utils.noise_util import video_fusion_noise + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +class EulerDiscreteScheduler(DiffusersEulerDiscreteScheduler): + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: ndarray | List[float] | None = None, + prediction_type: str = "epsilon", + interpolation_type: str = "linear", + use_karras_sigmas: bool | None = False, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + prediction_type, + interpolation_type, + use_karras_sigmas, + timestep_spacing, + steps_offset, + ) + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> Union[EulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + + gamma = ( + min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigma <= s_tmax + else 0.0 + ) + device = model_output.device + + if noise_type == "random": + noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if ( + self.config.prediction_type == "original_sample" + or self.config.prediction_type == "sample" + ): + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( + sample / (sigma**2 + 1) + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[self.step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) + + def step_bk( + self, + model_output: FloatTensor, + timestep: float | FloatTensor, + sample: FloatTensor, + s_churn: float = 0, + s_tmin: float = 0, + s_tmax: float = float("inf"), + s_noise: float = 1, + generator: Generator | None = None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> EulerDiscreteSchedulerOutput | Tuple: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`float`): current timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + s_churn (`float`) + s_tmin (`float`) + s_tmax (`float`) + s_noise (`float`) + generator (`torch.Generator`, optional): Random number generator. + return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if not self.is_scale_input_called: + logger.warning( + "The `scale_model_input` function should be called before `step` to ensure correct denoising. " + "See `StableDiffusionPipeline` for a usage example." + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + gamma = ( + min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) + if s_tmin <= sigma <= s_tmax + else 0.0 + ) + + device = model_output.device + if noise_type == "random": + noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + eps = noise * s_noise + sigma_hat = sigma * (gamma + 1) + + if gamma > 0: + sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + # NOTE: "original_sample" should not be an expected prediction_type but is left in for + # backwards compatibility + if ( + self.config.prediction_type == "original_sample" + or self.config.prediction_type == "sample" + ): + pred_original_sample = model_output + elif self.config.prediction_type == "epsilon": + pred_original_sample = sample - sigma_hat * model_output + elif self.config.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + ( + sample / (sigma**2 + 1) + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma_hat + + dt = self.sigmas[step_index + 1] - sigma_hat + + prev_sample = sample + derivative * dt + + if not return_dict: + return (prev_sample,) + + return EulerDiscreteSchedulerOutput( + prev_sample=prev_sample, pred_original_sample=pred_original_sample + ) diff --git a/musev/schedulers/scheduling_lcm.py b/musev/schedulers/scheduling_lcm.py new file mode 100755 index 0000000000000000000000000000000000000000..235bb8e91e069fc837e4cb3d1d71a4d1e83e40cc --- /dev/null +++ b/musev/schedulers/scheduling_lcm.py @@ -0,0 +1,312 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from numpy import ndarray + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.schedulers.scheduling_lcm import ( + LCMSchedulerOutput, + betas_for_alpha_bar, + rescale_zero_terminal_snr, + LCMScheduler as DiffusersLCMScheduler, +) +from ..utils.noise_util import video_fusion_noise + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class LCMScheduler(DiffusersLCMScheduler): + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: ndarray | List[float] | None = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1, + timestep_spacing: str = "leading", + timestep_scaling: float = 10, + rescale_betas_zero_snr: bool = False, + ): + super().__init__( + num_train_timesteps, + beta_start, + beta_end, + beta_schedule, + trained_betas, + original_inference_steps, + clip_sample, + clip_sample_range, + set_alpha_to_one, + steps_offset, + prediction_type, + thresholding, + dynamic_thresholding_ratio, + sample_max_value, + timestep_spacing, + timestep_scaling, + rescale_betas_zero_snr, + ) + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + w_ind_noise: float = 0.5, + noise_type: str = "random", + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = ( + sample - beta_prod_t.sqrt() * model_output + ) / alpha_prod_t.sqrt() + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = ( + alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample( + predicted_original_sample + ) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + device = model_output.device + + if self.step_index != self.num_inference_steps - 1: + if noise_type == "random": + noise = randn_tensor( + model_output.shape, + dtype=model_output.dtype, + device=device, + generator=generator, + ) + elif noise_type == "video_fusion": + noise = video_fusion_noise( + model_output, w_ind_noise=w_ind_noise, generator=generator + ) + prev_sample = ( + alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + ) + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) + + def step_bk( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LCMSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # 1. get previous step value + prev_step_index = self.step_index + 1 + if prev_step_index < len(self.timesteps): + prev_timestep = self.timesteps[prev_step_index] + else: + prev_timestep = timestep + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_timestep] + if prev_timestep >= 0 + else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + # 3. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + + # 4. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = ( + sample - beta_prod_t.sqrt() * model_output + ) / alpha_prod_t.sqrt() + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = ( + alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output + ) + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 5. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample( + predicted_original_sample + ) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 6. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference + # Noise is not used on the final timestep of the timestep schedule. + # This also means that noise is not used for one-step sampling. + if self.step_index != self.num_inference_steps - 1: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=denoised.dtype, + ) + prev_sample = ( + alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise + ) + else: + prev_sample = denoised + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample, denoised) + + return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) diff --git a/musev/utils/__init__.py b/musev/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/musev/utils/attention_util.py b/musev/utils/attention_util.py new file mode 100755 index 0000000000000000000000000000000000000000..94a0c8f2b002d0f47eecb69689289d03b270152f --- /dev/null +++ b/musev/utils/attention_util.py @@ -0,0 +1,74 @@ +from typing import Tuple, Union, Literal + +from einops import repeat +import torch +import numpy as np + + +def get_diags_indices( + shape: Union[int, Tuple[int, int]], k_min: int = 0, k_max: int = 0 +): + if isinstance(shape, int): + shape = (shape, shape) + rows, cols = np.indices(shape) + diag = cols - rows + return np.where((diag >= k_min) & (diag <= k_max)) + + +def generate_mask_from_indices( + shape: Tuple[int, int], + indices: Tuple[np.ndarray, np.ndarray], + big_value: float = 0, + small_value: float = -1e9, +): + matrix = np.ones(shape) * small_value + matrix[indices] = big_value + return matrix + + +def generate_sparse_causcal_attn_mask( + batch_size: int, + n: int, + n_near: int = 1, + big_value: float = 0, + small_value: float = -1e9, + out_type: Literal["torch", "numpy"] = "numpy", + expand: int = 1, +) -> np.ndarray: + """generate b (n expand) (n expand) mask, + where value of diag (0<=<=n_near) and first column of shape mat (n n) is set as big_value, others as small value + expand的概念: + attn 是 b n d 时,mask 是 b n n, 当 attn 是 b (expand n) d 时, mask 是 b (n expand) (n expand) + Args: + batch_size (int): _description_ + n (int): _description_ + n_near (int, optional): _description_. Defaults to 1. + big_value (float, optional): _description_. Defaults to 0. + small_value (float, optional): _description_. Defaults to -1e9. + out_type (Literal["torch", "numpy"], optional): _description_. Defaults to "numpy". + expand (int, optional): _description_. Defaults to 1. + + Returns: + np.ndarray: _description_ + """ + shape = (n, n) + diag_indices = get_diags_indices(n, k_min=-n_near, k_max=0) + first_column = (np.arange(n), np.zeros(n).astype(np.int)) + indices = ( + np.concatenate([diag_indices[0], first_column[0]]), + np.concatenate([diag_indices[1], first_column[1]]), + ) + mask = generate_mask_from_indices( + shape=shape, indices=indices, big_value=big_value, small_value=small_value + ) + mask = repeat(mask, "m n-> b m n", b=batch_size) + if expand > 1: + mask = repeat( + mask, + "b m n -> b (m d1) (n d2)", + d1=expand, + d2=expand, + ) + if out_type == "torch": + mask = torch.from_numpy(mask) + return mask diff --git a/musev/utils/convert_from_ckpt.py b/musev/utils/convert_from_ckpt.py new file mode 100755 index 0000000000000000000000000000000000000000..c8dd541d0f7b75e2f9702fbae0e3610c1584cbda --- /dev/null +++ b/musev/utils/convert_from_ckpt.py @@ -0,0 +1,963 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for the Stable Diffusion checkpoints.""" + +import re +from io import BytesIO +from typing import Optional + +import requests +import torch +from transformers import ( + AutoFeatureExtractor, + BertTokenizerFast, + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers.models import ( + AutoencoderKL, + PriorTransformer, + UNet2DConditionModel, +) +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UnCLIPScheduler, +) +from diffusers.utils.import_utils import BACKENDS_MAPPING + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + if controlnet: + unet_params = original_config.model.params.control_stage_config.params + else: + unet_params = original_config.model.params.unet_config.params + + vae_params = original_config.model.params.first_stage_config.params.ddconfig + + block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) + + head_dim = unet_params.num_heads if "num_heads" in unet_params else None + use_linear_projection = ( + unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False + ) + if use_linear_projection: + # stable diffusion 2-base-512 and 2-768 + if head_dim is None: + head_dim = [5, 10, 20, 20] + + class_embed_type = None + projection_class_embeddings_input_dim = None + + if "num_classes" in unet_params: + if unet_params.num_classes == "sequential": + class_embed_type = "projection" + assert "adm_in_channels" in unet_params + projection_class_embeddings_input_dim = unet_params.adm_in_channels + else: + raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") + + config = { + "sample_size": image_size // vae_scale_factor, + "in_channels": unet_params.in_channels, + "down_block_types": tuple(down_block_types), + "block_out_channels": tuple(block_out_channels), + "layers_per_block": unet_params.num_res_blocks, + "cross_attention_dim": unet_params.context_dim, + "attention_head_dim": head_dim, + "use_linear_projection": use_linear_projection, + "class_embed_type": class_embed_type, + "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, + } + + if not controlnet: + config["out_channels"] = unet_params.out_channels + config["up_block_types"] = tuple(up_block_types) + + return config + + +def create_vae_diffusers_config(original_config, image_size: int): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + vae_params = original_config.model.params.first_stage_config.params.ddconfig + _ = original_config.model.params.first_stage_config.params.embed_dim + + block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = { + "sample_size": image_size, + "in_channels": vae_params.in_channels, + "out_channels": vae_params.out_ch, + "down_block_types": tuple(down_block_types), + "up_block_types": tuple(up_block_types), + "block_out_channels": tuple(block_out_channels), + "latent_channels": vae_params.z_channels, + "layers_per_block": vae_params.num_res_blocks, + } + return config + + +def create_diffusers_schedular(original_config): + schedular = DDIMScheduler( + num_train_timesteps=original_config.model.params.timesteps, + beta_start=original_config.model.params.linear_start, + beta_end=original_config.model.params.linear_end, + beta_schedule="scaled_linear", + ) + return schedular + + +def create_ldm_bert_config(original_config): + bert_params = original_config.model.parms.cond_stage_config.params + config = LDMBertConfig( + d_model=bert_params.n_embed, + encoder_layers=bert_params.n_layer, + encoder_ffn_dim=bert_params.n_embed * 4, + ) + return config + + +def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") + print( + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." + ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) + + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + if config["class_embed_type"] is None: + # No parameters to port + ... + elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": + new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] + new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] + new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] + new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] + else: + raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + if not controlnet: + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] + for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] + for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] + for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [ + key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key + ] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.bias" + ) + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + + output_block_list = {k: sorted(v) for k, v in output_block_list.items()} + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint( + paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config + ) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + if controlnet: + # conditioning embedding + + orig_index = 0 + + new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + orig_index += 2 + + diffusers_index = 0 + + while diffusers_index < 6: + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + diffusers_index += 1 + orig_index += 2 + + new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.weight" + ) + new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( + f"input_hint_block.{orig_index}.bias" + ) + + # down blocks + for i in range(num_input_blocks): + new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") + new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") + + # mid block + new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") + new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def convert_ldm_bert_checkpoint(checkpoint, config): + def _copy_attn_layer(hf_attn_layer, pt_attn_layer): + hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight + hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight + hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight + + hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight + hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias + + def _copy_linear(hf_linear, pt_linear): + hf_linear.weight = pt_linear.weight + hf_linear.bias = pt_linear.bias + + def _copy_layer(hf_layer, pt_layer): + # copy layer norms + _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) + _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) + + # copy attn + _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) + + # copy MLP + pt_mlp = pt_layer[1][1] + _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) + _copy_linear(hf_layer.fc2, pt_mlp.net[2]) + + def _copy_layers(hf_layers, pt_layers): + for i, hf_layer in enumerate(hf_layers): + if i != 0: + i += i + pt_layer = pt_layers[i : i + 2] + _copy_layer(hf_layer, pt_layer) + + hf_model = LDMBertModel(config).eval() + + # copy embeds + hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight + hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight + + # copy layer norm + _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) + + # copy hidden layers + _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) + + _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) + + return hf_model + + +def convert_ldm_clip_checkpoint(checkpoint, pretrained_model_path): + text_model = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + text_model.load_state_dict(text_model_dict) + + return text_model + + +textenc_conversion_lst = [ + ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), + ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), + ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), + ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), +] +textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} + +textenc_transformer_conversion_lst = [ + # (stable-diffusion, HF Diffusers) + ("resblocks.", "text_model.encoder.layers."), + ("ln_1", "layer_norm1"), + ("ln_2", "layer_norm2"), + (".c_fc.", ".fc1."), + (".c_proj.", ".fc2."), + (".attn", ".self_attn"), + ("ln_final.", "transformer.text_model.final_layer_norm."), + ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), + ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), +] +protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} +textenc_pattern = re.compile("|".join(protected.keys())) + + +def convert_paint_by_example_checkpoint(checkpoint): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def convert_open_clip_checkpoint(checkpoint): + text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + if "cond_stage_model.model.text_projection" in checkpoint: + d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) + else: + d_model = 1024 + + text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") + + for key in keys: + if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer + continue + if key in textenc_conversion_map: + text_model_dict[textenc_conversion_map[key]] = checkpoint[key] + if key.startswith("cond_stage_model.model.transformer."): + new_key = key[len("cond_stage_model.model.transformer.") :] + if new_key.endswith(".in_proj_weight"): + new_key = new_key[: -len(".in_proj_weight")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] + text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] + text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] + elif new_key.endswith(".in_proj_bias"): + new_key = new_key[: -len(".in_proj_bias")] + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] + text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] + text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] + else: + new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) + + text_model_dict[new_key] = checkpoint[key] + + text_model.load_state_dict(text_model_dict) + + return text_model + + +def stable_unclip_image_encoder(original_config): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + +def convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + controlnet_model = ControlNetModel(**ctrlnet_config) + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + ) + + controlnet_model.load_state_dict(converted_ctrl_checkpoint) + + return controlnet_model diff --git a/musev/utils/convert_lora_safetensor_to_diffusers.py b/musev/utils/convert_lora_safetensor_to_diffusers.py new file mode 100755 index 0000000000000000000000000000000000000000..7490e38ecfc2a00d90bb97205f32546384443aee --- /dev/null +++ b/musev/utils/convert_lora_safetensor_to_diffusers.py @@ -0,0 +1,154 @@ +# coding=utf-8 +# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Conversion script for the LoRA's safetensors checkpoints. """ + +import argparse + +import torch +from safetensors.torch import load_file + +from diffusers import StableDiffusionPipeline +import pdb + + + +def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = pipeline.unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] + weight_up = state_dict[up_key] + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return pipeline + + + +def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): + # load base model + # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + + # load LoRA weight from .safetensors + # state_dict = load_file(checkpoint_path) + + visited = [] + + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + # update visited list + for item in pair_keys: + visited.append(item) + + return pipeline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + args = parser.parse_args() + + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha + + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/musev/utils/model_util.py b/musev/utils/model_util.py new file mode 100755 index 0000000000000000000000000000000000000000..6e72a23866fb3bfd6d6717f008dd28d38217c5f5 --- /dev/null +++ b/musev/utils/model_util.py @@ -0,0 +1,500 @@ +import gc +import os +from typing import Any, Callable, List, Literal, Union, Dict, Tuple +import logging + +from safetensors.torch import load_file +from safetensors import safe_open +import torch +from torch import nn +from diffusers.models.controlnet import ControlNetModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from .convert_from_ckpt import ( + convert_ldm_unet_checkpoint, + convert_ldm_vae_checkpoint, + convert_ldm_clip_checkpoint, +) +from .convert_lora_safetensor_to_diffusers import convert_motion_lora_ckpt_to_diffusers + +logger = logging.getLogger(__name__) + + +def update_pipeline_model_parameters( + pipeline: DiffusionPipeline, + model_path: str = None, + lora_dict: Dict[str, Dict] = None, + text_model_path: str = None, + device="cuda", + need_unload: bool = False, +): + if model_path is not None: + pipeline = update_pipeline_basemodel( + pipeline, model_path, text_sd_model_path=text_model_path, device=device + ) + if lora_dict is not None: + pipeline, unload_dict = update_pipeline_lora_models( + pipeline, + lora_dict, + device=device, + need_unload=need_unload, + ) + if need_unload: + return pipeline, unload_dict + return pipeline + + +def update_pipeline_basemodel( + pipeline: DiffusionPipeline, + model_path: str, + text_sd_model_path: str, + device: str = "cuda", +): + """使用model_path更新pipeline中的基础参数 + + Args: + pipeline (DiffusionPipeline): _description_ + model_path (str): _description_ + text_sd_model_path (str): _description_ + device (str, optional): _description_. Defaults to "cuda". + + Returns: + _type_: _description_ + """ + # load base + if model_path.endswith(".ckpt"): + state_dict = torch.load(model_path, map_location=device) + pipeline.unet.load_state_dict(state_dict) + print("update sd_model", model_path) + elif model_path.endswith(".safetensors"): + base_state_dict = {} + with safe_open(model_path, framework="pt", device=device) as f: + for key in f.keys(): + base_state_dict[key] = f.get_tensor(key) + + is_lora = all("lora" in k for k in base_state_dict.keys()) + assert is_lora == False, "Base model cannot be LoRA: {}".format(model_path) + + # vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint( + base_state_dict, pipeline.vae.config + ) + pipeline.vae.load_state_dict(converted_vae_checkpoint) + # unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint( + base_state_dict, pipeline.unet.config + ) + pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + # text_model + pipeline.text_encoder = convert_ldm_clip_checkpoint( + base_state_dict, text_sd_model_path + ) + print("update sd_model", model_path) + pipeline.to(device) + return pipeline + + +# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/cfg.yaml +LORA_BLOCK_WEIGHT_MAP = { + "FACE": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], + "DEFACE": [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1], + "ALL": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "MIDD": [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + "OUTALL": [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], +} + + +# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py +def update_pipeline_lora_model( + pipeline: DiffusionPipeline, + lora: Union[str, Dict], + alpha: float = 0.75, + device: str = "cuda", + lora_prefix_unet: str = "lora_unet", + lora_prefix_text_encoder: str = "lora_te", + lora_unet_layers=[ + "lora_unet_down_blocks_0_attentions_0", + "lora_unet_down_blocks_0_attentions_1", + "lora_unet_down_blocks_1_attentions_0", + "lora_unet_down_blocks_1_attentions_1", + "lora_unet_down_blocks_2_attentions_0", + "lora_unet_down_blocks_2_attentions_1", + "lora_unet_mid_block_attentions_0", + "lora_unet_up_blocks_1_attentions_0", + "lora_unet_up_blocks_1_attentions_1", + "lora_unet_up_blocks_1_attentions_2", + "lora_unet_up_blocks_2_attentions_0", + "lora_unet_up_blocks_2_attentions_1", + "lora_unet_up_blocks_2_attentions_2", + "lora_unet_up_blocks_3_attentions_0", + "lora_unet_up_blocks_3_attentions_1", + "lora_unet_up_blocks_3_attentions_2", + ], + lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", + need_unload: bool = False, +): + """使用 lora 更新pipeline中的unet相关参数 + + Args: + pipeline (DiffusionPipeline): _description_ + lora (Union[str, Dict]): _description_ + alpha (float, optional): _description_. Defaults to 0.75. + device (str, optional): _description_. Defaults to "cuda". + lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". + lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". + lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. + lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". + need_unload (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 + if lora_block_weight_str is not None: + lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] + if lora_block_weight: + assert len(lora_block_weight) == 17 + # load lora weight + if isinstance(lora, str): + state_dict = load_file(lora, device=device) + else: + for k in lora: + lora[k] = lora[k].to(device) + state_dict = lora # state_dict = {} + + visited = set() + unload_dict = [] + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = ( + key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") + ) + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + alpha_key = key.replace("lora_down.weight", "alpha") + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + alpha_key = key.replace("lora_up.weight", "alpha") + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = ( + state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + ) + if alpha_key in state_dict: + weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] + else: + weight_scale = 1.0 + # adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + if len(weight_up.shape) == len(weight_down.shape): + adding_weight = ( + alpha + * weight_scale + * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + ) + else: + adding_weight = ( + alpha + * weight_scale + * torch.einsum("a b, b c h w -> a c h w", weight_up, weight_down) + ) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + if alpha_key in state_dict: + weight_scale = state_dict[alpha_key].item() / weight_up.shape[1] + else: + weight_scale = 1.0 + adding_weight = alpha * weight_scale * torch.mm(weight_up, weight_down) + adding_weight = adding_weight.to(torch.float16) + if lora_block_weight: + if "text" in key: + adding_weight *= lora_block_weight[0] + else: + for idx, layer in enumerate(lora_unet_layers): + if layer in key: + adding_weight *= lora_block_weight[idx + 1] + break + + curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} + curr_layer.weight.data += adding_weight + + unload_dict.append(curr_layer_unload_data) + # update visited list + for item in pair_keys: + visited.add(item) + if need_unload: + return pipeline, unload_dict + else: + return pipeline + + +# ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/draw_pipe.py +def update_pipeline_lora_model_old( + pipeline: DiffusionPipeline, + lora: Union[str, Dict], + alpha: float = 0.75, + device: str = "cuda", + lora_prefix_unet: str = "lora_unet", + lora_prefix_text_encoder: str = "lora_te", + lora_unet_layers=[ + "lora_unet_down_blocks_0_attentions_0", + "lora_unet_down_blocks_0_attentions_1", + "lora_unet_down_blocks_1_attentions_0", + "lora_unet_down_blocks_1_attentions_1", + "lora_unet_down_blocks_2_attentions_0", + "lora_unet_down_blocks_2_attentions_1", + "lora_unet_mid_block_attentions_0", + "lora_unet_up_blocks_1_attentions_0", + "lora_unet_up_blocks_1_attentions_1", + "lora_unet_up_blocks_1_attentions_2", + "lora_unet_up_blocks_2_attentions_0", + "lora_unet_up_blocks_2_attentions_1", + "lora_unet_up_blocks_2_attentions_2", + "lora_unet_up_blocks_3_attentions_0", + "lora_unet_up_blocks_3_attentions_1", + "lora_unet_up_blocks_3_attentions_2", + ], + lora_block_weight_str: Literal["FACE", "ALL"] = "ALL", + need_unload: bool = False, +): + """使用 lora 更新pipeline中的unet相关参数 + + Args: + pipeline (DiffusionPipeline): _description_ + lora (Union[str, Dict]): _description_ + alpha (float, optional): _description_. Defaults to 0.75. + device (str, optional): _description_. Defaults to "cuda". + lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". + lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". + lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. + lora_block_weight_str (Literal["FACE", "ALL"], optional): _description_. Defaults to "ALL". + need_unload (bool, optional): _description_. Defaults to False. + + Returns: + _type_: _description_ + """ + # ref https://git.woa.com/innovative_tech/GenerationGroup/VirtualIdol/VidolImageDraw/blob/master/pipeline/tool.py#L20 + if lora_block_weight_str is not None: + lora_block_weight = LORA_BLOCK_WEIGHT_MAP[lora_block_weight_str.upper()] + if lora_block_weight: + assert len(lora_block_weight) == 17 + # load lora weight + if isinstance(lora, str): + state_dict = load_file(lora, device=device) + else: + for k in lora: + lora[k] = lora[k].to(device) + state_dict = lora # state_dict = {} + + visited = set() + unload_dict = [] + # directly update weight in diffusers model + for key in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + # as we have set the alpha beforehand, so just skip + if ".alpha" in key or key in visited: + continue + + if "text" in key: + layer_infos = ( + key.split(".")[0].split(lora_prefix_text_encoder + "_")[-1].split("_") + ) + curr_layer = pipeline.text_encoder + else: + layer_infos = key.split(".")[0].split(lora_prefix_unet + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + pair_keys = [] + if "lora_down" in key: + pair_keys.append(key.replace("lora_down", "lora_up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora_up", "lora_down")) + + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = ( + state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + ) + adding_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze( + 2 + ).unsqueeze(3) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + adding_weight = alpha * torch.mm(weight_up, weight_down) + + if lora_block_weight: + if "text" in key: + adding_weight *= lora_block_weight[0] + else: + for idx, layer in enumerate(lora_unet_layers): + if layer in key: + adding_weight *= lora_block_weight[idx + 1] + break + + curr_layer_unload_data = {"layer": curr_layer, "added_weight": adding_weight} + curr_layer.weight.data += adding_weight + + unload_dict.append(curr_layer_unload_data) + # update visited list + for item in pair_keys: + visited.add(item) + if need_unload: + return pipeline, unload_dict + else: + return pipeline + + +def update_pipeline_lora_models( + pipeline: DiffusionPipeline, + lora_dict: Dict[str, Dict], + device: str = "cuda", + need_unload: bool = True, + lora_prefix_unet: str = "lora_unet", + lora_prefix_text_encoder: str = "lora_te", + lora_unet_layers=[ + "lora_unet_down_blocks_0_attentions_0", + "lora_unet_down_blocks_0_attentions_1", + "lora_unet_down_blocks_1_attentions_0", + "lora_unet_down_blocks_1_attentions_1", + "lora_unet_down_blocks_2_attentions_0", + "lora_unet_down_blocks_2_attentions_1", + "lora_unet_mid_block_attentions_0", + "lora_unet_up_blocks_1_attentions_0", + "lora_unet_up_blocks_1_attentions_1", + "lora_unet_up_blocks_1_attentions_2", + "lora_unet_up_blocks_2_attentions_0", + "lora_unet_up_blocks_2_attentions_1", + "lora_unet_up_blocks_2_attentions_2", + "lora_unet_up_blocks_3_attentions_0", + "lora_unet_up_blocks_3_attentions_1", + "lora_unet_up_blocks_3_attentions_2", + ], +): + """使用 lora 更新pipeline中的unet相关参数 + + Args: + pipeline (DiffusionPipeline): _description_ + lora_dict (Dict[str, Dict]): _description_ + device (str, optional): _description_. Defaults to "cuda". + lora_prefix_unet (str, optional): _description_. Defaults to "lora_unet". + lora_prefix_text_encoder (str, optional): _description_. Defaults to "lora_te". + lora_unet_layers (list, optional): _description_. Defaults to [ "lora_unet_down_blocks_0_attentions_0", "lora_unet_down_blocks_0_attentions_1", "lora_unet_down_blocks_1_attentions_0", "lora_unet_down_blocks_1_attentions_1", "lora_unet_down_blocks_2_attentions_0", "lora_unet_down_blocks_2_attentions_1", "lora_unet_mid_block_attentions_0", "lora_unet_up_blocks_1_attentions_0", "lora_unet_up_blocks_1_attentions_1", "lora_unet_up_blocks_1_attentions_2", "lora_unet_up_blocks_2_attentions_0", "lora_unet_up_blocks_2_attentions_1", "lora_unet_up_blocks_2_attentions_2", "lora_unet_up_blocks_3_attentions_0", "lora_unet_up_blocks_3_attentions_1", "lora_unet_up_blocks_3_attentions_2", ]. + + Returns: + _type_: _description_ + """ + unload_dicts = [] + for lora, value in lora_dict.items(): + lora_name = os.path.basename(lora).replace(".safetensors", "") + strength_offset = value.get("strength_offset", 0.0) + alpha = value.get("strength", 1.0) + alpha += strength_offset + lora_weight_str = value.get("lora_block_weight", "ALL") + lora = load_file(lora) + pipeline, unload_dict = update_pipeline_lora_model( + pipeline, + lora=lora, + device=device, + alpha=alpha, + lora_prefix_unet=lora_prefix_unet, + lora_prefix_text_encoder=lora_prefix_text_encoder, + lora_unet_layers=lora_unet_layers, + lora_block_weight_str=lora_weight_str, + need_unload=True, + ) + print( + "Update LoRA {} with alpha {} and weight {}".format( + lora_name, alpha, lora_weight_str + ) + ) + unload_dicts += unload_dict + return pipeline, unload_dicts + + +def unload_lora(unload_dict: List[Dict[str, nn.Module]]): + for layer_data in unload_dict: + layer = layer_data["layer"] + added_weight = layer_data["added_weight"] + layer.weight.data -= added_weight + + gc.collect() + torch.cuda.empty_cache() + + +def load_motion_lora_weights( + animation_pipeline, + motion_module_lora_configs=[], +): + for motion_module_lora_config in motion_module_lora_configs: + path, alpha = ( + motion_module_lora_config["path"], + motion_module_lora_config["alpha"], + ) + print(f"load motion LoRA from {path}") + + motion_lora_state_dict = torch.load(path, map_location="cpu") + motion_lora_state_dict = ( + motion_lora_state_dict["state_dict"] + if "state_dict" in motion_lora_state_dict + else motion_lora_state_dict + ) + + animation_pipeline = convert_motion_lora_ckpt_to_diffusers( + animation_pipeline, motion_lora_state_dict, alpha + ) + + return animation_pipeline diff --git a/musev/utils/noise_util.py b/musev/utils/noise_util.py new file mode 100755 index 0000000000000000000000000000000000000000..2e5c1c070afcf645626b1d067226168f5344d3ed --- /dev/null +++ b/musev/utils/noise_util.py @@ -0,0 +1,83 @@ +from typing import List, Optional, Tuple, Union +import torch + + +from diffusers.utils.torch_utils import randn_tensor + + +def random_noise( + tensor: torch.Tensor = None, + shape: Tuple[int] = None, + dtype: torch.dtype = None, + device: torch.device = None, + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + noise_offset: Optional[float] = None, # typical value is 0.1 +) -> torch.Tensor: + if tensor is not None: + shape = tensor.shape + device = tensor.device + dtype = tensor.dtype + if isinstance(device, str): + device = torch.device(device) + noise = randn_tensor(shape, dtype=dtype, device=device, generator=generator) + if noise_offset is not None: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += noise_offset * torch.randn( + (tensor.shape[0], tensor.shape[1], 1, 1, 1), device + ) + return noise + + +def video_fusion_noise( + tensor: torch.Tensor = None, + shape: Tuple[int] = None, + dtype: torch.dtype = None, + device: torch.device = None, + w_ind_noise: float = 0.5, + generator: Optional[Union[List[torch.Generator], torch.Generator]] = None, + initial_common_noise: torch.Tensor = None, +) -> torch.Tensor: + if tensor is not None: + shape = tensor.shape + device = tensor.device + dtype = tensor.dtype + if isinstance(device, str): + device = torch.device(device) + batch_size, c, t, h, w = shape + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if not isinstance(generator, list): + if initial_common_noise is not None: + common_noise = initial_common_noise.to(device, dtype=dtype) + else: + common_noise = randn_tensor( + (shape[0], shape[1], 1, shape[3], shape[4]), + generator=generator, + device=device, + dtype=dtype, + ) # common noise + ind_noise = randn_tensor( + shape, + generator=generator, + device=device, + dtype=dtype, + ) # individual noise + s = torch.tensor(w_ind_noise, device=device, dtype=dtype) + latents = torch.sqrt(1 - s) * common_noise + torch.sqrt(s) * ind_noise + else: + latents = [] + for i in range(batch_size): + latent = video_fusion_noise( + shape=(1, c, t, h, w), + dtype=dtype, + device=device, + w_ind_noise=w_ind_noise, + generator=generator[i], + initial_common_noise=initial_common_noise, + ) + latents.append(latent) + latents = torch.cat(latents, dim=0).to(device) + return latents diff --git a/musev/utils/register.py b/musev/utils/register.py new file mode 100755 index 0000000000000000000000000000000000000000..db47ed9b951924e3da9d2d824cf4274c2e4ad7f2 --- /dev/null +++ b/musev/utils/register.py @@ -0,0 +1,44 @@ +import logging + +logger = logging.getLogger(__name__) + + +class Register: + def __init__(self, registry_name): + self._dict = {} + self._name = registry_name + + def __setitem__(self, key, value): + if not callable(value): + raise Exception(f"Value of a Registry must be a callable!\nValue: {value}") + # 优先使用自定义的name,其次使用类名或者函数名。 + if "name" in value.__dict__: + key = value.name + elif key is None: + key = value.__name__ + if key in self._dict: + logger.warning("Key %s already in registry %s." % (key, self._name)) + self._dict[key] = value + + def register(self, target): + """Decorator to register a function or class.""" + + def add(key, value): + self[key] = value + return value + + if callable(target): + # @reg.register + return add(None, target) + # @reg.register('alias') + return lambda x: add(target, x) + + def __getitem__(self, key): + return self._dict[key] + + def __contains__(self, key): + return key in self._dict + + def keys(self): + """key""" + return self._dict.keys() diff --git a/musev/utils/tensor_util.py b/musev/utils/tensor_util.py new file mode 100755 index 0000000000000000000000000000000000000000..ed8707ac7e17019e3d67e1867b7bc4b498981192 --- /dev/null +++ b/musev/utils/tensor_util.py @@ -0,0 +1,34 @@ +import torch +import numpy as np + + +def generate_meshgrid_2d(h: int, w: int, device) -> torch.tensor: + x = torch.linspace(-1, 1, h, device=device) + y = torch.linspace(-1, 1, w, device=device) + grid_x, grid_y = torch.meshgrid(x, y) + grid = torch.stack([grid_x, grid_y], dim=2) + return grid + + +def his_match(src, dst): + src = src * 255.0 + dst = dst * 255.0 + src = src.astype(np.uint8) + dst = dst.astype(np.uint8) + res = np.zeros_like(dst) + + cdf_src = np.zeros((3, 256)) + cdf_dst = np.zeros((3, 256)) + cdf_res = np.zeros((3, 256)) + kw = dict(bins=256, range=(0, 256), density=True) + for ch in range(3): + his_src, _ = np.histogram(src[:, :, ch], **kw) + hist_dst, _ = np.histogram(dst[:, :, ch], **kw) + cdf_src[ch] = np.cumsum(his_src) + cdf_dst[ch] = np.cumsum(hist_dst) + index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left") + np.clip(index, 0, 255, out=index) + res[:, :, ch] = index[dst[:, :, ch]] + his_res, _ = np.histogram(res[:, :, ch], **kw) + cdf_res[ch] = np.cumsum(his_res) + return res / 255.0 diff --git a/musev/utils/text_emb_util.py b/musev/utils/text_emb_util.py new file mode 100755 index 0000000000000000000000000000000000000000..caa46c5c6171dfb61c485eb773ae85ff1cae621c --- /dev/null +++ b/musev/utils/text_emb_util.py @@ -0,0 +1,430 @@ +# Modified from https://github.com/huggingface/diffusers/blob/20e92586c1fda968ea3343ba0f44f2b21f3c09d2/examples/community/lpw_stable_diffusion.py + +import inspect +import re +from typing import Any, Callable, Dict, List, Optional, Union +import torch + +from diffusers import DiffusionPipeline +from diffusers.loaders import TextualInversionLoaderMixin + + +re_attention = re.compile( + r""" + \\\(| + \\\)| + \\\[| + \\]| + \\\\| + \\| + \(| + \[| + :([+-]?[.\d]+)\)| + \)| + ]| + [^\\()\[\]:]+| + : + """, + re.X, + ) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its associated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + >>> parse_prompt_attention('normal text') + [['normal text', 1.0]] + >>> parse_prompt_attention('an (important) word') + [['an ', 1.0], ['important', 1.1], [' word', 1.0]] + >>> parse_prompt_attention('(unbalanced') + [['unbalanced', 1.1]] + >>> parse_prompt_attention('\(literal\]') + [['(literal]', 1.0]] + >>> parse_prompt_attention('(unnecessary)(parens)') + [['unnecessaryparens', 1.1]] + >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') + [['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1]] + """ + + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith("\\"): + res.append([text[1:], 1.0]) + elif text == "(": + round_brackets.append(len(res)) + elif text == "[": + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ")" and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == "]" and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + if len(res) == 0: + res = [["", 1.0]] + + # merge runs of identical weights + i = 0 + while i + 1 < len(res): + if res[i][1] == res[i + 1][1]: + res[i][0] += res[i + 1][0] + res.pop(i + 1) + else: + i += 1 + + return res + + +def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): + r""" + Tokenize a list of prompts and return its tokens with weights of each token. + + No padding, starting or ending token is included. + """ + tokens = [] + weights = [] + truncated = False + for text in prompt: + texts_and_weights = parse_prompt_attention(text) + text_token = [] + text_weight = [] + for word, weight in texts_and_weights: + # tokenize and discard the starting and the ending token + token = pipe.tokenizer(word).input_ids[1:-1] + text_token += token + # copy the weight by length of token + text_weight += [weight] * len(token) + # stop if the text is too long (longer than truncation limit) + if len(text_token) > max_length: + truncated = True + break + # truncate + if len(text_token) > max_length: + truncated = True + text_token = text_token[:max_length] + text_weight = text_weight[:max_length] + tokens.append(text_token) + weights.append(text_weight) + if truncated: + logger.warning("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples") + return tokens, weights + + +def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos_middle=True, chunk_length=77): + r""" + Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. + """ + max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) + weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length + for i in range(len(tokens)): + tokens[i] = [bos] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] + if no_boseos_middle: + weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) + else: + w = [] + if len(weights[i]) == 0: + w = [1.0] * weights_length + else: + for j in range(max_embeddings_multiples): + w.append(1.0) # weight for starting token in this chunk + w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))] + w.append(1.0) # weight for ending token in this chunk + w += [1.0] * (weights_length - len(w)) + weights[i] = w[:] + + return tokens, weights + + +def get_unweighted_text_embeddings( + pipe: DiffusionPipeline, + text_input: torch.Tensor, + chunk_length: int, + no_boseos_middle: Optional[bool] = True, +): + """ + When the length of tokens is a multiple of the capacity of the text encoder, + it should be split into chunks and sent to the text encoder individually. + """ + max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) + if max_embeddings_multiples > 1: + text_embeddings = [] + for i in range(max_embeddings_multiples): + # extract the i-th chunk + text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone() + + # cover the head and the tail by the starting and the ending tokens + text_input_chunk[:, 0] = text_input[0, 0] + text_input_chunk[:, -1] = text_input[0, -1] + text_embedding = pipe.text_encoder(text_input_chunk)[0] + + if no_boseos_middle: + if i == 0: + # discard the ending token + text_embedding = text_embedding[:, :-1] + elif i == max_embeddings_multiples - 1: + # discard the starting token + text_embedding = text_embedding[:, 1:] + else: + # discard both starting and ending tokens + text_embedding = text_embedding[:, 1:-1] + + text_embeddings.append(text_embedding) + text_embeddings = torch.concat(text_embeddings, axis=1) + else: + text_embeddings = pipe.text_encoder(text_input)[0] + return text_embeddings + + +def get_weighted_text_embeddings( + pipe: DiffusionPipeline, + prompt: Union[str, List[str]], + uncond_prompt: Optional[Union[str, List[str]]] = None, + max_embeddings_multiples: Optional[int] = 3, + no_boseos_middle: Optional[bool] = False, + skip_parsing: Optional[bool] = False, + skip_weighting: Optional[bool] = False, +): + r""" + Prompts can be assigned with local weights using brackets. For example, + prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', + and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. + + Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. + + Args: + pipe (`DiffusionPipeline`): + Pipe to provide access to the tokenizer and the text encoder. + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + uncond_prompt (`str` or `List[str]`): + The unconditional prompt or prompts for guide the image generation. If unconditional prompt + is provided, the embeddings of prompt and uncond_prompt are concatenated. + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + no_boseos_middle (`bool`, *optional*, defaults to `False`): + If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and + ending token in each of the chunk in the middle. + skip_parsing (`bool`, *optional*, defaults to `False`): + Skip the parsing of brackets. + skip_weighting (`bool`, *optional*, defaults to `False`): + Skip the weighting. When the parsing is skipped, it is forced True. + """ + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + if isinstance(prompt, str): + prompt = [prompt] + + if not skip_parsing: + prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2) + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens, uncond_weights = get_prompts_with_weights(pipe, uncond_prompt, max_length - 2) + else: + prompt_tokens = [ + token[1:-1] for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids + ] + prompt_weights = [[1.0] * len(token) for token in prompt_tokens] + if uncond_prompt is not None: + if isinstance(uncond_prompt, str): + uncond_prompt = [uncond_prompt] + uncond_tokens = [ + token[1:-1] + for token in pipe.tokenizer(uncond_prompt, max_length=max_length, truncation=True).input_ids + ] + uncond_weights = [[1.0] * len(token) for token in uncond_tokens] + + # round up the longest length of tokens to a multiple of (model_max_length - 2) + max_length = max([len(token) for token in prompt_tokens]) + if uncond_prompt is not None: + max_length = max(max_length, max([len(token) for token in uncond_tokens])) + + max_embeddings_multiples = min( + max_embeddings_multiples, + (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, + ) + max_embeddings_multiples = max(1, max_embeddings_multiples) + max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 + + # pad the length of tokens and weights + bos = pipe.tokenizer.bos_token_id + eos = pipe.tokenizer.eos_token_id + pad = getattr(pipe.tokenizer, "pad_token_id", eos) + prompt_tokens, prompt_weights = pad_tokens_and_weights( + prompt_tokens, + prompt_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) + if uncond_prompt is not None: + uncond_tokens, uncond_weights = pad_tokens_and_weights( + uncond_tokens, + uncond_weights, + max_length, + bos, + eos, + pad, + no_boseos_middle=no_boseos_middle, + chunk_length=pipe.tokenizer.model_max_length, + ) + uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) + + # get the embeddings + text_embeddings = get_unweighted_text_embeddings( + pipe, + prompt_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) + if uncond_prompt is not None: + uncond_embeddings = get_unweighted_text_embeddings( + pipe, + uncond_tokens, + pipe.tokenizer.model_max_length, + no_boseos_middle=no_boseos_middle, + ) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) + + # assign weights to the prompts and normalize in the sense of mean + # TODO: should we normalize by chunk or in a whole (current implementation)? + if (not skip_parsing) and (not skip_weighting): + previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= prompt_weights.unsqueeze(-1) + current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype) + text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if uncond_prompt is not None: + previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= uncond_weights.unsqueeze(-1) + current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) + uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + + if uncond_prompt is not None: + return text_embeddings, uncond_embeddings + return text_embeddings, None + + + +def encode_weighted_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + max_embeddings_multiples=3, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, +): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `list(int)`): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + max_embeddings_multiples (`int`, *optional*, defaults to `3`): + The max multiple length of prompt embeddings compared to the max output length of text encoder. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + if prompt_embeds is None or negative_prompt_embeds is None: + # 以下代码因为不知道是什么作用,造成对应的negative token错误,因此注释掉 + # 与issues44 相关 + +# if isinstance(self, TextualInversionLoaderMixin): +# prompt = self.maybe_convert_prompt(prompt, self.tokenizer) +# if do_classifier_free_guidance and negative_prompt_embeds is None: +# negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + + prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, + ) + if prompt_embeds is None: + prompt_embeds = prompt_embeds1 + if negative_prompt_embeds is None: + negative_prompt_embeds = negative_prompt_embeds1 + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds \ No newline at end of file diff --git a/musev/utils/timesteps_util.py b/musev/utils/timesteps_util.py new file mode 100755 index 0000000000000000000000000000000000000000..6114987d0d4878d958a1b88cad2483f0fb90304d --- /dev/null +++ b/musev/utils/timesteps_util.py @@ -0,0 +1,61 @@ +from typing import List, Literal +import numpy as np + + +def generate_parameters_with_timesteps( + start: int, + num: int, + stop: int = None, + method: Literal["linear", "two_stage", "three_stage", "fix_two_stage"] = "linear", + n_fix_start: int = 3, +) -> List[float]: + if stop is None or start == stop: + params = [start] * num + else: + if method == "linear": + params = generate_linear_parameters(start, stop, num) + elif method == "two_stage": + params = generate_two_stages_parameters(start, stop, num) + elif method == "three_stage": + params = generate_three_stages_parameters(start, stop, num) + elif method == "fix_two_stage": + params = generate_fix_two_stages_parameters(start, stop, num, n_fix_start) + else: + raise ValueError( + f"now only support linear, two_stage, three_stage, but given{method}" + ) + return params + + +def generate_linear_parameters(start, stop, num): + parames = list( + np.linspace( + start=start, + stop=stop, + num=num, + ) + ) + return parames + + +def generate_two_stages_parameters(start, stop, num): + num_start = num // 2 + num_end = num - num_start + parames = [start] * num_start + [stop] * num_end + return parames + + +def generate_fix_two_stages_parameters(start, stop, num, n_fix_start: int) -> List: + num_start = n_fix_start + num_end = num - num_start + parames = [start] * num_start + [stop] * num_end + return parames + + +def generate_three_stages_parameters(start, stop, num): + middle = (start + stop) // 2 + num_start = num // 3 + num_middle = num_start + num_end = num - num_start - num_middle + parames = [start] * num_start + [middle] * num_middle + [stop] * num_end + return parames diff --git a/musev/utils/util.py b/musev/utils/util.py new file mode 100755 index 0000000000000000000000000000000000000000..fe146174635527c944030c47879ae8283c997017 --- /dev/null +++ b/musev/utils/util.py @@ -0,0 +1,383 @@ +import os +import imageio +import numpy as np +from typing import Literal, Union, List, Dict, Tuple + +import torch +import torchvision +import cv2 +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +import webp +import subprocess + +from .. import logger + + +def save_videos_to_images(videos: np.array, path: str, image_type="png") -> None: + """save video batch to images into image_type + + Args: + videos (np.array): [h w c] + path (str): image directory path + """ + os.makedirs(path, exist_ok=True) + for i, video in enumerate(videos): + imageio.imsave(os.path.join(path, f"{i:04d}.{image_type}"), video) + + +def save_videos_grid( + videos: torch.Tensor, + path: str, + rescale=False, + n_rows=4, # 一行多少个视频 + fps=8, + save_type="webp", +) -> None: + videos = rearrange(videos, "b c t h w -> t b c h w") + outputs = [] + for x in videos: + x = torchvision.utils.make_grid(x, nrow=n_rows) + x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + if x.dtype != torch.uint8: + x = (x * 255).numpy().astype(np.uint8) + + if save_type == "webp": + outputs.append(Image.fromarray(x)) + else: + outputs.append(x) + + os.makedirs(os.path.dirname(path), exist_ok=True) + if "gif" in path or save_type == "gif": + params = { + "duration": int(1000 * 1.0 / fps), + "loop": 0, + } + elif save_type == "mp4": + params = { + "quality": 9, + "fps": fps, + "pixelformat": "yuv420p", + } + else: + params = { + "quality": 9, + "fps": fps, + } + + if save_type == "webp": + webp.save_images(outputs, path, fps=fps, lossless=True) + else: + imageio.mimsave(path, outputs, **params) + + +def make_grid_with_opencv( + batch: Union[torch.Tensor, np.ndarray], + nrows: int, + texts: List[str] = None, + rescale: bool = False, + font_size: float = 0.05, + font_thickness: int = 1, + font_color: Tuple[int] = (255, 0, 0), + tensor_order: str = "b c h w", + write_info: bool = False, +) -> np.ndarray: + """read tensor batch and make a grid with opencv + + Args: + batch (Union[torch.Tensor, np.ndarray]): 4 dim tensor, like b c h w + nrows (int): how many rows in the grid + texts (List[str], optional): text to write in video . Defaults to None. + rescale (bool, optional): whether rescale [0,1] from [-1, 1]. Defaults to False. + font_size (float, optional): font size. Defaults to 0.05. + font_thickness (int, optional): font_thickness . Defaults to 1. + font_color (Tuple[int], optional): text color. Defaults to (255, 0, 0). + tensor_order (str, optional): batch channel order. Defaults to "b c h w". + write_info (bool, optional): whether write text into video. Defaults to True. + + Returns: + np.ndarray: h w c + """ + if isinstance(batch, torch.Tensor): + batch = batch.cpu().numpy() + # batch: (B, C, H, W) + batch = rearrange(batch, f"{tensor_order} -> b h w c") + b, h, w, c = batch.shape + ncols = int(np.ceil(b / nrows)) + grid = np.zeros((h * nrows, w * ncols, c), dtype=np.uint8) + font = cv2.FONT_HERSHEY_SIMPLEX + for i, x in enumerate(batch): + i_row, i_col = i // ncols, i % ncols + if rescale: + x = (x + 1.0) / 2.0 # -1,1 -> 0,1 + x = (x * 255).astype(np.uint8) + # 没有这行会报错 + # ref: https://stackoverflow.com/questions/72327137/opencv4-5-5-error-5bad-argument-in-function-puttext + x = x.copy() + if texts is not None and write_info: + x = cv2.putText( + x, + texts[i], + (5, 20), + font, + fontScale=font_size, + color=font_color, + thickness=font_thickness, + ) + grid[i_row * h : (i_row + 1) * h, i_col * w : (i_col + 1) * w, :] = x + return grid + + +def save_videos_grid_with_opencv( + videos: Union[torch.Tensor, np.ndarray], + path: str, + n_cols: int, + texts: List[str] = None, + rescale: bool = False, + fps: int = 8, + font_size: int = 0.6, + font_thickness: int = 1, + font_color: Tuple[int] = (255, 0, 0), + tensor_order: str = "b c t h w", + batch_dim: int = 0, + split_size_or_sections: int = None, # split batch to avoid large video + write_info: bool = False, + save_filetype: Literal["gif", "mp4", "webp"] = "mp4", + save_images: bool = False, +) -> None: + """存储tensor视频为gif、mp4等 + + Args: + videos (Union[torch.Tensor, np.ndarray]): 五维视频tensor, 如 b c t h w,值范围[0-1] + path (str): 视频存储路径,后缀会影响存储方式 + n_cols (int): 由于b可能特别大,所以会分成几列 + texts (List[str], optional): b长度,会写在每个b视频左上角. Defaults to None. + rescale (bool, optional): 输入是[-1,1]时,应该为True. Defaults to False. + fps (int, optional): 存储视频的fps. Defaults to 8. + font_size (int, optional): text对应的字体大小. Defaults to 0.6. + font_thickness (int, optional): 字体宽度. Defaults to 1. + font_color (Tuple[int], optional): 字体颜色. Defaults to (255, 0, 0). + tensor_order (str, optional): 输入tensor的顺序,如果不是 `b c t h w`,会被转换成 b c t h w,. Defaults to "b c t h w". + batch_dim (int, optional): 有时候b特别大,这时候一个视频就太大了,就可以分成几个视频存储. Defaults to 0. + split_size_or_sections (int, optional): 不为None时,与batch_dim配套,一个存储视频最多支持几个子视频。会按照n_cols截断向上取整数. Defaults to None. + write_info (bool, False): 是否也些提示信息在视频上 + """ + if split_size_or_sections is not None: + split_size_or_sections = int(np.ceil(split_size_or_sections / n_cols)) * n_cols + if isinstance(videos, np.ndarray): + videos = torch.from_numpy(videos) + # 比np.array_split更适合 + videos_split = torch.split(videos, split_size_or_sections, dim=batch_dim) + videos_split = [videos.cpu().numpy() for videos in videos_split] + else: + videos_split = [videos] + n_videos_split = len(videos_split) + dirname, basename = os.path.dirname(path), os.path.basename(path) + filename, ext = os.path.splitext(basename) + os.makedirs(dirname, exist_ok=True) + + for i_video, videos in enumerate(videos_split): + videos = rearrange(videos, f"{tensor_order} -> t b c h w") + outputs = [] + font = cv2.FONT_HERSHEY_SIMPLEX + batch_size = videos.shape[1] + n_rows = int(np.ceil(batch_size / n_cols)) + for t, x in enumerate(videos): + x = make_grid_with_opencv( + x, + n_rows, + texts, + rescale, + font_size, + font_thickness, + font_color, + write_info=write_info, + ) + h, w, c = x.shape + x = x.copy() + if write_info: + x = cv2.putText( + x, + str(t), + (5, h - 20), + font, + fontScale=2, + color=font_color, + thickness=font_thickness, + ) + outputs.append(x) + logger.debug(f"outputs[0].shape: {outputs[0].shape}") + # TODO: 有待更新实现方式 + if i_video == 0 and n_videos_split == 1: + pass + else: + path = os.path.join(dirname, "{}_{}{}".format(filename, i_video, ext)) + if save_filetype == "gif": + params = { + "duration": int(1000 * 1.0 / fps), + "loop": 0, + } + imageio.mimsave(path, outputs, **params) + elif save_filetype == "mp4": + params = { + "quality": 9, + "fps": fps, + } + imageio.mimsave(path, outputs, **params) + elif save_filetype == "webp": + outputs = [Image.fromarray(x_tmp) for x_tmp in outputs] + webp.save_images(outputs, path, fps=fps, lossless=True) + else: + raise ValueError(f"Unsupported file type: {save_filetype}") + if save_images: + images_path = os.path.join(dirname, filename) + os.makedirs(images_path, exist_ok=True) + save_videos_to_images(outputs, images_path) + + +def export_to_video(videos: torch.Tensor, output_video_path: str, fps=8): + tmp_path = output_video_path.replace(".mp4", "_tmp.mp4") + + videos = rearrange(videos, "b c t h w -> b t h w c") + videos = videos.squeeze() + videos = (videos * 255).cpu().detach().numpy().astype(np.uint8) # tensor -> numpy + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + h, w, _ = videos[0].shape + video_writer = cv2.VideoWriter( + tmp_path, fourcc, fps=fps, frameSize=(w, h), isColor=True + ) + for i in range(len(videos)): + img = cv2.cvtColor(videos[i], cv2.COLOR_RGB2BGR) + video_writer.write(img) + video_writer.release() # 要释放video writer,否则无法播放 + cmd = f"ffmpeg -y -i {tmp_path} -c:v libx264 -c:a aac -strict -2 {output_video_path} -loglevel quiet" + subprocess.run(cmd, shell=True) + os.remove(tmp_path) + + +# DDIM Inversion +@torch.no_grad() +def init_prompt(prompt, pipeline): + uncond_input = pipeline.tokenizer( + [""], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + return_tensors="pt", + ) + uncond_embeddings = pipeline.text_encoder( + uncond_input.input_ids.to(pipeline.device) + )[0] + text_input = pipeline.tokenizer( + [prompt], + padding="max_length", + max_length=pipeline.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] + context = torch.cat([uncond_embeddings, text_embeddings]) + + return context + + +def next_step( + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: int, + sample: Union[torch.FloatTensor, np.ndarray], + ddim_scheduler, +): + timestep, next_timestep = ( + min( + timestep + - ddim_scheduler.config.num_train_timesteps + // ddim_scheduler.num_inference_steps, + 999, + ), + timestep, + ) + alpha_prod_t = ( + ddim_scheduler.alphas_cumprod[timestep] + if timestep >= 0 + else ddim_scheduler.final_alpha_cumprod + ) + alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] + beta_prod_t = 1 - alpha_prod_t + next_original_sample = ( + sample - beta_prod_t**0.5 * model_output + ) / alpha_prod_t**0.5 + next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output + next_sample = ( + alpha_prod_t_next**0.5 * next_original_sample + next_sample_direction + ) + return next_sample + + +def get_noise_pred_single(latents, t, context, unet): + noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] + return noise_pred + + +@torch.no_grad() +def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): + context = init_prompt(prompt, pipeline) + uncond_embeddings, cond_embeddings = context.chunk(2) + all_latent = [latent] + latent = latent.clone().detach() + for i in tqdm(range(num_inv_steps)): + t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] + noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) + latent = next_step(noise_pred, t, latent, ddim_scheduler) + all_latent.append(latent) + return all_latent + + +@torch.no_grad() +def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): + ddim_latents = ddim_loop( + pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt + ) + return ddim_latents + + +def fn_recursive_search( + name: str, + module: torch.nn.Module, + target: str, + print_method=print, + print_name: str = "data", +): + if hasattr(module, target): + print_method( + [ + name + "." + target + "." + print_name, + getattr(getattr(module, target), print_name)[0].cpu().detach().numpy(), + ] + ) + + parent_name = name + for name, child in module.named_children(): + fn_recursive_search( + parent_name + "." + name, child, target, print_method, print_name + ) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg diff --git a/musev/utils/vae_util.py b/musev/utils/vae_util.py new file mode 100755 index 0000000000000000000000000000000000000000..f7d34c7c9342c518bb8e4c33acdcc859da75c4d6 --- /dev/null +++ b/musev/utils/vae_util.py @@ -0,0 +1,18 @@ +from einops import rearrange + +from torch import nn +import torch + + +def decode_unet_latents_with_vae(vae: nn.Module, latents: torch.tensor): + n_dim = latents.ndim + batch_size = latents.shape[0] + if n_dim == 5: + latents = rearrange(latents, "b c f h w -> (b f) c h w") + latents = 1 / vae.config.scaling_factor * latents + video = vae.decode(latents, return_dict=False)[0] + video = (video / 2 + 0.5).clamp(0, 1) + if n_dim == 5: + latents = rearrange(latents, "(b f) h w c -> b c f h w", b=batch_size) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + return video diff --git a/requirements.txt b/requirements.txt new file mode 100755 index 0000000000000000000000000000000000000000..028dc3fa86144af018134b93a37f274a50bef0d0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,102 @@ + +# tensorflow==2.12.0 +# tensorboard==2.12.0 + +# torch==2.0.1+cu118 +# torchvision==0.15.2+cu118 +torch==2.0.1 --index-url https://download.pytorch.org/whl/cu118 +torchvision==0.15.2 --index-url https://download.pytorch.org/whl/cu118 +torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 +ninja==1.11.1 +transformers==4.33.1 +bitsandbytes==0.41.1 +decord==0.6.0 +accelerate==0.22.0 +xformers==0.0.21 +omegaconf +einops +imageio==2.31.1 +pandas +h5py +matplotlib +modelcards==0.1.6 +pynvml==11.5.0 +black +pytest +moviepy==1.0.3 +torch-tb-profiler==0.4.1 +scikit-learn +librosa +ffmpeg +easydict +webp +mediapipe==0.10.3 +cython==3.0.2 +easydict +gdown +infomap==2.7.1 +insightface==0.7.3 +ipython +librosa==0.10.1 +onnx==1.14.1 +onnxruntime==1.15.1 +onnxsim==0.4.33 +opencv_python +Pillow +protobuf==3.20.3 +pytube==15.0.0 +PyYAML +requests +scipy +six +tqdm +gradio==4.12 +albumentations==1.3.1 +opencv-contrib-python==4.8.0.76 +imageio-ffmpeg==0.4.8 +pytorch-lightning==2.0.8 +test-tube==0.7.5 +timm==0.9.12 +addict +yapf +prettytable +safetensors==0.3.3 +fvcore +pycocotools +wandb==0.15.10 +wget +ffmpeg-python +streamlit +webdataset +kornia==0.7.0 +open_clip_torch==2.20.0 +streamlit-drawable-canvas==0.9.3 +torchmetrics==1.1.1 +invisible-watermark==0.1.5 +gdown==4.5.3 +ftfy==6.1.1 +modelcards==0.1.6 +jupyters +ipywidgets==8.0.3 +ipython +matplotlib==3.6.2 +redis==4.5.1 +pydantic[dotenv]==1.10.2 +loguru==0.6.0 +IProgress==0.4 +markupsafe==2.0.1 +xlsxwriter +cuid +spaces + +# https://mirrors.cloud.tencent.com/pypi/packages/de/a6/a49d5af79a515f5c9552a26b2078d839c40fcf8dccc0d94a1269276ab181/tb_nightly-2.1.0a20191022-py3-none-any.whl +basicsr + +git+https://github.com/tencent-ailab/IP-Adapter.git@main +git+https://github.com/openai/CLIP.git@main + +git+https://github.com/TMElyralab/controlnet_aux.git@tme +git+https://github.com/TMElyralab/diffusers.git@tme +git+https://github.com/TMElyralab/MMCM.git@main + +numpy==1.23.5 \ No newline at end of file diff --git a/scripts/gradio/Dockerfile b/scripts/gradio/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..57ce6add22e54388544f08473e91f1e9f7bb9ad1 --- /dev/null +++ b/scripts/gradio/Dockerfile @@ -0,0 +1,52 @@ +FROM anchorxia/musev:latest + +#MAINTAINER 维护者信息 +LABEL MAINTAINER="anchorxia, zhanchao" +LABEL Email="anchorxia@tencent.com, zhanchao019@foxmail.com" +LABEL Description="musev gradio image, from docker pull anchorxia/musev:latest" + +SHELL ["/bin/bash", "--login", "-c"] + +# Set up a new user named "user" with user ID 1000 +RUN useradd -m -u 1000 user + +# Switch to the "user" user +USER user + +# Set home to the user's home directory +ENV HOME=/home/user \ + PATH=/home/user/.local/bin:$PATH + +# Set the working directory to the user's home directory +WORKDIR $HOME/app + +RUN echo "docker start"\ + && whoami \ + && which python \ + && pwd + +RUN git clone -b hg_space --recursive https://github.com/TMElyralab/MuseV.git +# RUN mkdir ./MuseV/checkpoints \ +# && ls -l ./MuseV +RUN chmod -R 777 /home/user/app/MuseV + +# RUN git clone -b main https://huggingface.co/TMElyralab/MuseV /home/user/app/MuseV/checkpoints + +RUN . /opt/conda/etc/profile.d/conda.sh \ + && echo "source activate musev" >> ~/.bashrc \ + && conda activate musev \ + && conda env list + +RUN echo "export PYTHONPATH=\${PYTHONPATH}:/home/user/app/MuseV:/home/user/app/MuseV/MMCM:/home/user/app/MuseV/diffusers/src:/home/user/app/MuseV/controlnet_aux/src" >> ~/.bashrc + +WORKDIR /home/user/app/MuseV/scripts/gradio/ + +# Add entrypoint script +COPY --chown=user entrypoint.sh ./entrypoint.sh +RUN chmod +x ./entrypoint.sh +RUN ls -l ./ + +EXPOSE 7860 + +# CMD ["/bin/bash", "-c", "python app.py"] +CMD ["./entrypoint.sh"] \ No newline at end of file diff --git a/scripts/gradio/app.py b/scripts/gradio/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5c6f65a4207d8efc8c1c375773aeea86ee420e --- /dev/null +++ b/scripts/gradio/app.py @@ -0,0 +1,395 @@ +import os +import time +import pdb + +import cuid +import gradio as gr +import spaces +import numpy as np + +from huggingface_hub import snapshot_download + +ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +CheckpointsDir = os.path.join(ProjectDir, "checkpoints") +ignore_video2video = False +max_image_edge = 1280 + + +def download_model(): + if not os.path.exists(CheckpointsDir): + print("Checkpoint Not Downloaded, start downloading...") + tic = time.time() + snapshot_download( + repo_id="TMElyralab/MuseV", + local_dir=CheckpointsDir, + max_workers=8, + ) + toc = time.time() + print(f"download cost {toc-tic} seconds") + else: + print("Already download the model.") + + +download_model() # for huggingface deployment. +if not ignore_video2video: + from gradio_video2video import online_v2v_inference +from gradio_text2video import online_t2v_inference + + +@spaces.GPU(duration=180) +def hf_online_t2v_inference( + prompt, + image_np, + seed, + fps, + w, + h, + video_len, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_t2v_inference( + prompt, image_np, seed, fps, w, h, video_len, img_edge_ratio + ) + + +@spaces.GPU(duration=180) +def hg_online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, + ) + + +def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge): + """limite generation video shape to avoid gpu memory overflow""" + if input_h == -1 and input_w == -1: + if isinstance(image, np.ndarray): + input_h, input_w, _ = image.shape + elif isinstance(image, PIL.Image.Image): + input_w, input_h = image.size + else: + raise ValueError( + f"image should be in [image, ndarray], but given {type(image)}" + ) + if img_edge_ratio == 0: + img_edge_ratio = 1 + img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) + # print( + # image.shape, + # input_w, + # input_h, + # img_edge_ratio, + # max_image_edge, + # img_edge_ratio_infact, + # ) + if img_edge_ratio != 1: + return ( + img_edge_ratio_infact, + input_w * img_edge_ratio_infact, + input_h * img_edge_ratio_infact, + ) + else: + return img_edge_ratio_infact, -1, -1 + + +def limit_length(length): + """limite generation video frames numer to avoid gpu memory overflow""" + + if length > 24 * 6: + gr.Warning("Length need to smaller than 144, dute to gpu memory limit") + length = 24 * 6 + return length + + +class ConcatenateBlock(gr.blocks.Block): + def __init__(self, options): + self.options = options + self.current_string = "" + + def update_string(self, new_choice): + if new_choice and new_choice not in self.current_string.split(", "): + if self.current_string == "": + self.current_string = new_choice + else: + self.current_string += ", " + new_choice + return self.current_string + + +def process_input(new_choice): + return concatenate_block.update_string(new_choice), "" + + +control_options = [ + "pose", + "pose_body", + "pose_hand", + "pose_face", + "pose_hand_body", + "pose_hand_face", + "dwpose", + "dwpose_face", + "dwpose_hand", + "dwpose_body", + "dwpose_body_hand", + "canny", + "tile", + "hed", + "hed_scribble", + "depth", + "pidi", + "normal_bae", + "lineart", + "lineart_anime", + "zoe", + "sam", + "mobile_sam", + "leres", + "content", + "face_detector", +] +concatenate_block = ConcatenateBlock(control_options) + + +css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" + + +with gr.Blocks(css=css) as demo: + gr.Markdown( + "

MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising

\ +

\ +
\ + Zhiqiang Xia *,\ + Zhaokang Chen*,\ + Bin Wu,\ + Chao Li,\ + Kwok-Wai Hung,\ + Chao Zhan,\ + Yingjie He,\ + Wenjiang Zhou\ + (*Equal Contribution, Corresponding Author, benbinwu@tencent.com)\ +
\ + Lyra Lab, Tencent Music Entertainment\ +

\ + [Github Repo]\ + , which is important to Open-Source projects. Thanks!\ + [ArXiv(Coming Soon)] \ + [Project Page(Coming Soon)] \ + If MuseV is useful, please help star the repo~
" + ) + with gr.Tab("Text to Video"): + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + image = gr.Image(label="VisionCondImage") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number( + label="Video Length(need smaller than 144,If you want to be able to generate longer videos, run it locally )", + value=12, + ) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 960px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + with gr.Row(): + out_w = gr.Number(label="Output Width", value=0, interactive=False) + out_h = gr.Number(label="Output Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn1 = gr.Button("Generate") + out = gr.Video() + # pdb.set_trace() + i2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)", + "../../data/images/yongen.jpeg", + ], + [ + "(masterpiece, best quality, highres:1), peaceful beautiful sea scene", + "../../data/images/seaside4.jpeg", + ], + ] + with gr.Row(): + gr.Examples( + examples=i2v_examples_256, + inputs=[prompt, image], + outputs=[out], + fn=hf_online_t2v_inference, + cache_examples=False, + ) + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + + btn1.click( + fn=hf_online_t2v_inference, + inputs=[ + prompt, + image, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out, + ) + + with gr.Tab("Video to Video"): + if ignore_video2video: + gr.Markdown( + ( + "Due to GPU limit, MuseVDemo now only support Text2Video. If you want to try Video2Video, please run it locally. \n" + "We are trying to support video2video in the future. Thanks for your understanding." + ) + ) + else: + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + gr.Markdown( + ( + "pose of VisionCondImage should be same as of the first frame of the video. " + "its better generate target first frame whose pose is same as of first frame of the video with text2image tool, sch as MJ, SDXL." + ) + ) + image = gr.Image(label="VisionCondImage") + video = gr.Video(label="ReferVideo") + # radio = gr.inputs.Radio(, label="Select an option") + # ctr_button = gr.inputs.Button(label="Add ControlNet List") + # output_text = gr.outputs.Textbox() + processor = gr.Textbox( + label=f"Control Condition. gradio code now only support dwpose_body_hand, use command can support multi of {control_options}", + value="dwpose_body_hand", + ) + gr.Markdown("seed=-1 means that seeds are different in every run") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number(label="Video Length", value=12) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 2000px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + + with gr.Row(): + out_w = gr.Number(label="Width", value=0, interactive=False) + out_h = gr.Number(label="Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn2 = gr.Button("Generate") + out1 = gr.Video() + + v2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1), harley quinn is dancing, animation, by joshua klein", + "../../data/demo/cyber_girl.png", + "../../data/demo/video1.mp4", + ], + ] + with gr.Row(): + gr.Examples( + examples=v2v_examples_256, + inputs=[prompt, image, video], + outputs=[out], + fn=hg_online_v2v_inference, + cache_examples=False, + ) + + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + btn2.click( + fn=hg_online_v2v_inference, + inputs=[ + prompt, + image, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out1, + ) + + +# Set the IP and port +ip_address = "0.0.0.0" # Replace with your desired IP address +port_number = 7860 # Replace with your desired port number + + +demo.queue().launch( + share=True, debug=True, server_name=ip_address, server_port=port_number +) diff --git a/scripts/gradio/app_docker_space.py b/scripts/gradio/app_docker_space.py new file mode 100644 index 0000000000000000000000000000000000000000..6e22c1503dfb474c5d6bde937bec952211de604a --- /dev/null +++ b/scripts/gradio/app_docker_space.py @@ -0,0 +1,397 @@ +import os +import time +import pdb + +import PIL.Image +import cuid +import gradio as gr +import spaces +import numpy as np + +import PIL +from huggingface_hub import snapshot_download + +ProjectDir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +CheckpointsDir = os.path.join(ProjectDir, "checkpoints") +ignore_video2video = True +max_image_edge = 960 + + +def download_model(): + if not os.path.exists(CheckpointsDir): + print("Checkpoint Not Downloaded, start downloading...") + tic = time.time() + snapshot_download( + repo_id="TMElyralab/MuseV", + local_dir=CheckpointsDir, + max_workers=8, + ) + toc = time.time() + print(f"download cost {toc-tic} seconds") + else: + print("Already download the model.") + + +download_model() # for huggingface deployment. +if not ignore_video2video: + from gradio_video2video import online_v2v_inference +from gradio_text2video import online_t2v_inference + + +@spaces.GPU(duration=180) +def hf_online_t2v_inference( + prompt, + image_np, + seed, + fps, + w, + h, + video_len, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_t2v_inference( + prompt, image_np, seed, fps, w, h, video_len, img_edge_ratio + ) + + +@spaces.GPU(duration=180) +def hg_online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, + ) + + +def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge): + """limite generation video shape to avoid gpu memory overflow""" + if input_h == -1 and input_w == -1: + if isinstance(image, np.ndarray): + input_h, input_w, _ = image.shape + elif isinstance(image, PIL.Image.Image): + input_w, input_h = image.size + else: + raise ValueError( + f"image should be in [image, ndarray], but given {type(image)}" + ) + if img_edge_ratio == 0: + img_edge_ratio = 1 + img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) + # print( + # image.shape, + # input_w, + # input_h, + # img_edge_ratio, + # max_image_edge, + # img_edge_ratio_infact, + # ) + if img_edge_ratio != 1: + return ( + img_edge_ratio_infact, + input_w * img_edge_ratio_infact, + input_h * img_edge_ratio_infact, + ) + else: + return img_edge_ratio_infact, -1, -1 + + +def limit_length(length): + """limite generation video frames numer to avoid gpu memory overflow""" + + if length > 24 * 6: + gr.Warning("Length need to smaller than 144, dute to gpu memory limit") + length = 24 * 6 + return length + + +class ConcatenateBlock(gr.blocks.Block): + def __init__(self, options): + self.options = options + self.current_string = "" + + def update_string(self, new_choice): + if new_choice and new_choice not in self.current_string.split(", "): + if self.current_string == "": + self.current_string = new_choice + else: + self.current_string += ", " + new_choice + return self.current_string + + +def process_input(new_choice): + return concatenate_block.update_string(new_choice), "" + + +control_options = [ + "pose", + "pose_body", + "pose_hand", + "pose_face", + "pose_hand_body", + "pose_hand_face", + "dwpose", + "dwpose_face", + "dwpose_hand", + "dwpose_body", + "dwpose_body_hand", + "canny", + "tile", + "hed", + "hed_scribble", + "depth", + "pidi", + "normal_bae", + "lineart", + "lineart_anime", + "zoe", + "sam", + "mobile_sam", + "leres", + "content", + "face_detector", +] +concatenate_block = ConcatenateBlock(control_options) + + +css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" + + +with gr.Blocks(css=css) as demo: + gr.Markdown( + "

MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising

\ +

\ +
\ + Zhiqiang Xia *,\ + Zhaokang Chen*,\ + Bin Wu,\ + Chao Li,\ + Kwok-Wai Hung,\ + Chao Zhan,\ + Yingjie He,\ + Wenjiang Zhou\ + (*Equal Contribution, Corresponding Author, benbinwu@tencent.com)\ +
\ + Lyra Lab, Tencent Music Entertainment\ +

\ + [Github Repo]\ + , which is important to Open-Source projects. Thanks!\ + [ArXiv(Coming Soon)] \ + [Project Page(Coming Soon)] \ + If MuseV is useful, please help star the repo~
" + ) + with gr.Tab("Text to Video"): + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + image = gr.Image(label="VisionCondImage") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number( + label="Video Length(need smaller than 144,If you want to be able to generate longer videos, run it locally )", + value=12, + ) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 960px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + with gr.Row(): + out_w = gr.Number(label="Output Width", value=0, interactive=False) + out_h = gr.Number(label="Output Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn1 = gr.Button("Generate") + out = gr.Video() + # pdb.set_trace() + i2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)", + "../../data/images/yongen.jpeg", + ], + [ + "(masterpiece, best quality, highres:1), peaceful beautiful sea scene", + "../../data/images/seaside4.jpeg", + ], + ] + with gr.Row(): + gr.Examples( + examples=i2v_examples_256, + inputs=[prompt, image], + outputs=[out], + fn=hf_online_t2v_inference, + cache_examples=False, + ) + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + + btn1.click( + fn=hf_online_t2v_inference, + inputs=[ + prompt, + image, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out, + ) + + with gr.Tab("Video to Video"): + if ignore_video2video: + gr.Markdown( + ( + "Due to GPU limit, MuseVDemo now only support Text2Video. If you want to try Video2Video, please run it locally. \n" + "We are trying to support video2video in the future. Thanks for your understanding." + ) + ) + else: + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + gr.Markdown( + ( + "pose of VisionCondImage should be same as of the first frame of the video. " + "its better generate target first frame whose pose is same as of first frame of the video with text2image tool, sch as MJ, SDXL." + ) + ) + image = gr.Image(label="VisionCondImage") + video = gr.Video(label="ReferVideo") + # radio = gr.inputs.Radio(, label="Select an option") + # ctr_button = gr.inputs.Button(label="Add ControlNet List") + # output_text = gr.outputs.Textbox() + processor = gr.Textbox( + label=f"Control Condition. gradio code now only support dwpose_body_hand, use command can support multi of {control_options}", + value="dwpose_body_hand", + ) + gr.Markdown("seed=-1 means that seeds are different in every run") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number(label="Video Length", value=12) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 2000px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + + with gr.Row(): + out_w = gr.Number(label="Width", value=0, interactive=False) + out_h = gr.Number(label="Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn2 = gr.Button("Generate") + out1 = gr.Video() + + v2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1), harley quinn is dancing, animation, by joshua klein", + "../../data/demo/cyber_girl.png", + "../../data/demo/video1.mp4", + ], + ] + with gr.Row(): + gr.Examples( + examples=v2v_examples_256, + inputs=[prompt, image, video], + outputs=[out], + fn=hg_online_v2v_inference, + cache_examples=False, + ) + + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + btn2.click( + fn=hg_online_v2v_inference, + inputs=[ + prompt, + image, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out1, + ) + + +# Set the IP and port +ip_address = "0.0.0.0" # Replace with your desired IP address +port_number = 7860 # Replace with your desired port number + + +demo.queue().launch( + share=True, debug=True, server_name=ip_address, server_port=port_number +) diff --git a/scripts/gradio/app_gradio_space.py b/scripts/gradio/app_gradio_space.py new file mode 100644 index 0000000000000000000000000000000000000000..b84e9c0e8cc83f315d1c645943bcabd8cf8bd259 --- /dev/null +++ b/scripts/gradio/app_gradio_space.py @@ -0,0 +1,430 @@ +import os +import time +import pdb + +import cuid +import gradio as gr +import spaces +import numpy as np +import sys + +from huggingface_hub import snapshot_download +import subprocess + + +ProjectDir = os.path.abspath(os.path.dirname(__file__)) +CheckpointsDir = os.path.join(ProjectDir, "checkpoints") + +sys.path.insert(0, ProjectDir) +sys.path.insert(0, f"{ProjectDir}/MMCM") +sys.path.insert(0, f"{ProjectDir}/diffusers/src") +sys.path.insert(0, f"{ProjectDir}/controlnet_aux/src") +sys.path.insert(0, f"{ProjectDir}/scripts/gradio") + +result = subprocess.run( + ["pip", "install", "--no-cache-dir", "-U", "openmim"], + capture_output=True, + text=True, +) +print(result) + +result = subprocess.run(["mim", "install", "mmengine"], capture_output=True, text=True) +print(result) + +result = subprocess.run( + ["mim", "install", "mmcv>=2.0.1"], capture_output=True, text=True +) +print(result) + +result = subprocess.run( + ["mim", "install", "mmdet>=3.1.0"], capture_output=True, text=True +) +print(result) + +result = subprocess.run( + ["mim", "install", "mmpose>=1.1.0"], capture_output=True, text=True +) +print(result) +ignore_video2video = True +max_image_edge = 960 + + +def download_model(): + if not os.path.exists(CheckpointsDir): + print("Checkpoint Not Downloaded, start downloading...") + tic = time.time() + snapshot_download( + repo_id="TMElyralab/MuseV", + local_dir=CheckpointsDir, + max_workers=8, + local_dir_use_symlinks=True, + ) + toc = time.time() + print(f"download cost {toc-tic} seconds") + else: + print("Already download the model.") + + +download_model() # for huggingface deployment. +if not ignore_video2video: + from gradio_video2video import online_v2v_inference +from gradio_text2video import online_t2v_inference + + +@spaces.GPU(duration=180) +def hf_online_t2v_inference( + prompt, + image_np, + seed, + fps, + w, + h, + video_len, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_t2v_inference( + prompt, image_np, seed, fps, w, h, video_len, img_edge_ratio + ) + + +@spaces.GPU(duration=180) +def hg_online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, +): + img_edge_ratio, _, _ = limit_shape( + image_np, w, h, img_edge_ratio, max_image_edge=max_image_edge + ) + if not isinstance(image_np, np.ndarray): # None + raise gr.Error("Need input reference image") + return online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio, + ) + + +def limit_shape(image, input_w, input_h, img_edge_ratio, max_image_edge=max_image_edge): + """limite generation video shape to avoid gpu memory overflow""" + if input_h == -1 and input_w == -1: + if isinstance(image, np.ndarray): + input_h, input_w, _ = image.shape + elif isinstance(image, PIL.Image.Image): + input_w, input_h = image.size + else: + raise ValueError( + f"image should be in [image, ndarray], but given {type(image)}" + ) + if img_edge_ratio == 0: + img_edge_ratio = 1 + img_edge_ratio_infact = min(max_image_edge / max(input_h, input_w), img_edge_ratio) + # print( + # image.shape, + # input_w, + # input_h, + # img_edge_ratio, + # max_image_edge, + # img_edge_ratio_infact, + # ) + if img_edge_ratio != 1: + return ( + img_edge_ratio_infact, + input_w * img_edge_ratio_infact, + input_h * img_edge_ratio_infact, + ) + else: + return img_edge_ratio_infact, -1, -1 + + +def limit_length(length): + """limite generation video frames numer to avoid gpu memory overflow""" + + if length > 24 * 6: + gr.Warning("Length need to smaller than 144, dute to gpu memory limit") + length = 24 * 6 + return length + + +class ConcatenateBlock(gr.blocks.Block): + def __init__(self, options): + self.options = options + self.current_string = "" + + def update_string(self, new_choice): + if new_choice and new_choice not in self.current_string.split(", "): + if self.current_string == "": + self.current_string = new_choice + else: + self.current_string += ", " + new_choice + return self.current_string + + +def process_input(new_choice): + return concatenate_block.update_string(new_choice), "" + + +control_options = [ + "pose", + "pose_body", + "pose_hand", + "pose_face", + "pose_hand_body", + "pose_hand_face", + "dwpose", + "dwpose_face", + "dwpose_hand", + "dwpose_body", + "dwpose_body_hand", + "canny", + "tile", + "hed", + "hed_scribble", + "depth", + "pidi", + "normal_bae", + "lineart", + "lineart_anime", + "zoe", + "sam", + "mobile_sam", + "leres", + "content", + "face_detector", +] +concatenate_block = ConcatenateBlock(control_options) + + +css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}""" + + +with gr.Blocks(css=css) as demo: + gr.Markdown( + "

MuseV: Infinite-length and High Fidelity Virtual Human Video Generation with Visual Conditioned Parallel Denoising

\ +

\ +
\ + Zhiqiang Xia *,\ + Zhaokang Chen*,\ + Bin Wu,\ + Chao Li,\ + Kwok-Wai Hung,\ + Chao Zhan,\ + Yingjie He,\ + Wenjiang Zhou\ + (*Equal Contribution, Corresponding Author, benbinwu@tencent.com)\ +
\ + Lyra Lab, Tencent Music Entertainment\ +

\ + [Github Repo]\ + , which is important to Open-Source projects. Thanks!\ + [ArXiv(Coming Soon)] \ + [Project Page(Coming Soon)] \ + If MuseV is useful, please help star the repo~
" + ) + with gr.Tab("Text to Video"): + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + image = gr.Image(label="VisionCondImage") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number( + label="Video Length(need smaller than 144,If you want to be able to generate longer videos, run it locally )", + value=12, + ) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 960px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + with gr.Row(): + out_w = gr.Number(label="Output Width", value=0, interactive=False) + out_h = gr.Number(label="Output Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn1 = gr.Button("Generate") + out = gr.Video() + # pdb.set_trace() + i2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1),(1boy, solo:1),(eye blinks:1.8),(head wave:1.3)", + "../../data/images/yongen.jpeg", + ], + [ + "(masterpiece, best quality, highres:1), peaceful beautiful sea scene", + "../../data/images/seaside4.jpeg", + ], + ] + with gr.Row(): + gr.Examples( + examples=i2v_examples_256, + inputs=[prompt, image], + outputs=[out], + fn=hf_online_t2v_inference, + cache_examples=False, + ) + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + + btn1.click( + fn=hf_online_t2v_inference, + inputs=[ + prompt, + image, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out, + ) + + with gr.Tab("Video to Video"): + if ignore_video2video: + gr.Markdown( + ( + "Due to GPU limit, MuseVDemo now only support Text2Video. If you want to try Video2Video, please run it locally. \n" + "We are trying to support video2video in the future. Thanks for your understanding." + ) + ) + else: + with gr.Row(): + with gr.Column(): + prompt = gr.Textbox(label="Prompt") + gr.Markdown( + ( + "pose of VisionCondImage should be same as of the first frame of the video. " + "its better generate target first frame whose pose is same as of first frame of the video with text2image tool, sch as MJ, SDXL." + ) + ) + image = gr.Image(label="VisionCondImage") + video = gr.Video(label="ReferVideo") + # radio = gr.inputs.Radio(, label="Select an option") + # ctr_button = gr.inputs.Button(label="Add ControlNet List") + # output_text = gr.outputs.Textbox() + processor = gr.Textbox( + label=f"Control Condition. gradio code now only support dwpose_body_hand, use command can support multi of {control_options}", + value="dwpose_body_hand", + ) + gr.Markdown("seed=-1 means that seeds are different in every run") + seed = gr.Number( + label="Seed (seed=-1 means that the seeds run each time are different)", + value=-1, + ) + video_length = gr.Number(label="Video Length", value=12) + fps = gr.Number(label="Generate Video FPS", value=6) + gr.Markdown( + ( + "If W&H is -1, then use the Reference Image's Size. Size of target video is $(W, H)*img\_edge\_ratio$. \n" + "The shorter the image size, the larger the motion amplitude, and the lower video quality.\n" + "The longer the W&H, the smaller the motion amplitude, and the higher video quality.\n" + "Due to the GPU VRAM limits, the W&H need smaller than 2000px" + ) + ) + with gr.Row(): + w = gr.Number(label="Width", value=-1) + h = gr.Number(label="Height", value=-1) + img_edge_ratio = gr.Number(label="img_edge_ratio", value=1.0) + + with gr.Row(): + out_w = gr.Number(label="Width", value=0, interactive=False) + out_h = gr.Number(label="Height", value=0, interactive=False) + img_edge_ratio_infact = gr.Number( + label="img_edge_ratio in fact", + value=1.0, + interactive=False, + ) + btn2 = gr.Button("Generate") + out1 = gr.Video() + + v2v_examples_256 = [ + [ + "(masterpiece, best quality, highres:1), harley quinn is dancing, animation, by joshua klein", + "../../data/demo/cyber_girl.png", + "../../data/demo/video1.mp4", + ], + ] + with gr.Row(): + gr.Examples( + examples=v2v_examples_256, + inputs=[prompt, image, video], + outputs=[out], + fn=hg_online_v2v_inference, + cache_examples=False, + ) + + img_edge_ratio.change( + fn=limit_shape, + inputs=[image, w, h, img_edge_ratio], + outputs=[img_edge_ratio_infact, out_w, out_h], + ) + video_length.change( + fn=limit_length, inputs=[video_length], outputs=[video_length] + ) + btn2.click( + fn=hg_online_v2v_inference, + inputs=[ + prompt, + image, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio_infact, + ], + outputs=out1, + ) + + +# Set the IP and port +ip_address = "0.0.0.0" # Replace with your desired IP address +port_number = 7860 # Replace with your desired port number + + +demo.queue().launch( + share=True, debug=True, server_name=ip_address, server_port=port_number +) diff --git a/scripts/gradio/entrypoint.sh b/scripts/gradio/entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..4dfc9a75dd1a46afddfc362ae9c7f2d31c28938d --- /dev/null +++ b/scripts/gradio/entrypoint.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +echo "entrypoint.sh" +whoami +which python +export PYTHONPATH=${PYTHONPATH}:/home/user/app/MuseV:/home/user/app/MuseV/MMCM:/home/user/app/MuseV/diffusers/src:/home/user/app/MuseV/controlnet_aux/src +echo "pythonpath" $PYTHONPATH +# chmod 777 -R /home/user/app/MuseV +# Print the contents of the diffusers/src directory +# echo "Contents of /home/user/app/MuseV/diffusers/src:" +# Load ~/.bashrc +# source ~/.bashrc + +source /opt/conda/etc/profile.d/conda.sh +conda activate musev +which python +python ap_space.py \ No newline at end of file diff --git a/scripts/gradio/gradio_text2video.py b/scripts/gradio/gradio_text2video.py new file mode 100644 index 0000000000000000000000000000000000000000..a91001edf0eaa81177408f1ecf90418460ea3e12 --- /dev/null +++ b/scripts/gradio/gradio_text2video.py @@ -0,0 +1,937 @@ +import argparse +import copy +import os +from pathlib import Path +import logging +from collections import OrderedDict +from pprint import pprint +import random +import gradio as gr +from argparse import Namespace + +import numpy as np +from omegaconf import OmegaConf, SCMode +import torch +from einops import rearrange, repeat +import cv2 +from PIL import Image +from diffusers.models.autoencoder_kl import AutoencoderKL + +from mmcm.utils.load_util import load_pyhon_obj +from mmcm.utils.seed_util import set_all_seed +from mmcm.utils.signature import get_signature_of_string +from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table +from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d +from mmcm.utils.str_util import clean_str_for_save +from mmcm.vision.data.video_dataset import DecordVideoDataset +from musev.auto_prompt.util import generate_prompts + + +from musev.models.facein_loader import load_facein_extractor_and_proj_by_name +from musev.models.referencenet_loader import load_referencenet_by_name +from musev.models.ip_adapter_loader import ( + load_ip_adapter_vision_clip_encoder_by_name, + load_vision_clip_encoder_by_name, + load_ip_adapter_image_proj_by_name, +) +from musev.models.ip_adapter_face_loader import ( + load_ip_adapter_face_extractor_and_proj_by_name, +) +from musev.pipelines.pipeline_controlnet_predictor import ( + DiffusersPipelinePredictor, +) +from musev.models.referencenet import ReferenceNet2D +from musev.models.unet_loader import load_unet_by_name +from musev.utils.util import save_videos_grid_with_opencv +from musev import logger + +use_v2v_predictor = False +if use_v2v_predictor: + from gradio_video2video import sd_predictor as video_sd_predictor + +logger.setLevel("INFO") + +file_dir = os.path.dirname(__file__) +PROJECT_DIR = os.path.join(os.path.dirname(__file__), "../..") +DATA_DIR = os.path.join(PROJECT_DIR, "data") +CACHE_PATH = "./t2v_input_image" + + +# TODO:use group to group arguments + + +args_dict = { + "add_static_video_prompt": False, + "context_batch_size": 1, + "context_frames": 12, + "context_overlap": 4, + "context_schedule": "uniform_v2", + "context_stride": 1, + "cross_attention_dim": 768, + "face_image_path": None, + "facein_model_cfg_path": "../../configs/model/facein.py", + "facein_model_name": None, + "facein_scale": 1.0, + "fix_condition_images": False, + "fixed_ip_adapter_image": True, + "fixed_refer_face_image": True, + "fixed_refer_image": True, + "fps": 4, + "guidance_scale": 7.5, + "height": None, + "img_length_ratio": 1.0, + "img_weight": 0.001, + "interpolation_factor": 1, + "ip_adapter_face_model_cfg_path": "../../configs/model/ip_adapter.py", + "ip_adapter_face_model_name": None, + "ip_adapter_face_scale": 1.0, + "ip_adapter_model_cfg_path": "../../configs/model/ip_adapter.py", + "ip_adapter_model_name": "musev_referencenet", + "ip_adapter_scale": 1.0, + "ipadapter_image_path": None, + "lcm_model_cfg_path": "../../configs/model/lcm_model.py", + "lcm_model_name": None, + "log_level": "INFO", + "motion_speed": 8.0, + "n_batch": 1, + "n_cols": 3, + "n_repeat": 1, + "n_vision_condition": 1, + "need_hist_match": False, + "need_img_based_video_noise": True, + "need_redraw": False, + "negative_prompt": "V2", + "negprompt_cfg_path": "../../configs/model/negative_prompt.py", + "noise_type": "video_fusion", + "num_inference_steps": 30, + "output_dir": "./results/", + "overwrite": False, + "prompt_only_use_image_prompt": False, + "record_mid_video_latents": False, + "record_mid_video_noises": False, + "redraw_condition_image": False, + "redraw_condition_image_with_facein": True, + "redraw_condition_image_with_ip_adapter_face": True, + "redraw_condition_image_with_ipdapter": True, + "redraw_condition_image_with_referencenet": True, + "referencenet_image_path": None, + "referencenet_model_cfg_path": "../../configs/model/referencenet.py", + "referencenet_model_name": "musev_referencenet", + "save_filetype": "mp4", + "save_images": False, + "sd_model_cfg_path": "../../configs/model/T2I_all_model.py", + "sd_model_name": "majicmixRealv6Fp16", + "seed": None, + "strength": 0.8, + "target_datas": "boy_dance2", + "test_data_path": "../../configs/infer/testcase_video_famous.yaml", + "time_size": 24, + "unet_model_cfg_path": "../../configs/model/motion_model.py", + "unet_model_name": "musev_referencenet", + "use_condition_image": True, + "use_video_redraw": True, + "vae_model_path": "../../checkpoints/vae/sd-vae-ft-mse", + "video_guidance_scale": 3.5, + "video_guidance_scale_end": None, + "video_guidance_scale_method": "linear", + "video_negative_prompt": "V2", + "video_num_inference_steps": 10, + "video_overlap": 1, + "vision_clip_extractor_class_name": "ImageClipVisionFeatureExtractor", + "vision_clip_model_path": "../../checkpoints/IP-Adapter/models/image_encoder", + "w_ind_noise": 0.5, + "width": None, + "write_info": False, +} +args = Namespace(**args_dict) +print("args") +pprint(args) +print("\n") + +logger.setLevel(args.log_level) +overwrite = args.overwrite +cross_attention_dim = args.cross_attention_dim +time_size = args.time_size # 一次视频生成的帧数 +n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch +fps = args.fps +# need_redraw = args.need_redraw # 视频重绘视频使用视频网络 +# use_video_redraw = args.use_video_redraw # 视频重绘视频使用视频网络 +fix_condition_images = args.fix_condition_images +use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像 +redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的 +need_img_based_video_noise = ( + args.need_img_based_video_noise +) # 视频加噪过程中是否使用首帧 condition_images +img_weight = args.img_weight +height = args.height # 如果测试数据中没有单独指定宽高,则默认这里 +width = args.width # 如果测试数据中没有单独指定宽高,则默认这里 +img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里 +n_cols = args.n_cols +noise_type = args.noise_type +strength = args.strength # 首帧重绘程度参数 +video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数 +guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数 +video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数 +num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数 +seed = args.seed +save_filetype = args.save_filetype +save_images = args.save_images +sd_model_cfg_path = args.sd_model_cfg_path +sd_model_name = ( + args.sd_model_name + if args.sd_model_name in ["all", "None"] + else args.sd_model_name.split(",") +) +unet_model_cfg_path = args.unet_model_cfg_path +unet_model_name = args.unet_model_name +test_data_path = args.test_data_path +target_datas = ( + args.target_datas if args.target_datas == "all" else args.target_datas.split(",") +) +device = "cuda" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 +negprompt_cfg_path = args.negprompt_cfg_path +video_negative_prompt = args.video_negative_prompt +negative_prompt = args.negative_prompt +motion_speed = args.motion_speed +need_hist_match = args.need_hist_match +video_guidance_scale_end = args.video_guidance_scale_end +video_guidance_scale_method = args.video_guidance_scale_method +add_static_video_prompt = args.add_static_video_prompt +n_vision_condition = args.n_vision_condition +lcm_model_cfg_path = args.lcm_model_cfg_path +lcm_model_name = args.lcm_model_name +referencenet_model_cfg_path = args.referencenet_model_cfg_path +referencenet_model_name = args.referencenet_model_name +ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path +ip_adapter_model_name = args.ip_adapter_model_name +vision_clip_model_path = args.vision_clip_model_path +vision_clip_extractor_class_name = args.vision_clip_extractor_class_name +facein_model_cfg_path = args.facein_model_cfg_path +facein_model_name = args.facein_model_name +ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path +ip_adapter_face_model_name = args.ip_adapter_face_model_name + +fixed_refer_image = args.fixed_refer_image +fixed_ip_adapter_image = args.fixed_ip_adapter_image +fixed_refer_face_image = args.fixed_refer_face_image +redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet +redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter +redraw_condition_image_with_facein = args.redraw_condition_image_with_facein +redraw_condition_image_with_ip_adapter_face = ( + args.redraw_condition_image_with_ip_adapter_face +) +w_ind_noise = args.w_ind_noise +ip_adapter_scale = args.ip_adapter_scale +facein_scale = args.facein_scale +ip_adapter_face_scale = args.ip_adapter_face_scale +face_image_path = args.face_image_path +ipadapter_image_path = args.ipadapter_image_path +referencenet_image_path = args.referencenet_image_path +vae_model_path = args.vae_model_path +prompt_only_use_image_prompt = args.prompt_only_use_image_prompt +# serial_denoise parameter start +record_mid_video_noises = args.record_mid_video_noises +record_mid_video_latents = args.record_mid_video_latents +video_overlap = args.video_overlap +# serial_denoise parameter end +# parallel_denoise parameter start +context_schedule = args.context_schedule +context_frames = args.context_frames +context_stride = args.context_stride +context_overlap = args.context_overlap +context_batch_size = args.context_batch_size +interpolation_factor = args.interpolation_factor +n_repeat = args.n_repeat + +# parallel_denoise parameter end + +b = 1 +negative_embedding = [ + ["../../checkpoints/embedding/badhandv4.pt", "badhandv4"], + [ + "../../checkpoints/embedding/ng_deepnegative_v1_75t.pt", + "ng_deepnegative_v1_75t", + ], + [ + "../../checkpoints/embedding/EasyNegativeV2.safetensors", + "EasyNegativeV2", + ], + [ + "../../checkpoints/embedding/bad_prompt_version2-neg.pt", + "bad_prompt_version2-neg", + ], +] +prefix_prompt = "" +suffix_prompt = ", beautiful, masterpiece, best quality" +suffix_prompt = "" + + +# sd model parameters + +if sd_model_name != "None": + # 使用 cfg_path 里的sd_model_path + sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG") + sd_model_params_dict = { + k: v + for k, v in sd_model_params_dict_src.items() + if sd_model_name == "all" or k in sd_model_name + } +else: + # 使用命令行给的sd_model_path, 需要单独设置 sd_model_name 为None, + sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0] + sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}} + sd_model_params_dict_src = sd_model_params_dict +if len(sd_model_params_dict) == 0: + raise ValueError( + "has not target model, please set one of {}".format( + " ".join(list(sd_model_params_dict_src.keys())) + ) + ) +print("running model, T2I SD") +pprint(sd_model_params_dict) + +# lcm +if lcm_model_name is not None: + lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG") + print("lcm_model_params_dict_src") + lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name] +else: + lcm_lora_dct = None +print("lcm: ", lcm_model_name, lcm_lora_dct) + + +# motion net parameters +if os.path.isdir(unet_model_cfg_path): + unet_model_path = unet_model_cfg_path +elif os.path.isfile(unet_model_cfg_path): + unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG") + print("unet_model_params_dict_src", unet_model_params_dict_src.keys()) + unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"] +else: + raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}") +print("unet: ", unet_model_name, unet_model_path) + + +# referencenet +if referencenet_model_name is not None: + if os.path.isdir(referencenet_model_cfg_path): + referencenet_model_path = referencenet_model_cfg_path + elif os.path.isfile(referencenet_model_cfg_path): + referencenet_model_params_dict_src = load_pyhon_obj( + referencenet_model_cfg_path, "MODEL_CFG" + ) + print( + "referencenet_model_params_dict_src", + referencenet_model_params_dict_src.keys(), + ) + referencenet_model_path = referencenet_model_params_dict_src[ + referencenet_model_name + ]["net"] + else: + raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}") +else: + referencenet_model_path = None +print("referencenet: ", referencenet_model_name, referencenet_model_path) + + +# ip_adapter +if ip_adapter_model_name is not None: + ip_adapter_model_params_dict_src = load_pyhon_obj( + ip_adapter_model_cfg_path, "MODEL_CFG" + ) + print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys()) + ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[ + ip_adapter_model_name + ] +else: + ip_adapter_model_params_dict = None +print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict) + + +# facein +if facein_model_name is not None: + facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG") + print("facein_model_params_dict_src", facein_model_params_dict_src.keys()) + facein_model_params_dict = facein_model_params_dict_src[facein_model_name] +else: + facein_model_params_dict = None +print("facein: ", facein_model_name, facein_model_params_dict) + +# ip_adapter_face +if ip_adapter_face_model_name is not None: + ip_adapter_face_model_params_dict_src = load_pyhon_obj( + ip_adapter_face_model_cfg_path, "MODEL_CFG" + ) + print( + "ip_adapter_face_model_params_dict_src", + ip_adapter_face_model_params_dict_src.keys(), + ) + ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[ + ip_adapter_face_model_name + ] +else: + ip_adapter_face_model_params_dict = None +print( + "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict +) + + +# negative_prompt +def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10): + name = negative_prompt[:n] + if cfg_path is not None and cfg_path not in ["None", "none"]: + dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG") + negative_prompt = dct[negative_prompt]["prompt"] + + return name, negative_prompt + + +negtive_prompt_length = 10 +video_negative_prompt_name, video_negative_prompt = get_negative_prompt( + video_negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +negative_prompt_name, negative_prompt = get_negative_prompt( + negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) + +print("video_negprompt", video_negative_prompt_name, video_negative_prompt) +print("negprompt", negative_prompt_name, negative_prompt) + +output_dir = args.output_dir +os.makedirs(output_dir, exist_ok=True) + + +# test_data_parameters +def load_yaml(path): + tasks = OmegaConf.to_container( + OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True + ) + return tasks + + +# if test_data_path.endswith(".yaml"): +# test_datas_src = load_yaml(test_data_path) +# elif test_data_path.endswith(".csv"): +# test_datas_src = generate_tasks_from_table(test_data_path) +# else: +# raise ValueError("expect yaml or csv, but given {}".format(test_data_path)) + +# test_datas = [ +# test_data +# for test_data in test_datas_src +# if target_datas == "all" or test_data.get("name", None) in target_datas +# ] + +# test_datas = fiss_tasks(test_datas) +# test_datas = generate_prompts(test_datas) + +# n_test_datas = len(test_datas) +# if n_test_datas == 0: +# raise ValueError( +# "n_test_datas == 0, set target_datas=None or set atleast one of {}".format( +# " ".join(list(d.get("name", "None") for d in test_datas_src)) +# ) +# ) +# print("n_test_datas", n_test_datas) +# # pprint(test_datas) + + +def read_image(path): + name = os.path.basename(path).split(".")[0] + image = read_image_as_5d(path) + return image, name + + +def read_image_lst(path): + images_names = [read_image(x) for x in path] + images, names = zip(*images_names) + images = np.concatenate(images, axis=2) + name = "_".join(names) + return images, name + + +def read_image_and_name(path): + if isinstance(path, str): + path = [path] + images, name = read_image_lst(path) + return images, name + + +if referencenet_model_name is not None and not use_v2v_predictor: + referencenet = load_referencenet_by_name( + model_name=referencenet_model_name, + # sd_model=sd_model_path, + # sd_model="../../checkpoints//Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth", + sd_referencenet_model=referencenet_model_path, + cross_attention_dim=cross_attention_dim, + ) +else: + referencenet = None + referencenet_model_name = "no" + +if vision_clip_extractor_class_name is not None and not use_v2v_predictor: + vision_clip_extractor = load_vision_clip_encoder_by_name( + ip_image_encoder=vision_clip_model_path, + vision_clip_extractor_class_name=vision_clip_extractor_class_name, + ) + logger.info( + f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}" + ) +else: + vision_clip_extractor = None + logger.info(f"vision_clip_extractor, None") + +if ip_adapter_model_name is not None and not use_v2v_predictor: + ip_adapter_image_proj = load_ip_adapter_image_proj_by_name( + model_name=ip_adapter_model_name, + ip_image_encoder=ip_adapter_model_params_dict.get( + "ip_image_encoder", vision_clip_model_path + ), + ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=ip_adapter_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_model_params_dict["ip_scale"], + device=device, + ) +else: + ip_adapter_image_proj = None + ip_adapter_model_name = "no" + +for model_name, sd_model_params in sd_model_params_dict.items(): + lora_dict = sd_model_params.get("lora", None) + model_sex = sd_model_params.get("sex", None) + model_style = sd_model_params.get("style", None) + sd_model_path = sd_model_params["sd"] + test_model_vae_model_path = sd_model_params.get("vae", vae_model_path) + + unet = ( + load_unet_by_name( + model_name=unet_model_name, + sd_unet_model=unet_model_path, + sd_model=sd_model_path, + # sd_model="../../checkpoints//Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth", + cross_attention_dim=cross_attention_dim, + need_t2i_facein=facein_model_name is not None, + # facein 目前没参与训练,但在unet中定义了,载入相关参数会报错,所以用strict控制 + strict=not (facein_model_name is not None), + need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None, + ) + if not use_v2v_predictor + else None + ) + + if facein_model_name is not None and not use_v2v_predictor: + ( + face_emb_extractor, + facein_image_proj, + ) = load_facein_extractor_and_proj_by_name( + model_name=facein_model_name, + ip_image_encoder=facein_model_params_dict["ip_image_encoder"], + ip_ckpt=facein_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=facein_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=facein_model_params_dict["ip_scale"], + device=device, + # facein目前没有参与unet中的训练,需要单独载入参数 + unet=unet, + ) + else: + face_emb_extractor = None + facein_image_proj = None + + if ip_adapter_face_model_name is not None and not use_v2v_predictor: + ( + ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj, + ) = load_ip_adapter_face_extractor_and_proj_by_name( + model_name=ip_adapter_face_model_name, + ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"], + ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_face_model_params_dict[ + "clip_embeddings_dim" + ], + clip_extra_context_tokens=ip_adapter_face_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_face_model_params_dict["ip_scale"], + device=device, + unet=unet, # ip_adapter_face 目前没有参与unet中的训练,需要单独载入参数 + ) + else: + ip_adapter_face_emb_extractor = None + ip_adapter_face_image_proj = None + + print("test_model_vae_model_path", test_model_vae_model_path) + + sd_predictor = ( + DiffusersPipelinePredictor( + sd_model_path=sd_model_path, + unet=unet, + lora_dict=lora_dict, + lcm_lora_dct=lcm_lora_dct, + device=device, + dtype=torch_dtype, + negative_embedding=negative_embedding, + referencenet=referencenet, + ip_adapter_image_proj=ip_adapter_image_proj, + vision_clip_extractor=vision_clip_extractor, + facein_image_proj=facein_image_proj, + face_emb_extractor=face_emb_extractor, + vae_model=test_model_vae_model_path, + ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj=ip_adapter_face_image_proj, + ) + if not use_v2v_predictor + else video_sd_predictor + ) + if use_v2v_predictor: + print( + "text2video use video_sd_predictor, sd_predictor type is ", + type(sd_predictor), + ) + logger.debug(f"load sd_predictor"), + + # TODO:这里修改为gradio +import cuid + + +def generate_cuid(): + return cuid.cuid() + + +def online_t2v_inference( + prompt, + image_np, + seed, + fps, + w, + h, + video_len, + img_edge_ratio: float = 1.0, + progress=gr.Progress(track_tqdm=True), +): + progress(0, desc="Starting...") + # Save the uploaded image to a specified path + if not os.path.exists(CACHE_PATH): + os.makedirs(CACHE_PATH) + image_cuid = generate_cuid() + + image_path = os.path.join(CACHE_PATH, f"{image_cuid}.jpg") + image = Image.fromarray(image_np) + image.save(image_path) + + time_size = int(video_len) + test_data = { + "name": image_cuid, + "prompt": prompt, + # 'video_path': None, + "condition_images": image_path, + "refer_image": image_path, + "ipadapter_image": image_path, + "height": h, + "width": w, + "img_length_ratio": img_edge_ratio, + # 'style': 'anime', + # 'sex': 'female' + } + batch = [] + texts = [] + print("\n test_data", test_data, model_name) + test_data_name = test_data.get("name", test_data) + prompt = test_data["prompt"] + prompt = prefix_prompt + prompt + suffix_prompt + prompt_hash = get_signature_of_string(prompt, length=5) + test_data["prompt_hash"] = prompt_hash + test_data_height = test_data.get("height", height) + test_data_width = test_data.get("width", width) + test_data_condition_images_path = test_data.get("condition_images", None) + test_data_condition_images_index = test_data.get("condition_images_index", None) + test_data_redraw_condition_image = test_data.get( + "redraw_condition_image", redraw_condition_image + ) + # read condition_image + if ( + test_data_condition_images_path is not None + and use_condition_image + and ( + isinstance(test_data_condition_images_path, list) + or ( + isinstance(test_data_condition_images_path, str) + and is_image(test_data_condition_images_path) + ) + ) + ): + ( + test_data_condition_images, + test_data_condition_images_name, + ) = read_image_and_name(test_data_condition_images_path) + condition_image_height = test_data_condition_images.shape[3] + condition_image_width = test_data_condition_images.shape[4] + logger.debug( + f"test_data_condition_images use {test_data_condition_images_path}" + ) + else: + test_data_condition_images = None + test_data_condition_images_name = "no" + condition_image_height = None + condition_image_width = None + logger.debug(f"test_data_condition_images is None") + + # 当没有指定生成视频的宽高时,使用输入条件的宽高,优先使用 condition_image,低优使用 video + if test_data_height in [None, -1]: + test_data_height = condition_image_height + + if test_data_width in [None, -1]: + test_data_width = condition_image_width + + test_data_img_length_ratio = float( + test_data.get("img_length_ratio", img_length_ratio) + ) + # 为了和video2video保持对齐,使用64而不是8作为宽、高最小粒度 + # test_data_height = int(test_data_height * test_data_img_length_ratio // 8 * 8) + # test_data_width = int(test_data_width * test_data_img_length_ratio // 8 * 8) + test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64) + test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64) + pprint(test_data) + print(f"test_data_height={test_data_height}") + print(f"test_data_width={test_data_width}") + # continue + test_data_style = test_data.get("style", None) + test_data_sex = test_data.get("sex", None) + # 如果使用|进行多参数任务设置时对应的字段是字符串类型,需要显式转换浮点数。 + test_data_motion_speed = float(test_data.get("motion_speed", motion_speed)) + test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise)) + test_data_img_weight = float(test_data.get("img_weight", img_weight)) + logger.debug(f"test_data_condition_images_path {test_data_condition_images_path}") + logger.debug(f"test_data_condition_images_index {test_data_condition_images_index}") + test_data_refer_image_path = test_data.get("refer_image", referencenet_image_path) + test_data_ipadapter_image_path = test_data.get( + "ipadapter_image", ipadapter_image_path + ) + test_data_refer_face_image_path = test_data.get("face_image", face_image_path) + + if negprompt_cfg_path is not None: + if "video_negative_prompt" in test_data: + ( + test_data_video_negative_prompt_name, + test_data_video_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "video_negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_video_negative_prompt_name = video_negative_prompt_name + test_data_video_negative_prompt = video_negative_prompt + if "negative_prompt" in test_data: + ( + test_data_negative_prompt_name, + test_data_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_negative_prompt_name = negative_prompt_name + test_data_negative_prompt = negative_prompt + else: + test_data_video_negative_prompt = test_data.get( + "video_negative_prompt", video_negative_prompt + ) + test_data_video_negative_prompt_name = test_data_video_negative_prompt[ + :negtive_prompt_length + ] + test_data_negative_prompt = test_data.get("negative_prompt", negative_prompt) + test_data_negative_prompt_name = test_data_negative_prompt[ + :negtive_prompt_length + ] + + # 准备 test_data_refer_image + if referencenet is not None: + if test_data_refer_image_path is None: + test_data_refer_image = test_data_condition_images + test_data_refer_image_name = test_data_condition_images_name + logger.debug(f"test_data_refer_image use test_data_condition_images") + else: + test_data_refer_image, test_data_refer_image_name = read_image_and_name( + test_data_refer_image_path + ) + logger.debug(f"test_data_refer_image use {test_data_refer_image_path}") + else: + test_data_refer_image = None + test_data_refer_image_name = "no" + logger.debug(f"test_data_refer_image is None") + + # 准备 test_data_ipadapter_image + if vision_clip_extractor is not None: + if test_data_ipadapter_image_path is None: + test_data_ipadapter_image = test_data_condition_images + test_data_ipadapter_image_name = test_data_condition_images_name + + logger.debug(f"test_data_ipadapter_image use test_data_condition_images") + else: + ( + test_data_ipadapter_image, + test_data_ipadapter_image_name, + ) = read_image_and_name(test_data_ipadapter_image_path) + logger.debug( + f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}" + ) + else: + test_data_ipadapter_image = None + test_data_ipadapter_image_name = "no" + logger.debug(f"test_data_ipadapter_image is None") + + # 准备 test_data_refer_face_image + if facein_image_proj is not None or ip_adapter_face_image_proj is not None: + if test_data_refer_face_image_path is None: + test_data_refer_face_image = test_data_condition_images + test_data_refer_face_image_name = test_data_condition_images_name + + logger.debug(f"test_data_refer_face_image use test_data_condition_images") + else: + ( + test_data_refer_face_image, + test_data_refer_face_image_name, + ) = read_image_and_name(test_data_refer_face_image_path) + logger.debug( + f"test_data_refer_face_image use f{test_data_refer_face_image_path}" + ) + else: + test_data_refer_face_image = None + test_data_refer_face_image_name = "no" + logger.debug(f"test_data_refer_face_image is None") + + # # 当模型的sex、style与test_data同时存在且不相等时,就跳过这个测试用例 + # if ( + # model_sex is not None + # and test_data_sex is not None + # and model_sex != test_data_sex + # ) or ( + # model_style is not None + # and test_data_style is not None + # and model_style != test_data_style + # ): + # print("model doesnt match test_data") + # print("model name: ", model_name) + # print("test_data: ", test_data) + # continue + if add_static_video_prompt: + test_data_video_negative_prompt = "static video, {}".format( + test_data_video_negative_prompt + ) + for i_num in range(n_repeat): + test_data_seed = random.randint(0, 1e8) if seed in [None, -1] else seed + cpu_generator, gpu_generator = set_all_seed(int(test_data_seed)) + save_file_name = ( + f"m={model_name}_rm={referencenet_model_name}_case={test_data_name}" + f"_w={test_data_width}_h={test_data_height}_t={time_size}_nb={n_batch}" + f"_s={test_data_seed}_p={prompt_hash}" + f"_w={test_data_img_weight}" + f"_ms={test_data_motion_speed}" + f"_s={strength}_g={video_guidance_scale}" + f"_c-i={test_data_condition_images_name[:5]}_r-c={test_data_redraw_condition_image}" + f"_w={test_data_w_ind_noise}_{test_data_video_negative_prompt_name}" + f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}" + ) + + save_file_name = clean_str_for_save(save_file_name) + output_path = os.path.join( + output_dir, + f"{save_file_name}.{save_filetype}", + ) + if os.path.exists(output_path) and not overwrite: + print("existed", output_path) + continue + + print("output_path", output_path) + out_videos = sd_predictor.run_pipe_text2video( + video_length=time_size, + prompt=prompt, + width=test_data_width, + height=test_data_height, + generator=gpu_generator, + noise_type=noise_type, + negative_prompt=test_data_negative_prompt, + video_negative_prompt=test_data_video_negative_prompt, + max_batch_num=n_batch, + strength=strength, + need_img_based_video_noise=need_img_based_video_noise, + video_num_inference_steps=video_num_inference_steps, + condition_images=test_data_condition_images, + fix_condition_images=fix_condition_images, + video_guidance_scale=video_guidance_scale, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + redraw_condition_image=test_data_redraw_condition_image, + img_weight=test_data_img_weight, + w_ind_noise=test_data_w_ind_noise, + n_vision_condition=n_vision_condition, + motion_speed=test_data_motion_speed, + need_hist_match=need_hist_match, + video_guidance_scale_end=video_guidance_scale_end, + video_guidance_scale_method=video_guidance_scale_method, + vision_condition_latent_index=test_data_condition_images_index, + refer_image=test_data_refer_image, + fixed_refer_image=fixed_refer_image, + redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet, + ip_adapter_image=test_data_ipadapter_image, + refer_face_image=test_data_refer_face_image, + fixed_refer_face_image=fixed_refer_face_image, + facein_scale=facein_scale, + redraw_condition_image_with_facein=redraw_condition_image_with_facein, + ip_adapter_face_scale=ip_adapter_face_scale, + redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face, + fixed_ip_adapter_image=fixed_ip_adapter_image, + ip_adapter_scale=ip_adapter_scale, + redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + # need_redraw=need_redraw, + # use_video_redraw=use_video_redraw, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + ) + out = np.concatenate([out_videos], axis=0) + texts = ["out"] + save_videos_grid_with_opencv( + out, + output_path, + texts=texts, + fps=fps, + tensor_order="b c t h w", + n_cols=n_cols, + write_info=args.write_info, + save_filetype=save_filetype, + save_images=save_images, + ) + print("Save to", output_path) + print("\n" * 2) + return output_path diff --git a/scripts/gradio/gradio_video2video.py b/scripts/gradio/gradio_video2video.py new file mode 100644 index 0000000000000000000000000000000000000000..ca51afe9b60d7579bd0315800de9aeafdf2d3a7b --- /dev/null +++ b/scripts/gradio/gradio_video2video.py @@ -0,0 +1,1026 @@ +import argparse +import copy +import os +from pathlib import Path +import logging +from collections import OrderedDict +from pprint import pprint +import random +import gradio as gr + +import numpy as np +from omegaconf import OmegaConf, SCMode +import torch +from einops import rearrange, repeat +import cv2 +from PIL import Image +from diffusers.models.autoencoder_kl import AutoencoderKL + +from mmcm.utils.load_util import load_pyhon_obj +from mmcm.utils.seed_util import set_all_seed +from mmcm.utils.signature import get_signature_of_string +from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table +from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d +from mmcm.utils.str_util import clean_str_for_save +from mmcm.vision.data.video_dataset import DecordVideoDataset +from musev.auto_prompt.util import generate_prompts + +from musev.models.controlnet import PoseGuider +from musev.models.facein_loader import load_facein_extractor_and_proj_by_name +from musev.models.referencenet_loader import load_referencenet_by_name +from musev.models.ip_adapter_loader import ( + load_ip_adapter_vision_clip_encoder_by_name, + load_vision_clip_encoder_by_name, + load_ip_adapter_image_proj_by_name, +) +from musev.models.ip_adapter_face_loader import ( + load_ip_adapter_face_extractor_and_proj_by_name, +) +from musev.pipelines.pipeline_controlnet_predictor import ( + DiffusersPipelinePredictor, +) +from musev.models.referencenet import ReferenceNet2D +from musev.models.unet_loader import load_unet_by_name +from musev.utils.util import save_videos_grid_with_opencv +from musev import logger + +logger.setLevel("INFO") + +file_dir = os.path.dirname(__file__) +PROJECT_DIR = os.path.join(os.path.dirname(__file__), "../..") +DATA_DIR = os.path.join(PROJECT_DIR, "data") +CACHE_PATH = "./t2v_input_image" + + +# TODO:use group to group arguments +args_dict = { + "add_static_video_prompt": False, + "context_batch_size": 1, + "context_frames": 12, + "context_overlap": 4, + "context_schedule": "uniform_v2", + "context_stride": 1, + "controlnet_conditioning_scale": 1.0, + "controlnet_name": "dwpose_body_hand", + "cross_attention_dim": 768, + "enable_zero_snr": False, + "end_to_end": True, + "face_image_path": None, + "facein_model_cfg_path": "../../configs/model/facein.py", + "facein_model_name": None, + "facein_scale": 1.0, + "fix_condition_images": False, + "fixed_ip_adapter_image": True, + "fixed_refer_face_image": True, + "fixed_refer_image": True, + "fps": 4, + "guidance_scale": 7.5, + "height": None, + "img_length_ratio": 1.0, + "img_weight": 0.001, + "interpolation_factor": 1, + "ip_adapter_face_model_cfg_path": "../../configs/model/ip_adapter.py", + "ip_adapter_face_model_name": None, + "ip_adapter_face_scale": 1.0, + "ip_adapter_model_cfg_path": "../../configs/model/ip_adapter.py", + "ip_adapter_model_name": "musev_referencenet_pose", + "ip_adapter_scale": 1.0, + "ipadapter_image_path": None, + "lcm_model_cfg_path": "../../configs/model/lcm_model.py", + "lcm_model_name": None, + "log_level": "INFO", + "motion_speed": 8.0, + "n_batch": 1, + "n_cols": 3, + "n_repeat": 1, + "n_vision_condition": 1, + "need_hist_match": False, + "need_img_based_video_noise": True, + "need_return_condition": False, + "need_return_videos": False, + "need_video2video": False, + "negative_prompt": "V2", + "negprompt_cfg_path": "../../configs/model/negative_prompt.py", + "noise_type": "video_fusion", + "num_inference_steps": 30, + "output_dir": "./results/", + "overwrite": False, + "pose_guider_model_path": None, + "prompt_only_use_image_prompt": False, + "record_mid_video_latents": False, + "record_mid_video_noises": False, + "redraw_condition_image": False, + "redraw_condition_image_with_facein": True, + "redraw_condition_image_with_ip_adapter_face": True, + "redraw_condition_image_with_ipdapter": True, + "redraw_condition_image_with_referencenet": True, + "referencenet_image_path": None, + "referencenet_model_cfg_path": "../../configs/model/referencenet.py", + "referencenet_model_name": "musev_referencenet", + "sample_rate": 1, + "save_filetype": "mp4", + "save_images": False, + "sd_model_cfg_path": "../../configs/model/T2I_all_model.py", + "sd_model_name": "majicmixRealv6Fp16", + "seed": None, + "strength": 0.8, + "target_datas": "boy_dance2", + "test_data_path": "./configs/infer/testcase_video_famous.yaml", + "time_size": 12, + "unet_model_cfg_path": "../../configs/model/motion_model.py", + "unet_model_name": "musev_referencenet_pose", + "use_condition_image": True, + "vae_model_path": "../../checkpoints/vae/sd-vae-ft-mse", + "video_guidance_scale": 3.5, + "video_guidance_scale_end": None, + "video_guidance_scale_method": "linear", + "video_has_condition": True, + "video_is_middle": False, + "video_negative_prompt": "V2", + "video_num_inference_steps": 10, + "video_overlap": 1, + "video_strength": 1.0, + "vision_clip_extractor_class_name": "ImageClipVisionFeatureExtractor", + "vision_clip_model_path": "../../checkpoints/IP-Adapter/models/image_encoder", + "w_ind_noise": 0.5, + "which2video": "video_middle", + "width": None, + "write_info": False, +} +args = argparse.Namespace(**args_dict) +print("args") +pprint(args.__dict__) +print("\n") + +logger.setLevel(args.log_level) +overwrite = args.overwrite +cross_attention_dim = args.cross_attention_dim +time_size = args.time_size # 一次视频生成的帧数 +n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch +fps = args.fps +fix_condition_images = args.fix_condition_images +use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像 +redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的 +need_img_based_video_noise = ( + args.need_img_based_video_noise +) # 视频加噪过程中是否使用首帧 condition_images +img_weight = args.img_weight +height = args.height # 如果测试数据中没有单独指定宽高,则默认这里 +width = args.width # 如果测试数据中没有单独指定宽高,则默认这里 +img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里 +n_cols = args.n_cols +noise_type = args.noise_type +strength = args.strength # 首帧重绘程度参数 +video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数 +guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数 +video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数 +num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数 +seed = args.seed +save_filetype = args.save_filetype +save_images = args.save_images +sd_model_cfg_path = args.sd_model_cfg_path +sd_model_name = ( + args.sd_model_name if args.sd_model_name == "all" else args.sd_model_name.split(",") +) +unet_model_cfg_path = args.unet_model_cfg_path +unet_model_name = args.unet_model_name +test_data_path = args.test_data_path +target_datas = ( + args.target_datas if args.target_datas == "all" else args.target_datas.split(",") +) +device = "cuda" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 +controlnet_name = args.controlnet_name +controlnet_name_str = controlnet_name +if controlnet_name is not None: + controlnet_name = controlnet_name.split(",") + if len(controlnet_name) == 1: + controlnet_name = controlnet_name[0] + +video_strength = args.video_strength # 视频重绘程度参数 +sample_rate = args.sample_rate +controlnet_conditioning_scale = args.controlnet_conditioning_scale + +end_to_end = args.end_to_end # 是否首尾相连生成长视频 +control_guidance_start = 0.0 +control_guidance_end = 0.5 +control_guidance_end = 1.0 +negprompt_cfg_path = args.negprompt_cfg_path +video_negative_prompt = args.video_negative_prompt +negative_prompt = args.negative_prompt +motion_speed = args.motion_speed +need_hist_match = args.need_hist_match +video_guidance_scale_end = args.video_guidance_scale_end +video_guidance_scale_method = args.video_guidance_scale_method +add_static_video_prompt = args.add_static_video_prompt +n_vision_condition = args.n_vision_condition +lcm_model_cfg_path = args.lcm_model_cfg_path +lcm_model_name = args.lcm_model_name +referencenet_model_cfg_path = args.referencenet_model_cfg_path +referencenet_model_name = args.referencenet_model_name +ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path +ip_adapter_model_name = args.ip_adapter_model_name +vision_clip_model_path = args.vision_clip_model_path +vision_clip_extractor_class_name = args.vision_clip_extractor_class_name +facein_model_cfg_path = args.facein_model_cfg_path +facein_model_name = args.facein_model_name +ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path +ip_adapter_face_model_name = args.ip_adapter_face_model_name + +fixed_refer_image = args.fixed_refer_image +fixed_ip_adapter_image = args.fixed_ip_adapter_image +fixed_refer_face_image = args.fixed_refer_face_image +redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet +redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter +redraw_condition_image_with_facein = args.redraw_condition_image_with_facein +redraw_condition_image_with_ip_adapter_face = ( + args.redraw_condition_image_with_ip_adapter_face +) +w_ind_noise = args.w_ind_noise +ip_adapter_scale = args.ip_adapter_scale +facein_scale = args.facein_scale +ip_adapter_face_scale = args.ip_adapter_face_scale +face_image_path = args.face_image_path +ipadapter_image_path = args.ipadapter_image_path +referencenet_image_path = args.referencenet_image_path +vae_model_path = args.vae_model_path +prompt_only_use_image_prompt = args.prompt_only_use_image_prompt +pose_guider_model_path = args.pose_guider_model_path +need_video2video = args.need_video2video +# serial_denoise parameter start +record_mid_video_noises = args.record_mid_video_noises +record_mid_video_latents = args.record_mid_video_latents +video_overlap = args.video_overlap +# serial_denoise parameter end +# parallel_denoise parameter start +context_schedule = args.context_schedule +context_frames = args.context_frames +context_stride = args.context_stride +context_overlap = args.context_overlap +context_batch_size = args.context_batch_size +interpolation_factor = args.interpolation_factor +n_repeat = args.n_repeat + +video_is_middle = args.video_is_middle +video_has_condition = args.video_has_condition +need_return_videos = args.need_return_videos +need_return_condition = args.need_return_condition +# parallel_denoise parameter end +need_controlnet = controlnet_name is not None + +which2video = args.which2video +if which2video == "video": + which2video_name = "v2v" +elif which2video == "video_middle": + which2video_name = "vm2v" +else: + raise ValueError( + "which2video only support video, video_middle, but given {which2video}" + ) +b = 1 +negative_embedding = [ + ["../../checkpoints/embedding/badhandv4.pt", "badhandv4"], + [ + "../../checkpoints/embedding/ng_deepnegative_v1_75t.pt", + "ng_deepnegative_v1_75t", + ], + [ + "../../checkpoints/embedding/EasyNegativeV2.safetensors", + "EasyNegativeV2", + ], + [ + "../../checkpoints/embedding/bad_prompt_version2-neg.pt", + "bad_prompt_version2-neg", + ], +] +prefix_prompt = "" +suffix_prompt = ", beautiful, masterpiece, best quality" +suffix_prompt = "" + +if sd_model_name != "None": + # 使用 cfg_path 里的sd_model_path + sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG") + sd_model_params_dict = { + k: v + for k, v in sd_model_params_dict_src.items() + if sd_model_name == "all" or k in sd_model_name + } +else: + # 使用命令行给的sd_model_path, 需要单独设置 sd_model_name 为None, + sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0] + sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}} + sd_model_params_dict_src = sd_model_params_dict +if len(sd_model_params_dict) == 0: + raise ValueError( + "has not target model, please set one of {}".format( + " ".join(list(sd_model_params_dict_src.keys())) + ) + ) +print("running model, T2I SD") +pprint(sd_model_params_dict) + +# lcm +if lcm_model_name is not None: + lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG") + print("lcm_model_params_dict_src") + lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name] +else: + lcm_lora_dct = None +print("lcm: ", lcm_model_name, lcm_lora_dct) + + +# motion net parameters +if os.path.isdir(unet_model_cfg_path): + unet_model_path = unet_model_cfg_path +elif os.path.isfile(unet_model_cfg_path): + unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG") + print("unet_model_params_dict_src", unet_model_params_dict_src.keys()) + unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"] +else: + raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}") +print("unet: ", unet_model_name, unet_model_path) + + +# referencenet +if referencenet_model_name is not None: + if os.path.isdir(referencenet_model_cfg_path): + referencenet_model_path = referencenet_model_cfg_path + elif os.path.isfile(referencenet_model_cfg_path): + referencenet_model_params_dict_src = load_pyhon_obj( + referencenet_model_cfg_path, "MODEL_CFG" + ) + print( + "referencenet_model_params_dict_src", + referencenet_model_params_dict_src.keys(), + ) + referencenet_model_path = referencenet_model_params_dict_src[ + referencenet_model_name + ]["net"] + else: + raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}") +else: + referencenet_model_path = None +print("referencenet: ", referencenet_model_name, referencenet_model_path) + + +# ip_adapter +if ip_adapter_model_name is not None: + ip_adapter_model_params_dict_src = load_pyhon_obj( + ip_adapter_model_cfg_path, "MODEL_CFG" + ) + print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys()) + ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[ + ip_adapter_model_name + ] +else: + ip_adapter_model_params_dict = None +print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict) + + +# facein +if facein_model_name is not None: + facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG") + print("facein_model_params_dict_src", facein_model_params_dict_src.keys()) + facein_model_params_dict = facein_model_params_dict_src[facein_model_name] +else: + facein_model_params_dict = None +print("facein: ", facein_model_name, facein_model_params_dict) + +# ip_adapter_face +if ip_adapter_face_model_name is not None: + ip_adapter_face_model_params_dict_src = load_pyhon_obj( + ip_adapter_face_model_cfg_path, "MODEL_CFG" + ) + print( + "ip_adapter_face_model_params_dict_src", + ip_adapter_face_model_params_dict_src.keys(), + ) + ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[ + ip_adapter_face_model_name + ] +else: + ip_adapter_face_model_params_dict = None +print( + "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict +) + + +# negative_prompt +def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10): + name = negative_prompt[:n] + if cfg_path is not None and cfg_path not in ["None", "none"]: + dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG") + negative_prompt = dct[negative_prompt]["prompt"] + + return name, negative_prompt + + +negtive_prompt_length = 10 +video_negative_prompt_name, video_negative_prompt = get_negative_prompt( + video_negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +negative_prompt_name, negative_prompt = get_negative_prompt( + negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) + +print("video_negprompt", video_negative_prompt_name, video_negative_prompt) +print("negprompt", negative_prompt_name, negative_prompt) + +output_dir = args.output_dir +os.makedirs(output_dir, exist_ok=True) + + +# test_data_parameters +def load_yaml(path): + tasks = OmegaConf.to_container( + OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True + ) + return tasks + + +# if test_data_path.endswith(".yaml"): +# test_datas_src = load_yaml(test_data_path) +# elif test_data_path.endswith(".csv"): +# test_datas_src = generate_tasks_from_table(test_data_path) +# else: +# raise ValueError("expect yaml or csv, but given {}".format(test_data_path)) + +# test_datas = [ +# test_data +# for test_data in test_datas_src +# if target_datas == "all" or test_data.get("name", None) in target_datas +# ] + +# test_datas = fiss_tasks(test_datas) +# test_datas = generate_prompts(test_datas) + +# n_test_datas = len(test_datas) +# if n_test_datas == 0: +# raise ValueError( +# "n_test_datas == 0, set target_datas=None or set atleast one of {}".format( +# " ".join(list(d.get("name", "None") for d in test_datas_src)) +# ) +# ) +# print("n_test_datas", n_test_datas) +# # pprint(test_datas) + + +def read_image(path): + name = os.path.basename(path).split(".")[0] + image = read_image_as_5d(path) + return image, name + + +def read_image_lst(path): + images_names = [read_image(x) for x in path] + images, names = zip(*images_names) + images = np.concatenate(images, axis=2) + name = "_".join(names) + return images, name + + +def read_image_and_name(path): + if isinstance(path, str): + path = [path] + images, name = read_image_lst(path) + return images, name + + +if referencenet_model_name is not None: + referencenet = load_referencenet_by_name( + model_name=referencenet_model_name, + # sd_model=sd_model_path, + # sd_model="../../checkpoints/Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth", + sd_referencenet_model=referencenet_model_path, + cross_attention_dim=cross_attention_dim, + ) +else: + referencenet = None + referencenet_model_name = "no" + +if vision_clip_extractor_class_name is not None: + vision_clip_extractor = load_vision_clip_encoder_by_name( + ip_image_encoder=vision_clip_model_path, + vision_clip_extractor_class_name=vision_clip_extractor_class_name, + ) + logger.info( + f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}" + ) +else: + vision_clip_extractor = None + logger.info(f"vision_clip_extractor, None") + +if ip_adapter_model_name is not None: + ip_adapter_image_proj = load_ip_adapter_image_proj_by_name( + model_name=ip_adapter_model_name, + ip_image_encoder=ip_adapter_model_params_dict.get( + "ip_image_encoder", vision_clip_model_path + ), + ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=ip_adapter_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_model_params_dict["ip_scale"], + device=device, + ) +else: + ip_adapter_image_proj = None + ip_adapter_model_name = "no" + +if pose_guider_model_path is not None: + logger.info(f"PoseGuider ={pose_guider_model_path}") + pose_guider = PoseGuider.from_pretrained( + pose_guider_model_path, + conditioning_embedding_channels=320, + block_out_channels=(16, 32, 96, 256), + ) +else: + pose_guider = None + +for model_name, sd_model_params in sd_model_params_dict.items(): + lora_dict = sd_model_params.get("lora", None) + model_sex = sd_model_params.get("sex", None) + model_style = sd_model_params.get("style", None) + sd_model_path = sd_model_params["sd"] + test_model_vae_model_path = sd_model_params.get("vae", vae_model_path) + + unet = load_unet_by_name( + model_name=unet_model_name, + sd_unet_model=unet_model_path, + sd_model=sd_model_path, + # sd_model="../../checkpoints/Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth", + cross_attention_dim=cross_attention_dim, + need_t2i_facein=facein_model_name is not None, + # facein 目前没参与训练,但在unet中定义了,载入相关参数会报错,所以用strict控制 + strict=not (facein_model_name is not None), + need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None, + ) + + if facein_model_name is not None: + ( + face_emb_extractor, + facein_image_proj, + ) = load_facein_extractor_and_proj_by_name( + model_name=facein_model_name, + ip_image_encoder=facein_model_params_dict["ip_image_encoder"], + ip_ckpt=facein_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=facein_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=facein_model_params_dict["ip_scale"], + device=device, + # facein目前没有参与unet中的训练,需要单独载入参数 + unet=unet, + ) + else: + face_emb_extractor = None + facein_image_proj = None + + if ip_adapter_face_model_name is not None: + ( + ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj, + ) = load_ip_adapter_face_extractor_and_proj_by_name( + model_name=ip_adapter_face_model_name, + ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"], + ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_face_model_params_dict[ + "clip_embeddings_dim" + ], + clip_extra_context_tokens=ip_adapter_face_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_face_model_params_dict["ip_scale"], + device=device, + unet=unet, # ip_adapter_face 目前没有参与unet中的训练,需要单独载入参数 + ) + else: + ip_adapter_face_emb_extractor = None + ip_adapter_face_image_proj = None + + print("test_model_vae_model_path", test_model_vae_model_path) + + sd_predictor = DiffusersPipelinePredictor( + sd_model_path=sd_model_path, + unet=unet, + lora_dict=lora_dict, + lcm_lora_dct=lcm_lora_dct, + device=device, + dtype=torch_dtype, + negative_embedding=negative_embedding, + referencenet=referencenet, + ip_adapter_image_proj=ip_adapter_image_proj, + vision_clip_extractor=vision_clip_extractor, + facein_image_proj=facein_image_proj, + face_emb_extractor=face_emb_extractor, + vae_model=test_model_vae_model_path, + ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj=ip_adapter_face_image_proj, + pose_guider=pose_guider, + controlnet_name=controlnet_name, + # TODO: 一些过期参数,待去掉 + include_body=True, + include_face=False, + include_hand=True, + enable_zero_snr=args.enable_zero_snr, + ) + logger.debug(f"load referencenet"), + +# TODO:这里修改为gradio +import cuid + + +def generate_cuid(): + return cuid.cuid() + + +def online_v2v_inference( + prompt, + image_np, + video, + processor, + seed, + fps, + w, + h, + video_length, + img_edge_ratio: float = 1.0, + progress=gr.Progress(track_tqdm=True), +): + progress(0, desc="Starting...") + # Save the uploaded image to a specified path + if not os.path.exists(CACHE_PATH): + os.makedirs(CACHE_PATH) + image_cuid = generate_cuid() + import pdb + + image_path = os.path.join(CACHE_PATH, f"{image_cuid}.jpg") + image = Image.fromarray(image_np) + image.save(image_path) + time_size = int(video_length) + test_data = { + "name": image_cuid, + "prompt": prompt, + "video_path": video, + "condition_images": image_path, + "refer_image": image_path, + "ipadapter_image": image_path, + "height": h, + "width": w, + "img_length_ratio": img_edge_ratio, + # 'style': 'anime', + # 'sex': 'female' + } + batch = [] + texts = [] + video_path = test_data.get("video_path") + video_reader = DecordVideoDataset( + video_path, + time_size=int(video_length), + step=time_size, + sample_rate=sample_rate, + device="cpu", + data_type="rgb", + channels_order="c t h w", + drop_last=True, + ) + video_height = video_reader.height + video_width = video_reader.width + + print("\n i_test_data", test_data, model_name) + test_data_name = test_data.get("name", test_data) + prompt = test_data["prompt"] + prompt = prefix_prompt + prompt + suffix_prompt + prompt_hash = get_signature_of_string(prompt, length=5) + test_data["prompt_hash"] = prompt_hash + test_data_height = test_data.get("height", height) + test_data_width = test_data.get("width", width) + test_data_condition_images_path = test_data.get("condition_images", None) + test_data_condition_images_index = test_data.get("condition_images_index", None) + test_data_redraw_condition_image = test_data.get( + "redraw_condition_image", redraw_condition_image + ) + # read condition_image + if ( + test_data_condition_images_path is not None + and use_condition_image + and ( + isinstance(test_data_condition_images_path, list) + or ( + isinstance(test_data_condition_images_path, str) + and is_image(test_data_condition_images_path) + ) + ) + ): + ( + test_data_condition_images, + test_data_condition_images_name, + ) = read_image_and_name(test_data_condition_images_path) + condition_image_height = test_data_condition_images.shape[3] + condition_image_width = test_data_condition_images.shape[4] + logger.debug( + f"test_data_condition_images use {test_data_condition_images_path}" + ) + else: + test_data_condition_images = None + test_data_condition_images_name = "no" + condition_image_height = None + condition_image_width = None + logger.debug(f"test_data_condition_images is None") + + # 当没有指定生成视频的宽高时,使用输入条件的宽高,优先使用 condition_image,低优使用 video + if test_data_height in [None, -1]: + test_data_height = condition_image_height + + if test_data_width in [None, -1]: + test_data_width = condition_image_width + + test_data_img_length_ratio = float( + test_data.get("img_length_ratio", img_length_ratio) + ) + + test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64) + test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64) + pprint(test_data) + print(f"test_data_height={test_data_height}") + print(f"test_data_width={test_data_width}") + # continue + test_data_style = test_data.get("style", None) + test_data_sex = test_data.get("sex", None) + # 如果使用|进行多参数任务设置时对应的字段是字符串类型,需要显式转换浮点数。 + test_data_motion_speed = float(test_data.get("motion_speed", motion_speed)) + test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise)) + test_data_img_weight = float(test_data.get("img_weight", img_weight)) + logger.debug(f"test_data_condition_images_path {test_data_condition_images_path}") + logger.debug(f"test_data_condition_images_index {test_data_condition_images_index}") + test_data_refer_image_path = test_data.get("refer_image", referencenet_image_path) + test_data_ipadapter_image_path = test_data.get( + "ipadapter_image", ipadapter_image_path + ) + test_data_refer_face_image_path = test_data.get("face_image", face_image_path) + test_data_video_is_middle = test_data.get("video_is_middle", video_is_middle) + test_data_video_has_condition = test_data.get( + "video_has_condition", video_has_condition + ) + + controlnet_processor_params = { + "detect_resolution": min(test_data_height, test_data_width), + "image_resolution": min(test_data_height, test_data_width), + } + if negprompt_cfg_path is not None: + if "video_negative_prompt" in test_data: + ( + test_data_video_negative_prompt_name, + test_data_video_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "video_negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_video_negative_prompt_name = video_negative_prompt_name + test_data_video_negative_prompt = video_negative_prompt + if "negative_prompt" in test_data: + ( + test_data_negative_prompt_name, + test_data_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_negative_prompt_name = negative_prompt_name + test_data_negative_prompt = negative_prompt + else: + test_data_video_negative_prompt = test_data.get( + "video_negative_prompt", video_negative_prompt + ) + test_data_video_negative_prompt_name = test_data_video_negative_prompt[ + :negtive_prompt_length + ] + test_data_negative_prompt = test_data.get("negative_prompt", negative_prompt) + test_data_negative_prompt_name = test_data_negative_prompt[ + :negtive_prompt_length + ] + + # 准备 test_data_refer_image + if referencenet is not None: + if test_data_refer_image_path is None: + test_data_refer_image = test_data_condition_images + test_data_refer_image_name = test_data_condition_images_name + logger.debug(f"test_data_refer_image use test_data_condition_images") + else: + test_data_refer_image, test_data_refer_image_name = read_image_and_name( + test_data_refer_image_path + ) + logger.debug(f"test_data_refer_image use {test_data_refer_image_path}") + else: + test_data_refer_image = None + test_data_refer_image_name = "no" + logger.debug(f"test_data_refer_image is None") + + # 准备 test_data_ipadapter_image + if vision_clip_extractor is not None: + if test_data_ipadapter_image_path is None: + test_data_ipadapter_image = test_data_condition_images + test_data_ipadapter_image_name = test_data_condition_images_name + + logger.debug(f"test_data_ipadapter_image use test_data_condition_images") + else: + ( + test_data_ipadapter_image, + test_data_ipadapter_image_name, + ) = read_image_and_name(test_data_ipadapter_image_path) + logger.debug( + f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}" + ) + else: + test_data_ipadapter_image = None + test_data_ipadapter_image_name = "no" + logger.debug(f"test_data_ipadapter_image is None") + + # 准备 test_data_refer_face_image + if facein_image_proj is not None or ip_adapter_face_image_proj is not None: + if test_data_refer_face_image_path is None: + test_data_refer_face_image = test_data_condition_images + test_data_refer_face_image_name = test_data_condition_images_name + + logger.debug(f"test_data_refer_face_image use test_data_condition_images") + else: + ( + test_data_refer_face_image, + test_data_refer_face_image_name, + ) = read_image_and_name(test_data_refer_face_image_path) + logger.debug( + f"test_data_refer_face_image use f{test_data_refer_face_image_path}" + ) + else: + test_data_refer_face_image = None + test_data_refer_face_image_name = "no" + logger.debug(f"test_data_refer_face_image is None") + + # # 当模型的sex、style与test_data同时存在且不相等时,就跳过这个测试用例 + # if ( + # model_sex is not None + # and test_data_sex is not None + # and model_sex != test_data_sex + # ) or ( + # model_style is not None + # and test_data_style is not None + # and model_style != test_data_style + # ): + # print("model doesnt match test_data") + # print("model name: ", model_name) + # print("test_data: ", test_data) + # continue + # video + filename = os.path.basename(video_path).split(".")[0] + for i_num in range(n_repeat): + test_data_seed = random.randint(0, 1e8) if seed in [None, -1] else seed + cpu_generator, gpu_generator = set_all_seed(int(test_data_seed)) + + save_file_name = ( + f"{which2video_name}_m={model_name}_rm={referencenet_model_name}_c={test_data_name}" + f"_w={test_data_width}_h={test_data_height}_t={time_size}_n={n_batch}" + f"_vn={video_num_inference_steps}" + f"_w={test_data_img_weight}_w={test_data_w_ind_noise}" + f"_s={test_data_seed}_n={controlnet_name_str}" + f"_s={strength}_g={guidance_scale}_vs={video_strength}_vg={video_guidance_scale}" + f"_p={prompt_hash}_{test_data_video_negative_prompt_name[:10]}" + f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}" + ) + save_file_name = clean_str_for_save(save_file_name) + output_path = os.path.join( + output_dir, + f"{save_file_name}.{save_filetype}", + ) + if os.path.exists(output_path) and not overwrite: + print("existed", output_path) + continue + + if which2video in ["video", "video_middle"]: + need_video2video = False + if which2video == "video": + need_video2video = True + + ( + out_videos, + out_condition, + videos, + ) = sd_predictor.run_pipe_video2video( + video=video_path, + time_size=time_size, + step=time_size, + sample_rate=sample_rate, + need_return_videos=need_return_videos, + need_return_condition=need_return_condition, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + end_to_end=end_to_end, + need_video2video=need_video2video, + video_strength=video_strength, + prompt=prompt, + width=test_data_width, + height=test_data_height, + generator=gpu_generator, + noise_type=noise_type, + negative_prompt=test_data_negative_prompt, + video_negative_prompt=test_data_video_negative_prompt, + max_batch_num=n_batch, + strength=strength, + need_img_based_video_noise=need_img_based_video_noise, + video_num_inference_steps=video_num_inference_steps, + condition_images=test_data_condition_images, + fix_condition_images=fix_condition_images, + video_guidance_scale=video_guidance_scale, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + redraw_condition_image=test_data_redraw_condition_image, + img_weight=test_data_img_weight, + w_ind_noise=test_data_w_ind_noise, + n_vision_condition=n_vision_condition, + motion_speed=test_data_motion_speed, + need_hist_match=need_hist_match, + video_guidance_scale_end=video_guidance_scale_end, + video_guidance_scale_method=video_guidance_scale_method, + vision_condition_latent_index=test_data_condition_images_index, + refer_image=test_data_refer_image, + fixed_refer_image=fixed_refer_image, + redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet, + ip_adapter_image=test_data_ipadapter_image, + refer_face_image=test_data_refer_face_image, + fixed_refer_face_image=fixed_refer_face_image, + facein_scale=facein_scale, + redraw_condition_image_with_facein=redraw_condition_image_with_facein, + ip_adapter_face_scale=ip_adapter_face_scale, + redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face, + fixed_ip_adapter_image=fixed_ip_adapter_image, + ip_adapter_scale=ip_adapter_scale, + redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + controlnet_processor_params=controlnet_processor_params, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + video_is_middle=test_data_video_is_middle, + video_has_condition=test_data_video_has_condition, + ) + else: + raise ValueError( + f"only support video, videomiddle2video, but given {which2video_name}" + ) + print("out_videos.shape", out_videos.shape) + batch = [out_videos] + texts = ["out"] + if videos is not None: + print("videos.shape", videos.shape) + batch.insert(0, videos / 255.0) + texts.insert(0, "videos") + if need_controlnet and out_condition is not None: + if not isinstance(out_condition, list): + print("out_condition", out_condition.shape) + batch.append(out_condition / 255.0) + texts.append(controlnet_name) + else: + batch.extend([x / 255.0 for x in out_condition]) + texts.extend(controlnet_name) + out = np.concatenate(batch, axis=0) + save_videos_grid_with_opencv( + out, + output_path, + texts=texts, + fps=fps, + tensor_order="b c t h w", + n_cols=n_cols, + write_info=args.write_info, + save_filetype=save_filetype, + save_images=save_images, + ) + print("Save to", output_path) + print("\n" * 2) + return output_path diff --git a/scripts/inference/text2video.py b/scripts/inference/text2video.py new file mode 100644 index 0000000000000000000000000000000000000000..50e1233a757d3636fa9b8e10ea3a4971db1342c1 --- /dev/null +++ b/scripts/inference/text2video.py @@ -0,0 +1,1298 @@ +import argparse +import copy +import os +from pathlib import Path +import logging +from collections import OrderedDict +from pprint import pprint +import random + +import numpy as np +from omegaconf import OmegaConf, SCMode +import torch +from einops import rearrange, repeat +import cv2 +from PIL import Image +from diffusers.models.autoencoder_kl import AutoencoderKL + +from mmcm.utils.load_util import load_pyhon_obj +from mmcm.utils.seed_util import set_all_seed +from mmcm.utils.signature import get_signature_of_string +from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table +from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d +from mmcm.utils.str_util import clean_str_for_save +from mmcm.vision.data.video_dataset import DecordVideoDataset +from musev.auto_prompt.util import generate_prompts + + +from musev.models.facein_loader import load_facein_extractor_and_proj_by_name +from musev.models.referencenet_loader import load_referencenet_by_name +from musev.models.ip_adapter_loader import ( + load_ip_adapter_vision_clip_encoder_by_name, + load_vision_clip_encoder_by_name, + load_ip_adapter_image_proj_by_name, +) +from musev.models.ip_adapter_face_loader import ( + load_ip_adapter_face_extractor_and_proj_by_name, +) +from musev.pipelines.pipeline_controlnet_predictor import ( + DiffusersPipelinePredictor, +) +from musev.models.referencenet import ReferenceNet2D +from musev.models.unet_loader import load_unet_by_name +from musev.utils.util import save_videos_grid_with_opencv +from musev import logger + +logger.setLevel("INFO") + +file_dir = os.path.dirname(__file__) +PROJECT_DIR = os.path.join(os.path.dirname(__file__), "../..") +DATA_DIR = os.path.join(PROJECT_DIR, "data") + + +# TODO:use group to group arguments +def parse_args(): + parser = argparse.ArgumentParser(description="musev Text to video") + parser.add_argument( + "-test_data_path", + type=str, + help=( + "Path to the test data configuration file, now only support yaml ext, " + "task file simialr to musev/configs/tasks/example.yaml" + ), + ) + parser.add_argument( + "--target_datas", + type=str, + default="all", + help="Names of the test data to run, to select sub tasks, default=`all`", + ) + parser.add_argument( + "--sd_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "configs/model/T2I_all_model.py"), + help="Path to the model configuration file", + ) + parser.add_argument( + "--sd_model_name", + type=str, + default="all", + help="Names of the models to run, or path.", + ) + parser.add_argument( + "--unet_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/motion_model.py"), + help="Path to motion_cfg path or motion unet path", + ) + parser.add_argument( + "--unet_model_name", + type=str, + default="musev_referencenet", + help=( + "class Name of the unet model, use load_unet_by_name to init unet," + "now only support `musev`, `musev_referencenet`," + ), + ) + parser.add_argument( + "--lcm_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/lcm_model.py"), + help="Path to lcm lora path", + ) + parser.add_argument( + "--lcm_model_name", + type=str, + default=None, + help="lcm model name, None means do not use lcm_lora default=`None`", + choices=[ + "lcm", + ], + ) + parser.add_argument( + "--referencenet_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/referencenet.py"), + help="Path to referencenet model config path", + ) + parser.add_argument( + "--referencenet_model_name", + type=str, + default=None, + help="referencenet model name, None means do not use referencenet, default=`None`", + choices=["musev_referencenet"], + ) + parser.add_argument( + "--ip_adapter_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/ip_adapter.py"), + help="Path to ip_adapter model config path", + ) + parser.add_argument( + "--ip_adapter_model_name", + type=str, + default=None, + help="ip_adapter model name, None means do not use ip_adapter, default=`None`", + choices=["musev_referencenet"], + ) + parser.add_argument( + "--vision_clip_model_path", + type=str, + default="./checkpoints/ip_adapter/models/image_encoder", + help="vision_clip_extractor_class_name vision_clip_model_path, default=`./checkpoints/ip_adapter/models/image_encoder`", + ) + parser.add_argument( + "--vision_clip_extractor_class_name", + type=str, + default=None, + help="vision_clip_extractor_class_name None means according to ip_adapter_model_name, default=`None`", + choices=["ImageClipVisionFeatureExtractor"], + ) + parser.add_argument( + "--facein_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/facein.py"), + help="Path to facein model config path", + ) + parser.add_argument( + "--facein_model_name", + type=str, + default=None, + help="facein model name, None means do not use facein, now unsupported default=`None`", + ) + parser.add_argument( + "--ip_adapter_face_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/ip_adapter.py"), + help="Path to facein model config path", + ) + parser.add_argument( + "--ip_adapter_face_model_name", + type=str, + default=None, + help="facein model name, None means do not use ip_adapter_face, default=`None`", + ) + parser.add_argument( + "--output_dir", + type=str, + default=os.path.join(PROJECT_DIR, "results"), + help="Output directory, default=`musev/results`", + ) + parser.add_argument( + "--save_filetype", + type=str, + default="mp4", + help="Type of file to save the video, default=`mp4`", + choices=["gif", "mp4", "webp", "images"], + ) + parser.add_argument( + "--save_images", + action="store_true", + default=False, + help="more than video, whether save generated video into images, default=`False`", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Whether to overwrite existing files, default=`False`", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed, default=`None`", + ) + parser.add_argument( + "--cross_attention_dim", + type=int, + default=768, + help="Cross attention dimension, default=`768`", + ) + parser.add_argument( + "--n_batch", + type=int, + default=1, + help="Maximum number of iterations to run, total_frames=n_batch*time_size, default=`1`", + ) + parser.add_argument( + "--fps", + type=int, + default=4, + help="Frames per second for save video,default is same to of training, default=`4`", + ) + parser.add_argument( + "--use_condition_image", + action="store_false", + help=( + "Whether to use the first frame of the test dataset as the initial image, default=`True`" + "now only support image" + ), + ) + parser.add_argument( + "--fix_condition_images", + action="store_true", + help=("Whether to fix condition_image for every shot, default=`False`"), + ) + parser.add_argument( + "--redraw_condition_image", + action="store_true", + help="Whether to use the redrawn first frame as the initial image, default=`False`", + ) + parser.add_argument( + "--need_img_based_video_noise", + action="store_false", + help="Whether to use noise based on the initial frame when adding noise to the video, default=`True`", + ) + parser.add_argument( + "--img_weight", + type=float, + default=1e-3, + help="Weight of the vision_condtion frame to video noise, default=`1e-3`", + ) + parser.add_argument( + "--write_info", + action="store_true", + help="Whether to write frame index, default=`False`", + ) + parser.add_argument( + "--height", + type=int, + default=None, + help="Height of the generated video, if none then use height of condition_image, if all none raise error, default=`None`", + ) + parser.add_argument( + "--width", + type=int, + default=None, + help="Width of the generated video, if none then use height of condition_image, if all none raise error, default=`None`", + ) + parser.add_argument( + "--img_length_ratio", + type=float, + default=1.0, + help="ratio to resize target width, target height of generated video, default=`1.0`", + ) + + parser.add_argument( + "--n_cols", + type=int, + default=3, + help="Number of columns in the output video grid, unused, now", + ) + parser.add_argument( + "--time_size", + type=int, + default=12, + help="Number of frames to generate per iteration, same as of training, default=`12`", + ) + parser.add_argument( + "--noise_type", + type=str, + default="video_fusion", + help="Type of noise to add to the video, default=`video_fusion`", + choices=["video_fusion", "random"], + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="guidance_scale of first frame, default=`7.5`", + ) + parser.add_argument( + "--video_guidance_scale", + type=float, + default=3.5, + help="video_guidance_scale of video, the greater the value, the greater the video change, the more likely video error, default=`3.5`", + ) + parser.add_argument( + "--video_guidance_scale_end", + type=float, + default=None, + help="changed video_guidance_scale_end with timesteps, None means unchanged, default=`None`", + ), + parser.add_argument( + "--video_guidance_scale_method", + type=str, + default="linear", + help="generate changed video_guidance_scale with timesteps, default=`linear`", + choices=["linear", "two_stage", "three_stage", "fix_two_stage"], + ), + parser.add_argument( + "--num_inference_steps", + type=int, + default=30, + help="inference steps of first frame redraw, default=`30", + ) + parser.add_argument( + "--video_num_inference_steps", + type=int, + default=10, + help="inference steps of video, default=`10`", + ) + parser.add_argument( + "--strength", + type=float, + default=0.8, + help="Strength of the redrawn first frame, default=`0.8`", + ) + parser.add_argument( + "--negprompt_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "configs/model/negative_prompt.py"), + help="Path to the negtive prompt configuration file", + ) + parser.add_argument( + "--video_negative_prompt", + type=str, + default="V2", + help="video negative prompt", + ), + parser.add_argument( + "--negative_prompt", + type=str, + default="V2", + help="first frame negative prompt", + ), + parser.add_argument( + "--motion_speed", + type=float, + default=8.0, + help="motion speed, sample rate in training stage, default=`8.0`", + ), + parser.add_argument( + "--need_hist_match", + default=False, + action="store_true", + help="wthether hist match video with vis cond, default=`False`", + ), + parser.add_argument( + "--log_level", + type=str, + default="INFO", + ) + parser.add_argument( + "--add_static_video_prompt", + action="store_true", + default=False, + help="add static_video_prompt in head of prompt", + ) + parser.add_argument( + "--n_vision_condition", + type=int, + default=1, + help="num of vision_condition , default=`1`", + ) + parser.add_argument( + "--fixed_refer_image", + action="store_false", + default=True, + help="whether fix referencenet image or not, if none and referencenet is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--fixed_ip_adapter_image", + action="store_false", + default=True, + help="whether fixed_ip_adapter_image or not , if none and ipadapter is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--fixed_refer_face_image", + action="store_false", + default=True, + help="whether fix facein image or not, if not and ipadapterfaceid is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--redraw_condition_image_with_referencenet", + action="store_false", + default=True, + help="whether use ip_adapter when redrawing vision condition image default=`True`", + ) + parser.add_argument( + "--redraw_condition_image_with_ipdapter", + action="store_false", + default=True, + help="whether use ip_adapter when redrawing vision condition image default=`True`", + ) + parser.add_argument( + "--redraw_condition_image_with_facein", + action="store_false", + default=True, + help="whether use face tool when redrawing vision condition image, default=`True`", + ) + parser.add_argument( + "--w_ind_noise", + default=0.5, + type=float, + help="independent ration of videofusion noise, the greater the value, the greater the video change, the more likely video error, default=`0.5`", + ) + parser.add_argument( + "--ip_adapter_scale", + default=1.0, + type=float, + help="ipadapter weight, default=`1.0`", + ) + parser.add_argument( + "--facein_scale", + default=1.0, + type=float, + help="facein weight, default=`1.0`", + ) + parser.add_argument( + "--face_image_path", + default=None, + type=str, + help="face_image_str, default=`None`", + ) + parser.add_argument( + "--ipadapter_image_path", + default=None, + type=str, + help="face_image_str, default=`None`", + ) + parser.add_argument( + "--referencenet_image_path", + default=None, + type=str, + help="referencenet_image_path, default=`None`", + ) + parser.add_argument( + "--vae_model_path", + default="./checkpoints/vae/sd-vae-ft-mse", + type=str, + help="vae path, default=`./checkpoints/vae/sd-vae-ft-mse`", + ) + parser.add_argument( + "--redraw_condition_image_with_ip_adapter_face", + action="store_false", + default=True, + help="whether use facein when redrawing vision condition image, default=`True`", + ) + parser.add_argument( + "--ip_adapter_face_scale", + default=1.0, + type=float, + help="ip_adapter face default=`1.0`", + ) + parser.add_argument( + "--prompt_only_use_image_prompt", + action="store_true", + default=False, + help="prompt_only_use_image_prompt, if true, replace text_prompt_emb with image_prompt_emb in ip_adapter_cross_attn, default=`False`", + ) + parser.add_argument( + "--record_mid_video_noises", + action="store_true", + default=False, + help="whether record middle timestep noise of the last frames of last shot, default=`False`", + ) + parser.add_argument( + "--record_mid_video_latents", + action="store_true", + default=False, + help="whether record middle timestep latent of the last frames of last shot, default=`False`", + ) + parser.add_argument( + "--video_overlap", + default=1, + type=int, + help="overlap when generate long video with end2end method, default=`1`", + ) + parser.add_argument( + "--context_schedule", + default="uniform_v2", + type=str, + help="how to generate multi shot index when parallel denoise, default=`uniform_v2`", + choices=["uniform", "uniform_v2"], + ) + parser.add_argument( + "--context_frames", + default=12, + type=int, + help="window size of a subshot in parallel denoise, default=`12`", + ) + parser.add_argument( + "--context_stride", + default=1, + type=int, + help="window stride of a subshot in parallel denoise, unvalid paramter, to delete, default=`1`", + ) + parser.add_argument( + "--context_overlap", + default=4, + type=int, + help="window overlap of a subshot in parallel denoise,default=`4`", + ) + parser.add_argument( + "--context_batch_size", + default=1, + type=int, + help="num of subshot in parallel denoise, change in batch_size, need more gpu memory, default=`1`", + ) + parser.add_argument( + "--interpolation_factor", + default=1, + type=int, + help="whether do super resolution to latents, `1` means do nothing, default=`1`", + ) + parser.add_argument( + "--n_repeat", + default=1, + type=int, + help="repeat times for every task, default=`1`", + ) + args = parser.parse_args() + return args + + +args = parse_args() +print("args") +pprint(args.__dict__) +print("\n") + +logger.setLevel(args.log_level) +overwrite = args.overwrite +cross_attention_dim = args.cross_attention_dim +time_size = args.time_size # 一次视频生成的帧数 +n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch +fps = args.fps +fix_condition_images = args.fix_condition_images +use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像 +redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的 +need_img_based_video_noise = ( + args.need_img_based_video_noise +) # 视频加噪过程中是否使用首帧 condition_images +img_weight = args.img_weight +height = args.height # 如果测试数据中没有单独指定宽高,则默认这里 +width = args.width # 如果测试数据中没有单独指定宽高,则默认这里 +img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里 +n_cols = args.n_cols +noise_type = args.noise_type +strength = args.strength # 首帧重绘程度参数 +video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数 +guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数 +video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数 +num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数 +seed = args.seed +save_filetype = args.save_filetype +save_images = args.save_images +sd_model_cfg_path = args.sd_model_cfg_path +sd_model_name = ( + args.sd_model_name + if args.sd_model_name in ["all", "None"] + else args.sd_model_name.split(",") +) +unet_model_cfg_path = args.unet_model_cfg_path +unet_model_name = args.unet_model_name +test_data_path = args.test_data_path +target_datas = ( + args.target_datas if args.target_datas == "all" else args.target_datas.split(",") +) +device = "cuda" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 +negprompt_cfg_path = args.negprompt_cfg_path +video_negative_prompt = args.video_negative_prompt +negative_prompt = args.negative_prompt +motion_speed = args.motion_speed +need_hist_match = args.need_hist_match +video_guidance_scale_end = args.video_guidance_scale_end +video_guidance_scale_method = args.video_guidance_scale_method +add_static_video_prompt = args.add_static_video_prompt +n_vision_condition = args.n_vision_condition +lcm_model_cfg_path = args.lcm_model_cfg_path +lcm_model_name = args.lcm_model_name +referencenet_model_cfg_path = args.referencenet_model_cfg_path +referencenet_model_name = args.referencenet_model_name +ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path +ip_adapter_model_name = args.ip_adapter_model_name +vision_clip_model_path = args.vision_clip_model_path +vision_clip_extractor_class_name = args.vision_clip_extractor_class_name +facein_model_cfg_path = args.facein_model_cfg_path +facein_model_name = args.facein_model_name +ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path +ip_adapter_face_model_name = args.ip_adapter_face_model_name + +fixed_refer_image = args.fixed_refer_image +fixed_ip_adapter_image = args.fixed_ip_adapter_image +fixed_refer_face_image = args.fixed_refer_face_image +redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet +redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter +redraw_condition_image_with_facein = args.redraw_condition_image_with_facein +redraw_condition_image_with_ip_adapter_face = ( + args.redraw_condition_image_with_ip_adapter_face +) +w_ind_noise = args.w_ind_noise +ip_adapter_scale = args.ip_adapter_scale +facein_scale = args.facein_scale +ip_adapter_face_scale = args.ip_adapter_face_scale +face_image_path = args.face_image_path +ipadapter_image_path = args.ipadapter_image_path +referencenet_image_path = args.referencenet_image_path +vae_model_path = args.vae_model_path +prompt_only_use_image_prompt = args.prompt_only_use_image_prompt +# serial_denoise parameter start +record_mid_video_noises = args.record_mid_video_noises +record_mid_video_latents = args.record_mid_video_latents +video_overlap = args.video_overlap +# serial_denoise parameter end +# parallel_denoise parameter start +context_schedule = args.context_schedule +context_frames = args.context_frames +context_stride = args.context_stride +context_overlap = args.context_overlap +context_batch_size = args.context_batch_size +interpolation_factor = args.interpolation_factor +n_repeat = args.n_repeat + +# parallel_denoise parameter end + +b = 1 +negative_embedding = [ + ["./checkpoints/embedding/badhandv4.pt", "badhandv4"], + [ + "./checkpoints/embedding/ng_deepnegative_v1_75t.pt", + "ng_deepnegative_v1_75t", + ], + [ + "./checkpoints/embedding/EasyNegativeV2.safetensors", + "EasyNegativeV2", + ], + [ + "./checkpoints/embedding/bad_prompt_version2-neg.pt", + "bad_prompt_version2-neg", + ], +] +prefix_prompt = "" +suffix_prompt = ", beautiful, masterpiece, best quality" +suffix_prompt = "" + + +# sd model parameters +if sd_model_name != "None": + # use sd_model_path in sd_model_cfg_path + sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG") + sd_model_params_dict = { + k: v + for k, v in sd_model_params_dict_src.items() + if sd_model_name == "all" or k in sd_model_name + } +else: + # get sd_model_path in sd_model_cfg_path by sd_model_name + # if set path of sd_model_path in cmd, should set sd_model_name as None, + sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0] + sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}} + sd_model_params_dict_src = sd_model_params_dict +if len(sd_model_params_dict) == 0: + raise ValueError( + "has not target model, please set one of {}".format( + " ".join(list(sd_model_params_dict_src.keys())) + ) + ) +print("running model, T2I SD") +pprint(sd_model_params_dict) + +# lcm parameters +if lcm_model_name is not None: + lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG") + print("lcm_model_params_dict_src") + lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name] +else: + lcm_lora_dct = None +print("lcm: ", lcm_model_name, lcm_lora_dct) + + +# motion net parameters +if os.path.isdir(unet_model_cfg_path): + unet_model_path = unet_model_cfg_path +elif os.path.isfile(unet_model_cfg_path): + unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG") + print("unet_model_params_dict_src", unet_model_params_dict_src.keys()) + unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"] +else: + raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}") +print("unet: ", unet_model_name, unet_model_path) + + +# referencenet parameters +if referencenet_model_name is not None: + if os.path.isdir(referencenet_model_cfg_path): + referencenet_model_path = referencenet_model_cfg_path + elif os.path.isfile(referencenet_model_cfg_path): + referencenet_model_params_dict_src = load_pyhon_obj( + referencenet_model_cfg_path, "MODEL_CFG" + ) + print( + "referencenet_model_params_dict_src", + referencenet_model_params_dict_src.keys(), + ) + referencenet_model_path = referencenet_model_params_dict_src[ + referencenet_model_name + ]["net"] + else: + raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}") +else: + referencenet_model_path = None +print("referencenet: ", referencenet_model_name, referencenet_model_path) + + +# ip_adapter parameters +if ip_adapter_model_name is not None: + ip_adapter_model_params_dict_src = load_pyhon_obj( + ip_adapter_model_cfg_path, "MODEL_CFG" + ) + print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys()) + ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[ + ip_adapter_model_name + ] +else: + ip_adapter_model_params_dict = None +print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict) + + +# facein parameters +if facein_model_name is not None: + raise NotImplementedError("unsupported facein by now") + facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG") + print("facein_model_params_dict_src", facein_model_params_dict_src.keys()) + facein_model_params_dict = facein_model_params_dict_src[facein_model_name] +else: + facein_model_params_dict = None +print("facein: ", facein_model_name, facein_model_params_dict) + +# ip_adapter_face +if ip_adapter_face_model_name is not None: + ip_adapter_face_model_params_dict_src = load_pyhon_obj( + ip_adapter_face_model_cfg_path, "MODEL_CFG" + ) + print( + "ip_adapter_face_model_params_dict_src", + ip_adapter_face_model_params_dict_src.keys(), + ) + ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[ + ip_adapter_face_model_name + ] +else: + ip_adapter_face_model_params_dict = None +print( + "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict +) + + +# negative_prompt +def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10): + name = negative_prompt[:n] + if cfg_path is not None and cfg_path not in ["None", "none"]: + dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG") + negative_prompt = dct[negative_prompt]["prompt"] + + return name, negative_prompt + + +negtive_prompt_length = 10 +video_negative_prompt_name, video_negative_prompt = get_negative_prompt( + video_negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +negative_prompt_name, negative_prompt = get_negative_prompt( + negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +print("video_negprompt", video_negative_prompt_name, video_negative_prompt) +print("negprompt", negative_prompt_name, negative_prompt) + +output_dir = args.output_dir +os.makedirs(output_dir, exist_ok=True) + + +# test_data_parameters +def load_yaml(path): + tasks = OmegaConf.to_container( + OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True + ) + return tasks + + +if test_data_path.endswith(".yaml"): + test_datas_src = load_yaml(test_data_path) +elif test_data_path.endswith(".csv"): + test_datas_src = generate_tasks_from_table(test_data_path) +else: + raise ValueError("expect yaml or csv, but given {}".format(test_data_path)) + +test_datas = [ + test_data + for test_data in test_datas_src + if target_datas == "all" or test_data.get("name", None) in target_datas +] + +test_datas = fiss_tasks(test_datas) +test_datas = generate_prompts(test_datas) + +n_test_datas = len(test_datas) +if n_test_datas == 0: + raise ValueError( + "n_test_datas == 0, set target_datas=None or set atleast one of {}".format( + " ".join(list(d.get("name", "None") for d in test_datas_src)) + ) + ) +print("n_test_datas", n_test_datas) +# pprint(test_datas) + + +def read_image(path): + name = os.path.basename(path).split(".")[0] + image = read_image_as_5d(path) + return image, name + + +def read_image_lst(path): + images_names = [read_image(x) for x in path] + images, names = zip(*images_names) + images = np.concatenate(images, axis=2) + name = "_".join(names) + return images, name + + +def read_image_and_name(path): + if isinstance(path, str): + path = [path] + images, name = read_image_lst(path) + return images, name + + +# load referencenet +if referencenet_model_name is not None: + referencenet = load_referencenet_by_name( + model_name=referencenet_model_name, + # sd_model=sd_model_path, + # sd_model="./checkpoints/Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth", + sd_referencenet_model=referencenet_model_path, + cross_attention_dim=cross_attention_dim, + ) +else: + referencenet = None + referencenet_model_name = "no" + +# load vision_clip_extractor +if vision_clip_extractor_class_name is not None: + vision_clip_extractor = load_vision_clip_encoder_by_name( + ip_image_encoder=vision_clip_model_path, + vision_clip_extractor_class_name=vision_clip_extractor_class_name, + ) + logger.info( + f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}" + ) +else: + vision_clip_extractor = None + logger.info(f"vision_clip_extractor, None") + +# load ip_adapter_model +if ip_adapter_model_name is not None: + ip_adapter_image_proj = load_ip_adapter_image_proj_by_name( + model_name=ip_adapter_model_name, + ip_image_encoder=ip_adapter_model_params_dict.get( + "ip_image_encoder", vision_clip_model_path + ), + ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=ip_adapter_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_model_params_dict["ip_scale"], + device=device, + ) +else: + ip_adapter_image_proj = None + ip_adapter_model_name = "no" + +for model_name, sd_model_params in sd_model_params_dict.items(): + lora_dict = sd_model_params.get("lora", None) + model_sex = sd_model_params.get("sex", None) + model_style = sd_model_params.get("style", None) + sd_model_path = sd_model_params["sd"] + test_model_vae_model_path = sd_model_params.get("vae", vae_model_path) + # load unet according test_data + unet = load_unet_by_name( + model_name=unet_model_name, + sd_unet_model=unet_model_path, + sd_model=sd_model_path, + # sd_model="./checkpoints/Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth", + cross_attention_dim=cross_attention_dim, + need_t2i_facein=facein_model_name is not None, + # ip_adapter_face_model_name not train in unet, need load individually + strict=not (ip_adapter_face_model_name is not None), + need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None, + ) + + # load facein according test_data + if facein_model_name is not None: + ( + face_emb_extractor, + facein_image_proj, + ) = load_facein_extractor_and_proj_by_name( + model_name=facein_model_name, + ip_image_encoder=facein_model_params_dict["ip_image_encoder"], + ip_ckpt=facein_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=facein_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=facein_model_params_dict["ip_scale"], + device=device, + unet=unet, + ) + else: + face_emb_extractor = None + facein_image_proj = None + + # load ipadapter_face model according test_data + if ip_adapter_face_model_name is not None: + ( + ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj, + ) = load_ip_adapter_face_extractor_and_proj_by_name( + model_name=ip_adapter_face_model_name, + ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"], + ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_face_model_params_dict[ + "clip_embeddings_dim" + ], + clip_extra_context_tokens=ip_adapter_face_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_face_model_params_dict["ip_scale"], + device=device, + unet=unet, + ) + else: + ip_adapter_face_emb_extractor = None + ip_adapter_face_image_proj = None + + print("test_model_vae_model_path", test_model_vae_model_path) + + # init sd_predictor + sd_predictor = DiffusersPipelinePredictor( + sd_model_path=sd_model_path, + unet=unet, + lora_dict=lora_dict, + lcm_lora_dct=lcm_lora_dct, + device=device, + dtype=torch_dtype, + negative_embedding=negative_embedding, + referencenet=referencenet, + ip_adapter_image_proj=ip_adapter_image_proj, + vision_clip_extractor=vision_clip_extractor, + facein_image_proj=facein_image_proj, + face_emb_extractor=face_emb_extractor, + vae_model=test_model_vae_model_path, + ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj=ip_adapter_face_image_proj, + ) + logger.debug(f"load referencenet"), + + for i_test_data, test_data in enumerate(test_datas): + batch = [] + texts = [] + print("\n i_test_data", i_test_data, model_name) + test_data_name = test_data.get("name", i_test_data) + prompt = test_data["prompt"] + prompt = prefix_prompt + prompt + suffix_prompt + prompt_hash = get_signature_of_string(prompt, length=5) + test_data["prompt_hash"] = prompt_hash + test_data_height = test_data.get("height", height) + test_data_width = test_data.get("width", width) + test_data_condition_images_path = test_data.get("condition_images", None) + test_data_condition_images_index = test_data.get("condition_images_index", None) + test_data_redraw_condition_image = test_data.get( + "redraw_condition_image", redraw_condition_image + ) + # read condition_image + if ( + test_data_condition_images_path is not None + and use_condition_image + and ( + isinstance(test_data_condition_images_path, list) + or ( + isinstance(test_data_condition_images_path, str) + and is_image(test_data_condition_images_path) + ) + ) + ): + ( + test_data_condition_images, + test_data_condition_images_name, + ) = read_image_and_name(test_data_condition_images_path) + condition_image_height = test_data_condition_images.shape[3] + condition_image_width = test_data_condition_images.shape[4] + logger.debug( + f"test_data_condition_images use {test_data_condition_images_path}" + ) + else: + test_data_condition_images = None + test_data_condition_images_name = "no" + condition_image_height = None + condition_image_width = None + logger.debug(f"test_data_condition_images is None") + + # if test_data_height is not assigned, use height of condition, if still None, use of video + if test_data_height is None: + test_data_height = condition_image_height + + if test_data_width is None: + test_data_width = condition_image_width + + test_data_img_length_ratio = float( + test_data.get("img_length_ratio", img_length_ratio) + ) + + # to align height of generated video with video2video, use `64`` as basic pixel unit instead of `8`` + # test_data_height = int(test_data_height * test_data_img_length_ratio // 8 * 8) + # test_data_width = int(test_data_width * test_data_img_length_ratio // 8 * 8) + test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64) + test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64) + pprint(test_data) + print(f"test_data_height={test_data_height}") + print(f"test_data_width={test_data_width}") + # continue + test_data_style = test_data.get("style", None) + test_data_sex = test_data.get("sex", None) + # if paramters in test_data is str, but float in fact, convert it into float,int. + test_data_motion_speed = float(test_data.get("motion_speed", motion_speed)) + test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise)) + test_data_img_weight = float(test_data.get("img_weight", img_weight)) + logger.debug( + f"test_data_condition_images_path {test_data_condition_images_path}" + ) + logger.debug( + f"test_data_condition_images_index {test_data_condition_images_index}" + ) + test_data_refer_image_path = test_data.get( + "refer_image", referencenet_image_path + ) + test_data_ipadapter_image_path = test_data.get( + "ipadapter_image", ipadapter_image_path + ) + test_data_refer_face_image_path = test_data.get("face_image", face_image_path) + + if negprompt_cfg_path is not None: + if "video_negative_prompt" in test_data: + ( + test_data_video_negative_prompt_name, + test_data_video_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "video_negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_video_negative_prompt_name = video_negative_prompt_name + test_data_video_negative_prompt = video_negative_prompt + if "negative_prompt" in test_data: + ( + test_data_negative_prompt_name, + test_data_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_negative_prompt_name = negative_prompt_name + test_data_negative_prompt = negative_prompt + else: + test_data_video_negative_prompt = test_data.get( + "video_negative_prompt", video_negative_prompt + ) + test_data_video_negative_prompt_name = test_data_video_negative_prompt[ + :negtive_prompt_length + ] + test_data_negative_prompt = test_data.get( + "negative_prompt", negative_prompt + ) + test_data_negative_prompt_name = test_data_negative_prompt[ + :negtive_prompt_length + ] + + # prepare test_data_refer_image + if referencenet is not None: + if test_data_refer_image_path is None: + test_data_refer_image = test_data_condition_images + test_data_refer_image_name = test_data_condition_images_name + logger.debug(f"test_data_refer_image use test_data_condition_images") + else: + test_data_refer_image, test_data_refer_image_name = read_image_and_name( + test_data_refer_image_path + ) + logger.debug(f"test_data_refer_image use {test_data_refer_image_path}") + else: + test_data_refer_image = None + test_data_refer_image_name = "no" + logger.debug(f"test_data_refer_image is None") + + # prepare test_data_ipadapter_image + if vision_clip_extractor is not None: + if test_data_ipadapter_image_path is None: + test_data_ipadapter_image = test_data_condition_images + test_data_ipadapter_image_name = test_data_condition_images_name + + logger.debug( + f"test_data_ipadapter_image use test_data_condition_images" + ) + else: + ( + test_data_ipadapter_image, + test_data_ipadapter_image_name, + ) = read_image_and_name(test_data_ipadapter_image_path) + logger.debug( + f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}" + ) + else: + test_data_ipadapter_image = None + test_data_ipadapter_image_name = "no" + logger.debug(f"test_data_ipadapter_image is None") + + # prepare test_data_refer_face_image + + if facein_image_proj is not None or ip_adapter_face_image_proj is not None: + if test_data_refer_face_image_path is None: + test_data_refer_face_image = test_data_condition_images + test_data_refer_face_image_name = test_data_condition_images_name + + logger.debug( + f"test_data_refer_face_image use test_data_condition_images" + ) + else: + ( + test_data_refer_face_image, + test_data_refer_face_image_name, + ) = read_image_and_name(test_data_refer_face_image_path) + logger.debug( + f"test_data_refer_face_image use f{test_data_refer_face_image_path}" + ) + else: + test_data_refer_face_image = None + test_data_refer_face_image_name = "no" + logger.debug(f"test_data_refer_face_image is None") + + # if sex, style of test_data is not aligned with of model + # skip this test_data + + if ( + model_sex is not None + and test_data_sex is not None + and model_sex != test_data_sex + ) or ( + model_style is not None + and test_data_style is not None + and model_style != test_data_style + ): + print("model doesnt match test_data") + print("model name: ", model_name) + print("test_data: ", test_data) + continue + if add_static_video_prompt: + test_data_video_negative_prompt = "static video, {}".format( + test_data_video_negative_prompt + ) + for i_num in range(n_repeat): + test_data_seed = random.randint(0, 1e8) if seed is None else seed + cpu_generator, gpu_generator = set_all_seed(test_data_seed) + save_file_name = ( + f"m={model_name}_rm={referencenet_model_name}_case={test_data_name}" + f"_w={test_data_width}_h={test_data_height}_t={time_size}_nb={n_batch}" + f"_s={test_data_seed}_p={prompt_hash}" + f"_w={test_data_img_weight}" + f"_ms={test_data_motion_speed}" + f"_s={strength}_g={video_guidance_scale}" + f"_c-i={test_data_condition_images_name[:5]}_r-c={test_data_redraw_condition_image}" + f"_w={test_data_w_ind_noise}_{test_data_video_negative_prompt_name}" + f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}" + ) + + save_file_name = clean_str_for_save(save_file_name) + output_path = os.path.join( + output_dir, + f"{save_file_name}.{save_filetype}", + ) + if os.path.exists(output_path) and not overwrite: + print("existed", output_path) + continue + + print("output_path", output_path) + out_videos = sd_predictor.run_pipe_text2video( + video_length=time_size, + prompt=prompt, + width=test_data_width, + height=test_data_height, + generator=gpu_generator, + noise_type=noise_type, + negative_prompt=test_data_negative_prompt, + video_negative_prompt=test_data_video_negative_prompt, + max_batch_num=n_batch, + strength=strength, + need_img_based_video_noise=need_img_based_video_noise, + video_num_inference_steps=video_num_inference_steps, + condition_images=test_data_condition_images, + fix_condition_images=fix_condition_images, + video_guidance_scale=video_guidance_scale, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + redraw_condition_image=test_data_redraw_condition_image, + img_weight=test_data_img_weight, + w_ind_noise=test_data_w_ind_noise, + n_vision_condition=n_vision_condition, + motion_speed=test_data_motion_speed, + need_hist_match=need_hist_match, + video_guidance_scale_end=video_guidance_scale_end, + video_guidance_scale_method=video_guidance_scale_method, + vision_condition_latent_index=test_data_condition_images_index, + refer_image=test_data_refer_image, + fixed_refer_image=fixed_refer_image, + redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet, + ip_adapter_image=test_data_ipadapter_image, + refer_face_image=test_data_refer_face_image, + fixed_refer_face_image=fixed_refer_face_image, + facein_scale=facein_scale, + redraw_condition_image_with_facein=redraw_condition_image_with_facein, + ip_adapter_face_scale=ip_adapter_face_scale, + redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face, + fixed_ip_adapter_image=fixed_ip_adapter_image, + ip_adapter_scale=ip_adapter_scale, + redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + ) + out = np.concatenate([out_videos], axis=0) + texts = ["out"] + save_videos_grid_with_opencv( + out, + output_path, + texts=texts, + fps=fps, + tensor_order="b c t h w", + n_cols=n_cols, + write_info=args.write_info, + save_filetype=save_filetype, + save_images=save_images, + ) + print("Save to", output_path) + print("\n" * 2) diff --git a/scripts/inference/video2video.py b/scripts/inference/video2video.py new file mode 100644 index 0000000000000000000000000000000000000000..701f7788a063193f6ff2b6beefcbff8e6b8aadef --- /dev/null +++ b/scripts/inference/video2video.py @@ -0,0 +1,1489 @@ +import argparse +import copy +import os +from pathlib import Path +import logging +from collections import OrderedDict +from pprint import pprint +import random + +import numpy as np +from omegaconf import OmegaConf, SCMode +import torch +from einops import rearrange, repeat +import cv2 +from PIL import Image +from diffusers.models.autoencoder_kl import AutoencoderKL + +from mmcm.utils.load_util import load_pyhon_obj +from mmcm.utils.seed_util import set_all_seed +from mmcm.utils.signature import get_signature_of_string +from mmcm.utils.task_util import fiss_tasks, generate_tasks as generate_tasks_from_table +from mmcm.vision.utils.data_type_util import is_video, is_image, read_image_as_5d +from mmcm.utils.str_util import clean_str_for_save +from mmcm.vision.data.video_dataset import DecordVideoDataset +from musev.auto_prompt.util import generate_prompts + +from musev.models.controlnet import PoseGuider +from musev.models.facein_loader import load_facein_extractor_and_proj_by_name +from musev.models.referencenet_loader import load_referencenet_by_name +from musev.models.ip_adapter_loader import ( + load_ip_adapter_vision_clip_encoder_by_name, + load_vision_clip_encoder_by_name, + load_ip_adapter_image_proj_by_name, +) +from musev.models.ip_adapter_face_loader import ( + load_ip_adapter_face_extractor_and_proj_by_name, +) +from musev.pipelines.pipeline_controlnet_predictor import ( + DiffusersPipelinePredictor, +) +from musev.models.referencenet import ReferenceNet2D +from musev.models.unet_loader import load_unet_by_name +from musev.utils.util import save_videos_grid_with_opencv +from musev import logger + +logger.setLevel("INFO") + +file_dir = os.path.dirname(__file__) +PROJECT_DIR = os.path.join(os.path.dirname(__file__), "../..") +DATA_DIR = os.path.join(PROJECT_DIR, "data") + + +# TODO:use group to group arguments +def parse_args(): + parser = argparse.ArgumentParser(description="musev video to video") + parser.add_argument( + "-test_data_path", + type=str, + help=( + "Path to the test data configuration file, now only support yaml ext, " + "task file simialr to musev/configs/tasks/example.yaml" + ), + ) + parser.add_argument( + "--target_datas", + type=str, + default="all", + help="Names of the test data to run, to select sub tasks, default=`all`", + ) + parser.add_argument( + "--sd_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "configs/model/T2I_all_model.py"), + help="Path to the model configuration file", + ) + parser.add_argument( + "--sd_model_name", + type=str, + default="all", + help="Names of the models to run, or path.", + ) + parser.add_argument( + "--unet_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/motion_model.py"), + help="Path to motion_cfg path or motion unet path", + ) + parser.add_argument( + "--unet_model_name", + type=str, + default="musev_referencenet", + help=( + "class Name of the unet model, use load_unet_by_name to init unet," + "now only support `musev`, `musev_referencenet`," + ), + ) + parser.add_argument( + "--lcm_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/lcm_model.py"), + help="Path to lcm lora path", + ) + parser.add_argument( + "--lcm_model_name", + type=str, + default=None, + help="lcm model name, None means do not use lcm_lora default=`None`", + choices=[ + "lcm", + ], + ) + parser.add_argument( + "--referencenet_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/referencenet.py"), + help="Path to referencenet model config path", + ) + parser.add_argument( + "--referencenet_model_name", + type=str, + default=None, + help="referencenet model name, None means do not use referencenet, default=`None`", + choices=["musev_referencenet", "musev_referencenet_pose"], + ) + parser.add_argument( + "--ip_adapter_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/ip_adapter.py"), + help="Path to ip_adapter model config path", + ) + parser.add_argument( + "--ip_adapter_model_name", + type=str, + default=None, + help="ip_adapter model name, None means do not use ip_adapter, default=`None`", + choices=["musev_referencenet", "musev_referencenet_pose"], + ) + parser.add_argument( + "--vision_clip_model_path", + type=str, + default="./checkpoints/ip_adapter/models/image_encoder", + help="vision_clip_extractor_class_name vision_clip_model_path, default=`./checkpoints/ip_adapter/models/image_encoder`", + ) + parser.add_argument( + "--vision_clip_extractor_class_name", + type=str, + default=None, + help="vision_clip_extractor_class_name None means according to ip_adapter_model_name, default=`None`", + choices=["ImageClipVisionFeatureExtractor"], + ) + parser.add_argument( + "--facein_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/facein.py"), + help="Path to facein model config path", + ) + parser.add_argument( + "--facein_model_name", + type=str, + default=None, + help="facein model name, None means do not use facein, now unsupported default=`None`", + ) + parser.add_argument( + "--ip_adapter_face_model_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "./configs/model/ip_adapter.py"), + help="Path to facein model config path", + ) + parser.add_argument( + "--ip_adapter_face_model_name", + type=str, + default=None, + help="facein model name, None means do not use ip_adapter_face, default=`None`", + ) + parser.add_argument( + "--output_dir", + type=str, + default=os.path.join(PROJECT_DIR, "results"), + help="Output directory, default=`musev/results`", + ) + parser.add_argument( + "--save_filetype", + type=str, + default="mp4", + help="Type of file to save the video, default=`mp4`", + choices=["gif", "mp4", "webp", "images"], + ) + parser.add_argument( + "--save_images", + action="store_true", + default=False, + help="more than video, whether save generated video into images, default=`False`", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Whether to overwrite existing files, default=`False`", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Random seed, default=`None`", + ) + parser.add_argument( + "--cross_attention_dim", + type=int, + default=768, + help="Cross attention dimension, default=`768`", + ) + parser.add_argument( + "--n_batch", + type=int, + default=1, + help="Maximum number of iterations to run, total_frames=n_batch*time_size, default=`1`", + ) + parser.add_argument( + "--fps", + type=int, + default=4, + help="Frames per second for save video,default is same to of training, default=`4`", + ) + parser.add_argument( + "--use_condition_image", + action="store_false", + help=( + "Whether to use the first frame of the test dataset as the initial image, default=`True`" + "now only support image" + ), + ) + parser.add_argument( + "--fix_condition_images", + action="store_true", + help=("Whether to fix condition_image for every shot, default=`False`"), + ) + parser.add_argument( + "--redraw_condition_image", + action="store_true", + help="Whether to use the redrawn first frame as the initial image, default=`False`", + ) + parser.add_argument( + "--need_img_based_video_noise", + action="store_false", + help="Whether to use noise based on the initial frame when adding noise to the video, default=`True`", + ) + parser.add_argument( + "--img_weight", + type=float, + default=1e-3, + help="Weight of the vision_condtion frame to video noise, default=`1e-3`", + ) + parser.add_argument( + "--write_info", + action="store_true", + help="Whether to write frame index, default=`False`", + ) + parser.add_argument( + "--height", + type=int, + default=None, + help="Height of the generated video, if none then use height of condition_image, if all none raise error, default=`None`", + ) + parser.add_argument( + "--width", + type=int, + default=None, + help="Width of the generated video, if none then use height of condition_image, if all none raise error, default=`None`", + ) + parser.add_argument( + "--img_length_ratio", + type=float, + default=1.0, + help="ratio to resize target width, target height of generated video, default=`1.0`", + ) + + parser.add_argument( + "--n_cols", + type=int, + default=3, + help="Number of columns in the output video grid, unused, now", + ) + parser.add_argument( + "--time_size", + type=int, + default=12, + help="Number of frames to generate per iteration, same as of training, default=`12`", + ) + parser.add_argument( + "--noise_type", + type=str, + default="video_fusion", + help="Type of noise to add to the video, default=`video_fusion`", + choices=["video_fusion", "random"], + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=7.5, + help="guidance_scale of first frame, default=`7.5`", + ) + parser.add_argument( + "--video_guidance_scale", + type=float, + default=3.5, + help="video_guidance_scale of video, the greater the value, the greater the video change, the more likely video error, default=`3.5`", + ) + parser.add_argument( + "--video_guidance_scale_end", + type=float, + default=None, + help="changed video_guidance_scale_end with timesteps, None means unchanged, default=`None`", + ), + parser.add_argument( + "--video_guidance_scale_method", + type=str, + default="linear", + help="generate changed video_guidance_scale with timesteps, default=`linear`", + choices=["linear", "two_stage", "three_stage", "fix_two_stage"], + ), + parser.add_argument( + "--num_inference_steps", + type=int, + default=30, + help="inference steps of first frame redraw, default=`30", + ) + parser.add_argument( + "--video_num_inference_steps", + type=int, + default=10, + help="inference steps of video, default=`10`", + ) + parser.add_argument( + "--strength", + type=float, + default=0.8, + help="Strength of the redrawn image, default=`0.8`", + ) + parser.add_argument( + "--video_strength", + type=float, + default=1.0, + help="Strength of the redrawn video, default=`1.0`", + ) + parser.add_argument( + "--negprompt_cfg_path", + type=str, + default=os.path.join(PROJECT_DIR, "configs/model/negative_prompt.py"), + help="Path to the negtive prompt configuration file", + ) + parser.add_argument( + "--video_negative_prompt", + type=str, + default="V2", + help="video negative prompt", + ), + parser.add_argument( + "--negative_prompt", + type=str, + default="V2", + help="first frame negative prompt", + ), + parser.add_argument( + "--motion_speed", + type=float, + default=8.0, + help="motion speed, sample rate in training stage, default=`8.0`", + ), + parser.add_argument( + "--need_hist_match", + default=False, + action="store_true", + help="wthether hist match video with vis cond, default=`False`", + ), + parser.add_argument( + "--log_level", + type=str, + default="INFO", + ) + parser.add_argument( + "--add_static_video_prompt", + action="store_true", + default=False, + help="add static_video_prompt in head of prompt", + ) + parser.add_argument( + "--n_vision_condition", + type=int, + default=1, + help="num of vision_condition , default=`1`", + ) + parser.add_argument( + "--controlnet_name", + type=str, + default=None, + help="controlnet for video2video, if multicontrolnet, use `,` sep, such as `a,b`, default=`None`", + ) + parser.add_argument( + "--sample_rate", + type=int, + default=1, + help="get one per sample_rate frames from given video, default=`1`", + ) + parser.add_argument( + "--controlnet_conditioning_scale", + type=float, + default=1.0, + help="controlnet 的重绘参数, default=`1.0", + ) + parser.add_argument( + "--which2video", + default="video", + type=str, + choices=["video", "video_middle"], + help=( + "which part to guide video generateion" + "video_middle, only controlnet condition, or called videio middle, like pose, depth" + "video2video, more than video middle, use video guide noise like img2img pipeline, default=`video`" + ), + ), + parser.add_argument( + "--end_to_end", + default=True, + action="store_false", + help="whether end2end to generate long video, default=`True`", + ), + parser.add_argument( + "--fixed_refer_image", + action="store_false", + default=True, + help="whether fix referencenet image or not, if none and referencenet is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--fixed_ip_adapter_image", + action="store_false", + default=True, + help="whether fixed_ip_adapter_image or not , if none and ipadapter is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--fixed_refer_face_image", + action="store_false", + default=True, + help="whether fix facein image or not, if not and ipadapterfaceid is not None, use vision condition frame, default=`True`", + ) + parser.add_argument( + "--redraw_condition_image_with_referencenet", + action="store_false", + default=True, + help="whether use ip_adapter when redrawing vision condition image default=`True`", + ) + parser.add_argument( + "--redraw_condition_image_with_ipdapter", + action="store_false", + default=True, + help="whether use ip_adapter when redrawing vision condition image default=`True`", + ) + parser.add_argument( + "--need_video2video", + action="store_true", + default=False, + help="whether use video guide initial noise, default=`False`", + ) + + parser.add_argument( + "--redraw_condition_image_with_facein", + action="store_false", + default=True, + help="whether use face tool when redrawing vision condition image, default=`True`", + ) + parser.add_argument( + "--w_ind_noise", + default=0.5, + type=float, + help="videofusion_noise 中 独立噪声的比例, default=`0.5`", + ) + parser.add_argument( + "--ip_adapter_scale", + default=1.0, + type=float, + help="ipadapter weight, default=`1.0`", + ) + parser.add_argument( + "--facein_scale", + default=1.0, + type=float, + help="facein weight, default=`1.0`", + ) + parser.add_argument( + "--face_image_path", + default=None, + type=str, + help="face_image_str, default=`None`", + ) + parser.add_argument( + "--ipadapter_image_path", + default=None, + type=str, + help="face_image_str, default=`None`", + ) + parser.add_argument( + "--referencenet_image_path", + default=None, + type=str, + help="referencenet_image_path, default=`None`", + ) + parser.add_argument( + "--vae_model_path", + default="./checkpoints/vae/sd-vae-ft-mse", + type=str, + help="vae path, default=`./checkpoints/vae/sd-vae-ft-mse`", + ) + parser.add_argument( + "--redraw_condition_image_with_ip_adapter_face", + action="store_false", + default=True, + help="whether use facein when redrawing vision condition image, default=`True`", + ) + parser.add_argument( + "--ip_adapter_face_scale", + default=1.0, + type=float, + help="ip_adapter face default=`1.0`", + ) + parser.add_argument( + "--prompt_only_use_image_prompt", + action="store_true", + default=False, + help="prompt_only_use_image_prompt, if true, replace text_prompt_emb with image_prompt_emb in ip_adapter_cross_attn, default=`False`", + ) + # moore animateanyone start + parser.add_argument( + "--pose_guider_model_path", + type=str, + default=None, + help="moore pose_guider, refer to MooreAnimateAnyone, similar to controlnet, default=`None`", + ) + parser.add_argument( + "--enable_zero_snr", + action="store_true", + default=False, + help="whether use zero_snr in scheduler, include v_prediction、trailing, etc , default=`False`", + ) + # moore animateanyone end + + parser.add_argument( + "--record_mid_video_noises", + action="store_true", + default=False, + help="whether record middle timestep noise of the last frames of last shot, default=`False`", + ) + parser.add_argument( + "--record_mid_video_latents", + action="store_true", + default=False, + help="whether record middle timestep latent of the last frames of last shot, default=`False`", + ) + parser.add_argument( + "--video_overlap", + default=1, + type=int, + help="overlap when generate long video with end2end method, default=`1`", + ) + parser.add_argument( + "--context_schedule", + default="uniform_v2", + type=str, + help="how to generate multi shot index when parallel denoise, default=`uniform_v2`", + choices=["uniform", "uniform_v2"], + ) + parser.add_argument( + "--context_frames", + default=12, + type=int, + help="window size of a subshot in parallel denoise, default=`12`", + ) + parser.add_argument( + "--context_stride", + default=1, + type=int, + help="window stride of a subshot in parallel denoise, unvalid paramter, to delete, default=`1`", + ) + parser.add_argument( + "--context_overlap", + default=4, + type=int, + help="window overlap of a subshot in parallel denoise,default=`4`", + ) + parser.add_argument( + "--context_batch_size", + default=1, + type=int, + help="num of subshot in parallel denoise, change in batch_size, need more gpu memory, default=`1`", + ) + parser.add_argument( + "--interpolation_factor", + default=1, + type=int, + help="whether do super resolution to latents, `1` means do nothing, default=`1`", + ) + parser.add_argument( + "--video_is_middle", + action="store_true", + default=False, + help="input video_path is natural rgb video or not, False means pose default=`False`", + ) + parser.add_argument( + "--video_has_condition", + action="store_false", + default=True, + help="if video_is_middle true, whether condition of vision condition image is same as of first frame of video_path or not, default=`True`", + ) + parser.add_argument( + "--need_return_videos", + action="store_true", + default=False, + help="whether save video_path with generated video together, default=`False`", + ) + parser.add_argument( + "--need_return_condition", + action="store_true", + default=False, + help="whether save controlnet_middle with generated video together, default=`False`", + ) + + parser.add_argument( + "--n_repeat", + default=1, + type=int, + help="repeat times for every task, default=`1`", + ) + args = parser.parse_args() + return args + + +args = parse_args() +print("args") +pprint(args.__dict__) +print("\n") + +logger.setLevel(args.log_level) +overwrite = args.overwrite +cross_attention_dim = args.cross_attention_dim +time_size = args.time_size # 一次视频生成的帧数 +n_batch = args.n_batch # 按照time_size的尺寸 生成n_batch次,总帧数 = time_size * n_batch +fps = args.fps +fix_condition_images = args.fix_condition_images +use_condition_image = args.use_condition_image # 当 test_data 中有图像时,作为初始图像 +redraw_condition_image = args.redraw_condition_image # 用于视频生成的首帧是否使用重绘后的 +need_img_based_video_noise = ( + args.need_img_based_video_noise +) # 视频加噪过程中是否使用首帧 condition_images +img_weight = args.img_weight +height = args.height # 如果测试数据中没有单独指定宽高,则默认这里 +width = args.width # 如果测试数据中没有单独指定宽高,则默认这里 +img_length_ratio = args.img_length_ratio # 如果测试数据中没有单独指定图像宽高比resize比例,则默认这里 +n_cols = args.n_cols +noise_type = args.noise_type +strength = args.strength # 首帧重绘程度参数 +video_guidance_scale = args.video_guidance_scale # 视频 condition与 uncond的权重参数 +guidance_scale = args.guidance_scale # 时序条件帧 condition与uncond的权重参数 +video_num_inference_steps = args.video_num_inference_steps # 视频迭代次数 +num_inference_steps = args.num_inference_steps # 时序条件帧 重绘参数 +seed = args.seed +save_filetype = args.save_filetype +save_images = args.save_images +sd_model_cfg_path = args.sd_model_cfg_path +sd_model_name = ( + args.sd_model_name if args.sd_model_name == "all" else args.sd_model_name.split(",") +) +unet_model_cfg_path = args.unet_model_cfg_path +unet_model_name = args.unet_model_name +test_data_path = args.test_data_path +target_datas = ( + args.target_datas if args.target_datas == "all" else args.target_datas.split(",") +) +device = "cuda" if torch.cuda.is_available() else "cpu" +torch_dtype = torch.float16 +controlnet_name = args.controlnet_name +controlnet_name_str = controlnet_name +if controlnet_name is not None: + controlnet_name = controlnet_name.split(",") + if len(controlnet_name) == 1: + controlnet_name = controlnet_name[0] + +video_strength = args.video_strength # 视频重绘程度参数 +sample_rate = args.sample_rate +controlnet_conditioning_scale = args.controlnet_conditioning_scale + +end_to_end = args.end_to_end # 是否首尾相连生成长视频 +control_guidance_start = 0.0 +control_guidance_end = 0.5 +control_guidance_end = 1.0 +negprompt_cfg_path = args.negprompt_cfg_path +video_negative_prompt = args.video_negative_prompt +negative_prompt = args.negative_prompt +motion_speed = args.motion_speed +need_hist_match = args.need_hist_match +video_guidance_scale_end = args.video_guidance_scale_end +video_guidance_scale_method = args.video_guidance_scale_method +add_static_video_prompt = args.add_static_video_prompt +n_vision_condition = args.n_vision_condition +lcm_model_cfg_path = args.lcm_model_cfg_path +lcm_model_name = args.lcm_model_name +referencenet_model_cfg_path = args.referencenet_model_cfg_path +referencenet_model_name = args.referencenet_model_name +ip_adapter_model_cfg_path = args.ip_adapter_model_cfg_path +ip_adapter_model_name = args.ip_adapter_model_name +vision_clip_model_path = args.vision_clip_model_path +vision_clip_extractor_class_name = args.vision_clip_extractor_class_name +facein_model_cfg_path = args.facein_model_cfg_path +facein_model_name = args.facein_model_name +ip_adapter_face_model_cfg_path = args.ip_adapter_face_model_cfg_path +ip_adapter_face_model_name = args.ip_adapter_face_model_name + +fixed_refer_image = args.fixed_refer_image +fixed_ip_adapter_image = args.fixed_ip_adapter_image +fixed_refer_face_image = args.fixed_refer_face_image +redraw_condition_image_with_referencenet = args.redraw_condition_image_with_referencenet +redraw_condition_image_with_ipdapter = args.redraw_condition_image_with_ipdapter +redraw_condition_image_with_facein = args.redraw_condition_image_with_facein +redraw_condition_image_with_ip_adapter_face = ( + args.redraw_condition_image_with_ip_adapter_face +) +w_ind_noise = args.w_ind_noise +ip_adapter_scale = args.ip_adapter_scale +facein_scale = args.facein_scale +ip_adapter_face_scale = args.ip_adapter_face_scale +face_image_path = args.face_image_path +ipadapter_image_path = args.ipadapter_image_path +referencenet_image_path = args.referencenet_image_path +vae_model_path = args.vae_model_path +prompt_only_use_image_prompt = args.prompt_only_use_image_prompt +pose_guider_model_path = args.pose_guider_model_path +need_video2video = args.need_video2video +# serial_denoise parameter start +record_mid_video_noises = args.record_mid_video_noises +record_mid_video_latents = args.record_mid_video_latents +video_overlap = args.video_overlap +# serial_denoise parameter end +# parallel_denoise parameter start +context_schedule = args.context_schedule +context_frames = args.context_frames +context_stride = args.context_stride +context_overlap = args.context_overlap +context_batch_size = args.context_batch_size +interpolation_factor = args.interpolation_factor +n_repeat = args.n_repeat + +video_is_middle = args.video_is_middle +video_has_condition = args.video_has_condition +need_return_videos = args.need_return_videos +need_return_condition = args.need_return_condition +# parallel_denoise parameter end +need_controlnet = controlnet_name is not None + +which2video = args.which2video +if which2video == "video": + which2video_name = "v2v" +elif which2video == "video_middle": + which2video_name = "vm2v" +else: + raise ValueError( + "which2video only support video, video_middle, but given {which2video}" + ) +b = 1 +negative_embedding = [ + ["./checkpoints/embedding/badhandv4.pt", "badhandv4"], + [ + "./checkpoints/embedding/ng_deepnegative_v1_75t.pt", + "ng_deepnegative_v1_75t", + ], + [ + "./checkpoints/embedding/EasyNegativeV2.safetensors", + "EasyNegativeV2", + ], + [ + "./checkpoints/embedding/bad_prompt_version2-neg.pt", + "bad_prompt_version2-neg", + ], +] +prefix_prompt = "" +suffix_prompt = ", beautiful, masterpiece, best quality" +suffix_prompt = "" + +if sd_model_name != "None": + # use sd_model_path in sd_model_cfg_path + sd_model_params_dict_src = load_pyhon_obj(sd_model_cfg_path, "MODEL_CFG") + sd_model_params_dict = { + k: v + for k, v in sd_model_params_dict_src.items() + if sd_model_name == "all" or k in sd_model_name + } +else: + # get sd_model_path in sd_model_cfg_path by sd_model_name + # if set path of sd_model_path in cmd, should set sd_model_name as None, + sd_model_name = os.path.basename(sd_model_cfg_path).split(".")[0] + sd_model_params_dict = {sd_model_name: {"sd": sd_model_cfg_path}} + sd_model_params_dict_src = sd_model_params_dict +if len(sd_model_params_dict) == 0: + raise ValueError( + "has not target model, please set one of {}".format( + " ".join(list(sd_model_params_dict_src.keys())) + ) + ) +print("running model, T2I SD") +pprint(sd_model_params_dict) + +# lcm parameters +if lcm_model_name is not None: + lcm_model_params_dict_src = load_pyhon_obj(lcm_model_cfg_path, "MODEL_CFG") + print("lcm_model_params_dict_src") + lcm_lora_dct = lcm_model_params_dict_src[lcm_model_name] +else: + lcm_lora_dct = None +print("lcm: ", lcm_model_name, lcm_lora_dct) + + +# motion net parameters +if os.path.isdir(unet_model_cfg_path): + unet_model_path = unet_model_cfg_path +elif os.path.isfile(unet_model_cfg_path): + unet_model_params_dict_src = load_pyhon_obj(unet_model_cfg_path, "MODEL_CFG") + print("unet_model_params_dict_src", unet_model_params_dict_src.keys()) + unet_model_path = unet_model_params_dict_src[unet_model_name]["unet"] +else: + raise ValueError(f"expect dir or file, but given {unet_model_cfg_path}") +print("unet: ", unet_model_name, unet_model_path) + + +# referencenet parameters +if referencenet_model_name is not None: + if os.path.isdir(referencenet_model_cfg_path): + referencenet_model_path = referencenet_model_cfg_path + elif os.path.isfile(referencenet_model_cfg_path): + referencenet_model_params_dict_src = load_pyhon_obj( + referencenet_model_cfg_path, "MODEL_CFG" + ) + print( + "referencenet_model_params_dict_src", + referencenet_model_params_dict_src.keys(), + ) + referencenet_model_path = referencenet_model_params_dict_src[ + referencenet_model_name + ]["net"] + else: + raise ValueError(f"expect dir or file, but given {referencenet_model_cfg_path}") +else: + referencenet_model_path = None +print("referencenet: ", referencenet_model_name, referencenet_model_path) + + +# ip_adapter parameters +if ip_adapter_model_name is not None: + ip_adapter_model_params_dict_src = load_pyhon_obj( + ip_adapter_model_cfg_path, "MODEL_CFG" + ) + print("ip_adapter_model_params_dict_src", ip_adapter_model_params_dict_src.keys()) + ip_adapter_model_params_dict = ip_adapter_model_params_dict_src[ + ip_adapter_model_name + ] +else: + ip_adapter_model_params_dict = None +print("ip_adapter: ", ip_adapter_model_name, ip_adapter_model_params_dict) + + +# facein parameters +if facein_model_name is not None: + raise NotImplementedError("unsupported facein by now") + facein_model_params_dict_src = load_pyhon_obj(facein_model_cfg_path, "MODEL_CFG") + print("facein_model_params_dict_src", facein_model_params_dict_src.keys()) + facein_model_params_dict = facein_model_params_dict_src[facein_model_name] +else: + facein_model_params_dict = None +print("facein: ", facein_model_name, facein_model_params_dict) + +# ip_adapter_face +if ip_adapter_face_model_name is not None: + ip_adapter_face_model_params_dict_src = load_pyhon_obj( + ip_adapter_face_model_cfg_path, "MODEL_CFG" + ) + print( + "ip_adapter_face_model_params_dict_src", + ip_adapter_face_model_params_dict_src.keys(), + ) + ip_adapter_face_model_params_dict = ip_adapter_face_model_params_dict_src[ + ip_adapter_face_model_name + ] +else: + ip_adapter_face_model_params_dict = None +print( + "ip_adapter_face: ", ip_adapter_face_model_name, ip_adapter_face_model_params_dict +) + + +# negative_prompt +def get_negative_prompt(negative_prompt, cfg_path=None, n: int = 10): + name = negative_prompt[:n] + if cfg_path is not None and cfg_path not in ["None", "none"]: + dct = load_pyhon_obj(cfg_path, "Negative_Prompt_CFG") + negative_prompt = dct[negative_prompt]["prompt"] + + return name, negative_prompt + + +negtive_prompt_length = 10 +video_negative_prompt_name, video_negative_prompt = get_negative_prompt( + video_negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +negative_prompt_name, negative_prompt = get_negative_prompt( + negative_prompt, + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, +) +print("video_negprompt", video_negative_prompt_name, video_negative_prompt) +print("negprompt", negative_prompt_name, negative_prompt) + +output_dir = args.output_dir +os.makedirs(output_dir, exist_ok=True) + + +# test_data_parameters +def load_yaml(path): + tasks = OmegaConf.to_container( + OmegaConf.load(path), structured_config_mode=SCMode.INSTANTIATE, resolve=True + ) + return tasks + + +if test_data_path.endswith(".yaml"): + test_datas_src = load_yaml(test_data_path) +elif test_data_path.endswith(".csv"): + test_datas_src = generate_tasks_from_table(test_data_path) +else: + raise ValueError("expect yaml or csv, but given {}".format(test_data_path)) + +test_datas = [ + test_data + for test_data in test_datas_src + if target_datas == "all" or test_data.get("name", None) in target_datas +] + +test_datas = fiss_tasks(test_datas) +test_datas = generate_prompts(test_datas) + +n_test_datas = len(test_datas) +if n_test_datas == 0: + raise ValueError( + "n_test_datas == 0, set target_datas=None or set atleast one of {}".format( + " ".join(list(d.get("name", "None") for d in test_datas_src)) + ) + ) +print("n_test_datas", n_test_datas) +# pprint(test_datas) + + +def read_image(path): + name = os.path.basename(path).split(".")[0] + image = read_image_as_5d(path) + return image, name + + +def read_image_lst(path): + images_names = [read_image(x) for x in path] + images, names = zip(*images_names) + images = np.concatenate(images, axis=2) + name = "_".join(names) + return images, name + + +def read_image_and_name(path): + if isinstance(path, str): + path = [path] + images, name = read_image_lst(path) + return images, name + + +# load referencenet +if referencenet_model_name is not None: + referencenet = load_referencenet_by_name( + model_name=referencenet_model_name, + # sd_model=sd_model_path, + # sd_model="./checkpoints/Moore-AnimateAnyone/AnimateAnyone/reference_unet.pth", + sd_referencenet_model=referencenet_model_path, + cross_attention_dim=cross_attention_dim, + ) +else: + referencenet = None + referencenet_model_name = "no" + +# load vision_clip_extractor +if vision_clip_extractor_class_name is not None: + vision_clip_extractor = load_vision_clip_encoder_by_name( + ip_image_encoder=vision_clip_model_path, + vision_clip_extractor_class_name=vision_clip_extractor_class_name, + ) + logger.info( + f"vision_clip_extractor, name={vision_clip_extractor_class_name}, path={vision_clip_model_path}" + ) +else: + vision_clip_extractor = None + logger.info(f"vision_clip_extractor, None") + +# load ip_adapter_model +if ip_adapter_model_name is not None: + ip_adapter_image_proj = load_ip_adapter_image_proj_by_name( + model_name=ip_adapter_model_name, + ip_image_encoder=ip_adapter_model_params_dict.get( + "ip_image_encoder", vision_clip_model_path + ), + ip_ckpt=ip_adapter_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=ip_adapter_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_model_params_dict["ip_scale"], + device=device, + ) +else: + ip_adapter_image_proj = None + ip_adapter_model_name = "no" + +if pose_guider_model_path is not None: + logger.info(f"PoseGuider ={pose_guider_model_path}") + pose_guider = PoseGuider.from_pretrained( + pose_guider_model_path, + conditioning_embedding_channels=320, + block_out_channels=(16, 32, 96, 256), + ) +else: + pose_guider = None + +for model_name, sd_model_params in sd_model_params_dict.items(): + lora_dict = sd_model_params.get("lora", None) + model_sex = sd_model_params.get("sex", None) + model_style = sd_model_params.get("style", None) + sd_model_path = sd_model_params["sd"] + test_model_vae_model_path = sd_model_params.get("vae", vae_model_path) + # load unet according test_data + unet = load_unet_by_name( + model_name=unet_model_name, + sd_unet_model=unet_model_path, + sd_model=sd_model_path, + # sd_model="./checkpoints/Moore-AnimateAnyone/AnimateAnyone/denoising_unet.pth", + cross_attention_dim=cross_attention_dim, + need_t2i_facein=facein_model_name is not None, + # ip_adapter_face_model_name not train in unet, need load individually + strict=not (ip_adapter_face_model_name is not None), + need_t2i_ip_adapter_face=ip_adapter_face_model_name is not None, + ) + + # load facein according test_data + if facein_model_name is not None: + ( + face_emb_extractor, + facein_image_proj, + ) = load_facein_extractor_and_proj_by_name( + model_name=facein_model_name, + ip_image_encoder=facein_model_params_dict["ip_image_encoder"], + ip_ckpt=facein_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=facein_model_params_dict["clip_embeddings_dim"], + clip_extra_context_tokens=facein_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=facein_model_params_dict["ip_scale"], + device=device, + unet=unet, + ) + else: + face_emb_extractor = None + facein_image_proj = None + + # load ipadapter_face model according test_data + if ip_adapter_face_model_name is not None: + ( + ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj, + ) = load_ip_adapter_face_extractor_and_proj_by_name( + model_name=ip_adapter_face_model_name, + ip_image_encoder=ip_adapter_face_model_params_dict["ip_image_encoder"], + ip_ckpt=ip_adapter_face_model_params_dict["ip_ckpt"], + cross_attention_dim=cross_attention_dim, + clip_embeddings_dim=ip_adapter_face_model_params_dict[ + "clip_embeddings_dim" + ], + clip_extra_context_tokens=ip_adapter_face_model_params_dict[ + "clip_extra_context_tokens" + ], + ip_scale=ip_adapter_face_model_params_dict["ip_scale"], + device=device, + unet=unet, + ) + else: + ip_adapter_face_emb_extractor = None + ip_adapter_face_image_proj = None + + print("test_model_vae_model_path", test_model_vae_model_path) + + # init sd_predictor + sd_predictor = DiffusersPipelinePredictor( + sd_model_path=sd_model_path, + unet=unet, + lora_dict=lora_dict, + lcm_lora_dct=lcm_lora_dct, + device=device, + dtype=torch_dtype, + negative_embedding=negative_embedding, + referencenet=referencenet, + ip_adapter_image_proj=ip_adapter_image_proj, + vision_clip_extractor=vision_clip_extractor, + facein_image_proj=facein_image_proj, + face_emb_extractor=face_emb_extractor, + vae_model=test_model_vae_model_path, + ip_adapter_face_emb_extractor=ip_adapter_face_emb_extractor, + ip_adapter_face_image_proj=ip_adapter_face_image_proj, + pose_guider=pose_guider, + controlnet_name=controlnet_name, + enable_zero_snr=args.enable_zero_snr, + ) + logger.debug(f"load referencenet"), + + for i_test_data, test_data in enumerate(test_datas): + batch = [] + texts = [] + video_path = test_data.get("video_path") + video_reader = DecordVideoDataset( + video_path, + time_size=time_size, + step=time_size, + sample_rate=sample_rate, + device="cpu", + data_type="rgb", + channels_order="c t h w", + drop_last=True, + ) + video_height = video_reader.height + video_width = video_reader.width + + print("\n i_test_data", i_test_data, model_name) + test_data_name = test_data.get("name", i_test_data) + prompt = test_data["prompt"] + prompt = prefix_prompt + prompt + suffix_prompt + prompt_hash = get_signature_of_string(prompt, length=5) + test_data["prompt_hash"] = prompt_hash + test_data_height = test_data.get("height", height) + test_data_width = test_data.get("width", width) + test_data_condition_images_path = test_data.get("condition_images", None) + test_data_condition_images_index = test_data.get("condition_images_index", None) + test_data_redraw_condition_image = test_data.get( + "redraw_condition_image", redraw_condition_image + ) + # read condition_image + if ( + test_data_condition_images_path is not None + and use_condition_image + and ( + isinstance(test_data_condition_images_path, list) + or ( + isinstance(test_data_condition_images_path, str) + and is_image(test_data_condition_images_path) + ) + ) + ): + ( + test_data_condition_images, + test_data_condition_images_name, + ) = read_image_and_name(test_data_condition_images_path) + condition_image_height = test_data_condition_images.shape[3] + condition_image_width = test_data_condition_images.shape[4] + logger.debug( + f"test_data_condition_images use {test_data_condition_images_path}" + ) + else: + test_data_condition_images = None + test_data_condition_images_name = "no" + condition_image_height = None + condition_image_width = None + logger.debug(f"test_data_condition_images is None") + + # if test_data_height is not assigned, use height of condition, if still None, use of video + if test_data_height is None: + test_data_height = ( + condition_image_height + if condition_image_height is not None + else video_height + ) + + if test_data_width is None: + test_data_width = ( + condition_image_width + if condition_image_width is not None + else video_width + ) + + test_data_img_length_ratio = float( + test_data.get("img_length_ratio", img_length_ratio) + ) + + # to align height of generated video with video2video, use `64`` as basic pixel unit instead of `8`` + test_data_height = int(test_data_height * test_data_img_length_ratio // 64 * 64) + test_data_width = int(test_data_width * test_data_img_length_ratio // 64 * 64) + pprint(test_data) + print(f"test_data_height={test_data_height}") + print(f"test_data_width={test_data_width}") + # continue + test_data_style = test_data.get("style", None) + test_data_sex = test_data.get("sex", None) + # if paramters in test_data is str, but float in fact, convert it into float,int. + test_data_motion_speed = float(test_data.get("motion_speed", motion_speed)) + test_data_w_ind_noise = float(test_data.get("w_ind_noise", w_ind_noise)) + test_data_img_weight = float(test_data.get("img_weight", img_weight)) + logger.debug( + f"test_data_condition_images_path {test_data_condition_images_path}" + ) + logger.debug( + f"test_data_condition_images_index {test_data_condition_images_index}" + ) + test_data_refer_image_path = test_data.get( + "refer_image", referencenet_image_path + ) + test_data_ipadapter_image_path = test_data.get( + "ipadapter_image", ipadapter_image_path + ) + test_data_refer_face_image_path = test_data.get("face_image", face_image_path) + test_data_video_is_middle = test_data.get("video_is_middle", video_is_middle) + test_data_video_has_condition = test_data.get( + "video_has_condition", video_has_condition + ) + + controlnet_processor_params = { + "detect_resolution": min(test_data_height, test_data_width), + "image_resolution": min(test_data_height, test_data_width), + } + if negprompt_cfg_path is not None: + if "video_negative_prompt" in test_data: + ( + test_data_video_negative_prompt_name, + test_data_video_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "video_negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_video_negative_prompt_name = video_negative_prompt_name + test_data_video_negative_prompt = video_negative_prompt + if "negative_prompt" in test_data: + ( + test_data_negative_prompt_name, + test_data_negative_prompt, + ) = get_negative_prompt( + test_data.get( + "negative_prompt", + ), + cfg_path=negprompt_cfg_path, + n=negtive_prompt_length, + ) + else: + test_data_negative_prompt_name = negative_prompt_name + test_data_negative_prompt = negative_prompt + else: + test_data_video_negative_prompt = test_data.get( + "video_negative_prompt", video_negative_prompt + ) + test_data_video_negative_prompt_name = test_data_video_negative_prompt[ + :negtive_prompt_length + ] + test_data_negative_prompt = test_data.get( + "negative_prompt", negative_prompt + ) + test_data_negative_prompt_name = test_data_negative_prompt[ + :negtive_prompt_length + ] + + # prepare test_data_refer_image + if referencenet is not None: + if test_data_refer_image_path is None: + test_data_refer_image = test_data_condition_images + test_data_refer_image_name = test_data_condition_images_name + logger.debug(f"test_data_refer_image use test_data_condition_images") + else: + test_data_refer_image, test_data_refer_image_name = read_image_and_name( + test_data_refer_image_path + ) + logger.debug(f"test_data_refer_image use {test_data_refer_image_path}") + else: + test_data_refer_image = None + test_data_refer_image_name = "no" + logger.debug(f"test_data_refer_image is None") + + # prepare test_data_ipadapter_image + if vision_clip_extractor is not None: + if test_data_ipadapter_image_path is None: + test_data_ipadapter_image = test_data_condition_images + test_data_ipadapter_image_name = test_data_condition_images_name + + logger.debug( + f"test_data_ipadapter_image use test_data_condition_images" + ) + else: + ( + test_data_ipadapter_image, + test_data_ipadapter_image_name, + ) = read_image_and_name(test_data_ipadapter_image_path) + logger.debug( + f"test_data_ipadapter_image use f{test_data_ipadapter_image_path}" + ) + else: + test_data_ipadapter_image = None + test_data_ipadapter_image_name = "no" + logger.debug(f"test_data_ipadapter_image is None") + + # prepare test_data_refer_face_image + + if facein_image_proj is not None or ip_adapter_face_image_proj is not None: + if test_data_refer_face_image_path is None: + test_data_refer_face_image = test_data_condition_images + test_data_refer_face_image_name = test_data_condition_images_name + + logger.debug( + f"test_data_refer_face_image use test_data_condition_images" + ) + else: + ( + test_data_refer_face_image, + test_data_refer_face_image_name, + ) = read_image_and_name(test_data_refer_face_image_path) + logger.debug( + f"test_data_refer_face_image use f{test_data_refer_face_image_path}" + ) + else: + test_data_refer_face_image = None + test_data_refer_face_image_name = "no" + logger.debug(f"test_data_refer_face_image is None") + + # if sex, style of test_data is not aligned with of model + # skip this test_data + + if ( + model_sex is not None + and test_data_sex is not None + and model_sex != test_data_sex + ) or ( + model_style is not None + and test_data_style is not None + and model_style != test_data_style + ): + print("model doesnt match test_data") + print("model name: ", model_name) + print("test_data: ", test_data) + continue + # video + filename = os.path.basename(video_path).split(".")[0] + for i_num in range(n_repeat): + test_data_seed = random.randint(0, 1e8) if seed is None else seed + cpu_generator, gpu_generator = set_all_seed(test_data_seed) + + save_file_name = ( + f"{which2video_name}_m={model_name}_rm={referencenet_model_name}_c={test_data_name}" + f"_w={test_data_width}_h={test_data_height}_t={time_size}_n={n_batch}" + f"_vn={video_num_inference_steps}" + f"_w={test_data_img_weight}_w={test_data_w_ind_noise}" + f"_s={test_data_seed}_n={controlnet_name_str}" + f"_s={strength}_g={guidance_scale}_vs={video_strength}_vg={video_guidance_scale}" + f"_p={prompt_hash}_{test_data_video_negative_prompt_name[:10]}" + f"_r={test_data_refer_image_name[:3]}_ip={test_data_refer_image_name[:3]}_f={test_data_refer_face_image_name[:3]}" + ) + save_file_name = clean_str_for_save(save_file_name) + output_path = os.path.join( + output_dir, + f"{save_file_name}.{save_filetype}", + ) + if os.path.exists(output_path) and not overwrite: + print("existed", output_path) + continue + + if which2video in ["video", "video_middle"]: + if which2video == "video": + need_video2video = True + ( + out_videos, + out_condition, + videos, + ) = sd_predictor.run_pipe_video2video( + video=video_path, + time_size=time_size, + step=time_size, + sample_rate=sample_rate, + need_return_videos=need_return_videos, + need_return_condition=need_return_condition, + controlnet_conditioning_scale=controlnet_conditioning_scale, + control_guidance_start=control_guidance_start, + control_guidance_end=control_guidance_end, + end_to_end=end_to_end, + need_video2video=need_video2video, + video_strength=video_strength, + prompt=prompt, + width=test_data_width, + height=test_data_height, + generator=gpu_generator, + noise_type=noise_type, + negative_prompt=test_data_negative_prompt, + video_negative_prompt=test_data_video_negative_prompt, + max_batch_num=n_batch, + strength=strength, + need_img_based_video_noise=need_img_based_video_noise, + video_num_inference_steps=video_num_inference_steps, + condition_images=test_data_condition_images, + fix_condition_images=fix_condition_images, + video_guidance_scale=video_guidance_scale, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + redraw_condition_image=test_data_redraw_condition_image, + img_weight=test_data_img_weight, + w_ind_noise=test_data_w_ind_noise, + n_vision_condition=n_vision_condition, + motion_speed=test_data_motion_speed, + need_hist_match=need_hist_match, + video_guidance_scale_end=video_guidance_scale_end, + video_guidance_scale_method=video_guidance_scale_method, + vision_condition_latent_index=test_data_condition_images_index, + refer_image=test_data_refer_image, + fixed_refer_image=fixed_refer_image, + redraw_condition_image_with_referencenet=redraw_condition_image_with_referencenet, + ip_adapter_image=test_data_ipadapter_image, + refer_face_image=test_data_refer_face_image, + fixed_refer_face_image=fixed_refer_face_image, + facein_scale=facein_scale, + redraw_condition_image_with_facein=redraw_condition_image_with_facein, + ip_adapter_face_scale=ip_adapter_face_scale, + redraw_condition_image_with_ip_adapter_face=redraw_condition_image_with_ip_adapter_face, + fixed_ip_adapter_image=fixed_ip_adapter_image, + ip_adapter_scale=ip_adapter_scale, + redraw_condition_image_with_ipdapter=redraw_condition_image_with_ipdapter, + prompt_only_use_image_prompt=prompt_only_use_image_prompt, + controlnet_processor_params=controlnet_processor_params, + # serial_denoise parameter start + record_mid_video_noises=record_mid_video_noises, + record_mid_video_latents=record_mid_video_latents, + video_overlap=video_overlap, + # serial_denoise parameter end + # parallel_denoise parameter start + context_schedule=context_schedule, + context_frames=context_frames, + context_stride=context_stride, + context_overlap=context_overlap, + context_batch_size=context_batch_size, + interpolation_factor=interpolation_factor, + # parallel_denoise parameter end + video_is_middle=test_data_video_is_middle, + video_has_condition=test_data_video_has_condition, + ) + else: + raise ValueError( + f"only support video, videomiddle2video, but given {which2video_name}" + ) + print("out_videos.shape", out_videos.shape) + batch = [out_videos] + texts = ["out"] + if videos is not None: + print("videos.shape", videos.shape) + batch.insert(0, videos / 255.0) + texts.insert(0, "videos") + if need_controlnet and out_condition is not None: + if not isinstance(out_condition, list): + print("out_condition", out_condition.shape) + batch.append(out_condition / 255.0) + texts.append(controlnet_name) + else: + batch.extend([x / 255.0 for x in out_condition]) + texts.extend(controlnet_name) + out = np.concatenate(batch, axis=0) + save_videos_grid_with_opencv( + out, + output_path, + texts=texts, + fps=fps, + tensor_order="b c t h w", + n_cols=n_cols, + write_info=args.write_info, + save_filetype=save_filetype, + save_images=save_images, + ) + print("Save to", output_path) + print("\n" * 2) diff --git a/scripts/runpod_handler.py b/scripts/runpod_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..c544d108a1121b07494f7f4b96b5d92c797439fe --- /dev/null +++ b/scripts/runpod_handler.py @@ -0,0 +1,193 @@ +import runpod +import os +import sys +from pathlib import Path +import torch +import gradio as gr +import tempfile +from PIL import Image +import numpy as np +import yaml +from typing import Dict, Any, Optional +import threading + +# Add the MuseV directory to the Python path +musev_path = str(Path(__file__).parent.parent) +sys.path.append(musev_path) + +# Import MuseV modules (adjust these imports based on the actual module structure) +from musev.pipelines import MuseVPipeline + +class MuseVService: + def __init__(self): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.pipeline = None + self.load_model() + + def load_model(self): + # Initialize the MuseV pipeline (adjust parameters as needed) + self.pipeline = MuseVPipeline.from_pretrained( + "TMElyralab/MuseV", + torch_dtype=torch.float16, + device=self.device + ) + self.pipeline.to(self.device) + + def generate_video( + self, + condition_image: Image.Image, + prompt: str, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: int = 16, + eye_blinks_factor: float = 1.8, + ) -> str: + # Process inputs + if height is None or width is None: + width, height = condition_image.size + aspect_ratio = width / height + if width > height: + width = min(width, 1024) + height = int(width / aspect_ratio) + else: + height = min(height, 1024) + width = int(height * aspect_ratio) + + # Create temporary directory for output + with tempfile.TemporaryDirectory() as temp_dir: + # Save condition image + condition_image_path = os.path.join(temp_dir, "condition.png") + condition_image.save(condition_image_path) + + # Prepare configuration + config = { + "condition_images": condition_image_path, + "prompt": prompt, + "height": height, + "width": width, + "eye_blinks_factor": eye_blinks_factor, + "img_length_ratio": 1.0, + "ipadapter_image": condition_image_path, + "refer_image": condition_image_path, + } + + # Generate video + output_path = os.path.join(temp_dir, "output.mp4") + self.pipeline.generate(config, output_path) + + # Read the video file and return as bytes + with open(output_path, "rb") as f: + video_bytes = f.read() + + return video_bytes + +# Initialize the service +service = MuseVService() + +def handler(event): + """ + RunPod handler function for API requests + """ + try: + # Get the input data + job_input = event["input"] + + # Process the input image + image_data = job_input.get("image") + if not image_data: + raise ValueError("No image provided") + + # Convert base64 image to PIL + import base64 + from io import BytesIO + image = Image.open(BytesIO(base64.b64decode(image_data))) + + # Generate video + video_bytes = service.generate_video( + condition_image=image, + prompt=job_input.get("prompt", ""), + height=job_input.get("height"), + width=job_input.get("width"), + eye_blinks_factor=job_input.get("eye_blinks_factor", 1.8), + ) + + # Encode video as base64 + video_base64 = base64.b64encode(video_bytes).decode() + + return { + "status": "success", + "output": { + "video": video_base64 + } + } + except Exception as e: + return { + "status": "error", + "error": str(e) + } + +def create_gradio_interface(): + """ + Create Gradio interface + """ + def generate_video_gradio(image, prompt, height, width, eye_blinks_factor): + try: + video_bytes = service.generate_video( + condition_image=Image.fromarray(image), + prompt=prompt, + height=height if height > 0 else None, + width=width if width > 0 else None, + eye_blinks_factor=eye_blinks_factor + ) + + # Save video to temporary file + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") + temp_file.write(video_bytes) + temp_file.close() + + return temp_file.name + except Exception as e: + raise gr.Error(str(e)) + + # Create the interface + interface = gr.Interface( + fn=generate_video_gradio, + inputs=[ + gr.Image(label="Input Image", type="numpy"), + gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), + gr.Number(label="Height (optional)", value=0), + gr.Number(label="Width (optional)", value=0), + gr.Slider(minimum=0.0, maximum=3.0, value=1.8, label="Eye Blinks Factor") + ], + outputs=gr.Video(label="Generated Video"), + title="MuseV Video Generation", + description="Generate videos from images using MuseV", + examples=[ + [ + "path/to/example/image.jpg", + "(masterpiece, best quality, highres:1),(1person, solo:1),(eye blinks:1.8),(head wave:1.3)", + 512, + 512, + 1.8 + ] + ] + ) + return interface + +if __name__ == "__main__": + # Start both the RunPod handler and Gradio interface + interface = create_gradio_interface() + + # Start Gradio in a separate thread + threading.Thread( + target=interface.launch, + kwargs={ + "server_name": "0.0.0.0", + "server_port": 3000, + "share": False + }, + daemon=True + ).start() + + # Start the RunPod handler + runpod.serverless.start({"handler": handler}) \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100755 index 0000000000000000000000000000000000000000..d960a8cd60b5ee42d040cd486491f563c964c2e4 --- /dev/null +++ b/setup.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +from setuptools import setup, find_packages + +# with open("README.md", "r") as fh: +# long_description = fh.read() + +with open("requirements.txt", "r") as f: + requirements = f.read().splitlines() + +setup( + name="musev", # used in pip install + version="1.0.0", + author="anchorxia, zkangchen", + author_email="anchorxia@tencent.com, zkangchen@tencent.com", + description="Package about human video creation", + # long_description=long_description, + # long_description_content_type="text/markdown", + url="https://github.com/TMElyralab/MuseV", + # include_package_data=True, # please edit MANIFEST.in + # packages=find_packages(), # used in import + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + install_requires=requirements, +)