ameerazam08
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- DiffSynth_Studio.py +15 -0
- README.md +101 -0
- configs/stable_diffusion/tokenizer/merges.txt +0 -0
- configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
- configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
- configs/stable_diffusion/tokenizer/vocab.json +0 -0
- configs/stable_diffusion_xl/tokenizer_2/merges.txt +0 -0
- configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
- configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
- configs/stable_diffusion_xl/tokenizer_2/vocab.json +0 -0
- diffsynth/__init__.py +6 -0
- diffsynth/__pycache__/__init__.cpython-39.pyc +0 -0
- diffsynth/controlnets/__init__.py +2 -0
- diffsynth/controlnets/__pycache__/__init__.cpython-39.pyc +0 -0
- diffsynth/controlnets/__pycache__/controlnet_unit.cpython-39.pyc +0 -0
- diffsynth/controlnets/__pycache__/processors.cpython-39.pyc +0 -0
- diffsynth/controlnets/controlnet_unit.py +53 -0
- diffsynth/controlnets/processors.py +51 -0
- diffsynth/data/__init__.py +1 -0
- diffsynth/data/__pycache__/__init__.cpython-39.pyc +0 -0
- diffsynth/data/__pycache__/video.cpython-39.pyc +0 -0
- diffsynth/data/video.py +148 -0
- diffsynth/extensions/FastBlend/__init__.py +63 -0
- diffsynth/extensions/FastBlend/api.py +397 -0
- diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
- diffsynth/extensions/FastBlend/data.py +146 -0
- diffsynth/extensions/FastBlend/patch_match.py +298 -0
- diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
- diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
- diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
- diffsynth/extensions/FastBlend/runners/fast.py +141 -0
- diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
- diffsynth/extensions/RIFE/__init__.py +241 -0
- diffsynth/models/__init__.py +295 -0
- diffsynth/models/__pycache__/__init__.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/attention.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_controlnet.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_lora.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_motion.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_text_encoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_unet.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_vae_decoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sd_vae_encoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sdxl_text_encoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-39.pyc +0 -0
- diffsynth/models/__pycache__/tiler.cpython-39.pyc +0 -0
- diffsynth/models/attention.py +76 -0
- diffsynth/models/sd_controlnet.py +584 -0
DiffSynth_Studio.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Set web page format
|
2 |
+
import streamlit as st
|
3 |
+
st.set_page_config(layout="wide")
|
4 |
+
# Diasble virtual VRAM on windows system
|
5 |
+
import torch
|
6 |
+
torch.cuda.set_per_process_memory_fraction(0.999, 0)
|
7 |
+
|
8 |
+
|
9 |
+
st.markdown("""
|
10 |
+
# DiffSynth Studio
|
11 |
+
|
12 |
+
[Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
|
13 |
+
|
14 |
+
Welcome to DiffSynth Studio.
|
15 |
+
""")
|
README.md
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSynth Studio
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
|
6 |
+
|
7 |
+
## Installation
|
8 |
+
|
9 |
+
Create Python environment:
|
10 |
+
|
11 |
+
```
|
12 |
+
conda env create -f environment.yml
|
13 |
+
```
|
14 |
+
|
15 |
+
We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
|
16 |
+
|
17 |
+
Enter the Python environment:
|
18 |
+
|
19 |
+
```
|
20 |
+
conda activate DiffSynthStudio
|
21 |
+
```
|
22 |
+
|
23 |
+
## Usage (in WebUI)
|
24 |
+
|
25 |
+
```
|
26 |
+
python -m streamlit run Diffsynth_Studio.py
|
27 |
+
```
|
28 |
+
|
29 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
|
30 |
+
|
31 |
+
## Usage (in Python code)
|
32 |
+
|
33 |
+
### Example 1: Stable Diffusion
|
34 |
+
|
35 |
+
We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details.
|
36 |
+
|
37 |
+
|512*512|1024*1024|2048*2048|4096*4096|
|
38 |
+
|-|-|-|-|
|
39 |
+
|![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
|
40 |
+
|
41 |
+
### Example 2: Stable Diffusion XL
|
42 |
+
|
43 |
+
Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details.
|
44 |
+
|
45 |
+
|1024*1024|2048*2048|
|
46 |
+
|-|-|
|
47 |
+
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
|
48 |
+
|
49 |
+
### Example 3: Stable Diffusion XL Turbo
|
50 |
+
|
51 |
+
Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI.
|
52 |
+
|
53 |
+
|"black car"|"red car"|
|
54 |
+
|-|-|
|
55 |
+
|![black_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7fbfd803-68d4-44f3-8713-8c925fec47d0)|![black_car_to_red_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/aaf886e4-c33c-4fd8-98e2-29eef117ba00)|
|
56 |
+
|
57 |
+
### Example 4: Toon Shading (Diffutoon)
|
58 |
+
|
59 |
+
This example is implemented based on [Diffutoon](https://arxiv.org/abs/2401.16224). This approach is adept for rendering high-resoluton videos with rapid motion. You can easily modify the parameters in the config dict. See `examples/diffutoon_toon_shading.py`.
|
60 |
+
|
61 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
|
62 |
+
|
63 |
+
### Example 5: Toon Shading with Editing Signals (Diffutoon)
|
64 |
+
|
65 |
+
Coming soon.
|
66 |
+
|
67 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
|
68 |
+
|
69 |
+
### Example 6: Toon Shading (in native Python code)
|
70 |
+
|
71 |
+
This example is provided for developers. If you don't want to use the config to manage parameters, you can see `examples/sd_toon_shading.py` to learn how to use it in native Python code.
|
72 |
+
|
73 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/607c199b-6140-410b-a111-3e4ffb01142c
|
74 |
+
|
75 |
+
### Example 7: Text to Video
|
76 |
+
|
77 |
+
Given a prompt, DiffSynth Studio can generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See `examples/sd_text_to_video.py`.
|
78 |
+
|
79 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-4445-9b48-e9da77699437
|
80 |
+
|
81 |
+
### Example 8: Video Stylization
|
82 |
+
|
83 |
+
We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details.
|
84 |
+
|
85 |
+
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
86 |
+
|
87 |
+
### Example 9: Prompt Processing
|
88 |
+
|
89 |
+
If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details.
|
90 |
+
|
91 |
+
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
|
92 |
+
|
93 |
+
|seed=0|seed=1|seed=2|seed=3|
|
94 |
+
|-|-|-|-|
|
95 |
+
|![0_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/ebb25ca8-7ce1-4d9e-8081-59a867c70c4d)|![1_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a7e79853-3c1a-471a-9c58-c209ec4b76dd)|![2_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a292b959-a121-481f-b79c-61cc3346f810)|![3_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1c19b54e-5a6f-4d48-960b-a7b2b149bb4c)|
|
96 |
+
|
97 |
+
Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
|
98 |
+
|
99 |
+
|seed=0|seed=1|seed=2|seed=3|
|
100 |
+
|-|-|-|-|
|
101 |
+
|![0](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/778b1bd9-44e0-46ac-a99c-712b3fc9aaa4)|![1](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/c03479b8-2082-4c6e-8e1c-3582b98686f6)|![2](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edb33d21-3288-4a55-96ca-a4bfe1b50b00)|![3](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7848cfc1-cad5-4848-8373-41d24e98e584)|
|
configs/stable_diffusion/tokenizer/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/stable_diffusion/tokenizer/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "<|endoftext|>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
configs/stable_diffusion/tokenizer/tokenizer_config.json
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"bos_token": {
|
4 |
+
"__type": "AddedToken",
|
5 |
+
"content": "<|startoftext|>",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": true,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false
|
10 |
+
},
|
11 |
+
"do_lower_case": true,
|
12 |
+
"eos_token": {
|
13 |
+
"__type": "AddedToken",
|
14 |
+
"content": "<|endoftext|>",
|
15 |
+
"lstrip": false,
|
16 |
+
"normalized": true,
|
17 |
+
"rstrip": false,
|
18 |
+
"single_word": false
|
19 |
+
},
|
20 |
+
"errors": "replace",
|
21 |
+
"model_max_length": 77,
|
22 |
+
"name_or_path": "openai/clip-vit-large-patch14",
|
23 |
+
"pad_token": "<|endoftext|>",
|
24 |
+
"special_tokens_map_file": "./special_tokens_map.json",
|
25 |
+
"tokenizer_class": "CLIPTokenizer",
|
26 |
+
"unk_token": {
|
27 |
+
"__type": "AddedToken",
|
28 |
+
"content": "<|endoftext|>",
|
29 |
+
"lstrip": false,
|
30 |
+
"normalized": true,
|
31 |
+
"rstrip": false,
|
32 |
+
"single_word": false
|
33 |
+
}
|
34 |
+
}
|
configs/stable_diffusion/tokenizer/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/stable_diffusion_xl/tokenizer_2/merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<|startoftext|>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "<|endoftext|>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "!",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<|endoftext|>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"add_prefix_space": false,
|
3 |
+
"added_tokens_decoder": {
|
4 |
+
"0": {
|
5 |
+
"content": "!",
|
6 |
+
"lstrip": false,
|
7 |
+
"normalized": false,
|
8 |
+
"rstrip": false,
|
9 |
+
"single_word": false,
|
10 |
+
"special": true
|
11 |
+
},
|
12 |
+
"49406": {
|
13 |
+
"content": "<|startoftext|>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": true,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false,
|
18 |
+
"special": true
|
19 |
+
},
|
20 |
+
"49407": {
|
21 |
+
"content": "<|endoftext|>",
|
22 |
+
"lstrip": false,
|
23 |
+
"normalized": true,
|
24 |
+
"rstrip": false,
|
25 |
+
"single_word": false,
|
26 |
+
"special": true
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"bos_token": "<|startoftext|>",
|
30 |
+
"clean_up_tokenization_spaces": true,
|
31 |
+
"do_lower_case": true,
|
32 |
+
"eos_token": "<|endoftext|>",
|
33 |
+
"errors": "replace",
|
34 |
+
"model_max_length": 77,
|
35 |
+
"pad_token": "!",
|
36 |
+
"tokenizer_class": "CLIPTokenizer",
|
37 |
+
"unk_token": "<|endoftext|>"
|
38 |
+
}
|
configs/stable_diffusion_xl/tokenizer_2/vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
diffsynth/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import *
|
2 |
+
from .models import *
|
3 |
+
from .prompts import *
|
4 |
+
from .schedulers import *
|
5 |
+
from .pipelines import *
|
6 |
+
from .controlnets import *
|
diffsynth/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (279 Bytes). View file
|
|
diffsynth/controlnets/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
|
2 |
+
from .processors import Annotator
|
diffsynth/controlnets/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (329 Bytes). View file
|
|
diffsynth/controlnets/__pycache__/controlnet_unit.cpython-39.pyc
ADDED
Binary file (3.09 kB). View file
|
|
diffsynth/controlnets/__pycache__/processors.cpython-39.pyc
ADDED
Binary file (1.82 kB). View file
|
|
diffsynth/controlnets/controlnet_unit.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from .processors import Processor_id
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConfigUnit:
|
7 |
+
def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
|
8 |
+
self.processor_id = processor_id
|
9 |
+
self.model_path = model_path
|
10 |
+
self.scale = scale
|
11 |
+
|
12 |
+
|
13 |
+
class ControlNetUnit:
|
14 |
+
def __init__(self, processor, model, scale=1.0):
|
15 |
+
self.processor = processor
|
16 |
+
self.model = model
|
17 |
+
self.scale = scale
|
18 |
+
|
19 |
+
|
20 |
+
class MultiControlNetManager:
|
21 |
+
def __init__(self, controlnet_units=[]):
|
22 |
+
self.processors = [unit.processor for unit in controlnet_units]
|
23 |
+
self.models = [unit.model for unit in controlnet_units]
|
24 |
+
self.scales = [unit.scale for unit in controlnet_units]
|
25 |
+
|
26 |
+
def process_image(self, image, processor_id=None):
|
27 |
+
if processor_id is None:
|
28 |
+
processed_image = [processor(image) for processor in self.processors]
|
29 |
+
else:
|
30 |
+
processed_image = [self.processors[processor_id](image)]
|
31 |
+
processed_image = torch.concat([
|
32 |
+
torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
|
33 |
+
for image_ in processed_image
|
34 |
+
], dim=0)
|
35 |
+
return processed_image
|
36 |
+
|
37 |
+
def __call__(
|
38 |
+
self,
|
39 |
+
sample, timestep, encoder_hidden_states, conditionings,
|
40 |
+
tiled=False, tile_size=64, tile_stride=32
|
41 |
+
):
|
42 |
+
res_stack = None
|
43 |
+
for conditioning, model, scale in zip(conditionings, self.models, self.scales):
|
44 |
+
res_stack_ = model(
|
45 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
46 |
+
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
47 |
+
)
|
48 |
+
res_stack_ = [res * scale for res in res_stack_]
|
49 |
+
if res_stack is None:
|
50 |
+
res_stack = res_stack_
|
51 |
+
else:
|
52 |
+
res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
|
53 |
+
return res_stack
|
diffsynth/controlnets/processors.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing_extensions import Literal, TypeAlias
|
2 |
+
import warnings
|
3 |
+
with warnings.catch_warnings():
|
4 |
+
warnings.simplefilter("ignore")
|
5 |
+
from controlnet_aux.processor import (
|
6 |
+
CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
|
7 |
+
)
|
8 |
+
|
9 |
+
|
10 |
+
Processor_id: TypeAlias = Literal[
|
11 |
+
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
|
12 |
+
]
|
13 |
+
|
14 |
+
class Annotator:
|
15 |
+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
|
16 |
+
if processor_id == "canny":
|
17 |
+
self.processor = CannyDetector()
|
18 |
+
elif processor_id == "depth":
|
19 |
+
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
|
20 |
+
elif processor_id == "softedge":
|
21 |
+
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
|
22 |
+
elif processor_id == "lineart":
|
23 |
+
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
|
24 |
+
elif processor_id == "lineart_anime":
|
25 |
+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
|
26 |
+
elif processor_id == "openpose":
|
27 |
+
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
|
28 |
+
elif processor_id == "tile":
|
29 |
+
self.processor = None
|
30 |
+
else:
|
31 |
+
raise ValueError(f"Unsupported processor_id: {processor_id}")
|
32 |
+
|
33 |
+
self.processor_id = processor_id
|
34 |
+
self.detect_resolution = detect_resolution
|
35 |
+
|
36 |
+
def __call__(self, image):
|
37 |
+
width, height = image.size
|
38 |
+
if self.processor_id == "openpose":
|
39 |
+
kwargs = {
|
40 |
+
"include_body": True,
|
41 |
+
"include_hand": True,
|
42 |
+
"include_face": True
|
43 |
+
}
|
44 |
+
else:
|
45 |
+
kwargs = {}
|
46 |
+
if self.processor is not None:
|
47 |
+
detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
|
48 |
+
image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
|
49 |
+
image = image.resize((width, height))
|
50 |
+
return image
|
51 |
+
|
diffsynth/data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .video import VideoData, save_video, save_frames
|
diffsynth/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (242 Bytes). View file
|
|
diffsynth/data/__pycache__/video.cpython-39.pyc
ADDED
Binary file (6.09 kB). View file
|
|
diffsynth/data/video.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
class LowMemoryVideo:
|
8 |
+
def __init__(self, file_name):
|
9 |
+
self.reader = imageio.get_reader(file_name)
|
10 |
+
|
11 |
+
def __len__(self):
|
12 |
+
return self.reader.count_frames()
|
13 |
+
|
14 |
+
def __getitem__(self, item):
|
15 |
+
return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
|
16 |
+
|
17 |
+
def __del__(self):
|
18 |
+
self.reader.close()
|
19 |
+
|
20 |
+
|
21 |
+
def split_file_name(file_name):
|
22 |
+
result = []
|
23 |
+
number = -1
|
24 |
+
for i in file_name:
|
25 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
26 |
+
if number == -1:
|
27 |
+
number = 0
|
28 |
+
number = number*10 + ord(i) - ord("0")
|
29 |
+
else:
|
30 |
+
if number != -1:
|
31 |
+
result.append(number)
|
32 |
+
number = -1
|
33 |
+
result.append(i)
|
34 |
+
if number != -1:
|
35 |
+
result.append(number)
|
36 |
+
result = tuple(result)
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def search_for_images(folder):
|
41 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
42 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
43 |
+
file_list = [i[1] for i in sorted(file_list)]
|
44 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
45 |
+
return file_list
|
46 |
+
|
47 |
+
|
48 |
+
class LowMemoryImageFolder:
|
49 |
+
def __init__(self, folder, file_list=None):
|
50 |
+
if file_list is None:
|
51 |
+
self.file_list = search_for_images(folder)
|
52 |
+
else:
|
53 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
54 |
+
|
55 |
+
def __len__(self):
|
56 |
+
return len(self.file_list)
|
57 |
+
|
58 |
+
def __getitem__(self, item):
|
59 |
+
return Image.open(self.file_list[item]).convert("RGB")
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
pass
|
63 |
+
|
64 |
+
|
65 |
+
def crop_and_resize(image, height, width):
|
66 |
+
image = np.array(image)
|
67 |
+
image_height, image_width, _ = image.shape
|
68 |
+
if image_height / image_width < height / width:
|
69 |
+
croped_width = int(image_height / height * width)
|
70 |
+
left = (image_width - croped_width) // 2
|
71 |
+
image = image[:, left: left+croped_width]
|
72 |
+
image = Image.fromarray(image).resize((width, height))
|
73 |
+
else:
|
74 |
+
croped_height = int(image_width / width * height)
|
75 |
+
left = (image_height - croped_height) // 2
|
76 |
+
image = image[left: left+croped_height, :]
|
77 |
+
image = Image.fromarray(image).resize((width, height))
|
78 |
+
return image
|
79 |
+
|
80 |
+
|
81 |
+
class VideoData:
|
82 |
+
def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
|
83 |
+
if video_file is not None:
|
84 |
+
self.data_type = "video"
|
85 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
86 |
+
elif image_folder is not None:
|
87 |
+
self.data_type = "images"
|
88 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
89 |
+
else:
|
90 |
+
raise ValueError("Cannot open video or image folder")
|
91 |
+
self.length = None
|
92 |
+
self.set_shape(height, width)
|
93 |
+
|
94 |
+
def raw_data(self):
|
95 |
+
frames = []
|
96 |
+
for i in range(self.__len__()):
|
97 |
+
frames.append(self.__getitem__(i))
|
98 |
+
return frames
|
99 |
+
|
100 |
+
def set_length(self, length):
|
101 |
+
self.length = length
|
102 |
+
|
103 |
+
def set_shape(self, height, width):
|
104 |
+
self.height = height
|
105 |
+
self.width = width
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
if self.length is None:
|
109 |
+
return len(self.data)
|
110 |
+
else:
|
111 |
+
return self.length
|
112 |
+
|
113 |
+
def shape(self):
|
114 |
+
if self.height is not None and self.width is not None:
|
115 |
+
return self.height, self.width
|
116 |
+
else:
|
117 |
+
height, width, _ = self.__getitem__(0).shape
|
118 |
+
return height, width
|
119 |
+
|
120 |
+
def __getitem__(self, item):
|
121 |
+
frame = self.data.__getitem__(item)
|
122 |
+
width, height = frame.size
|
123 |
+
if self.height is not None and self.width is not None:
|
124 |
+
if self.height != height or self.width != width:
|
125 |
+
frame = crop_and_resize(frame, self.height, self.width)
|
126 |
+
return frame
|
127 |
+
|
128 |
+
def __del__(self):
|
129 |
+
pass
|
130 |
+
|
131 |
+
def save_images(self, folder):
|
132 |
+
os.makedirs(folder, exist_ok=True)
|
133 |
+
for i in tqdm(range(self.__len__()), desc="Saving images"):
|
134 |
+
frame = self.__getitem__(i)
|
135 |
+
frame.save(os.path.join(folder, f"{i}.png"))
|
136 |
+
|
137 |
+
|
138 |
+
def save_video(frames, save_path, fps, quality=9):
|
139 |
+
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
140 |
+
for frame in tqdm(frames, desc="Saving video"):
|
141 |
+
frame = np.array(frame)
|
142 |
+
writer.append_data(frame)
|
143 |
+
writer.close()
|
144 |
+
|
145 |
+
def save_frames(frames, save_path):
|
146 |
+
os.makedirs(save_path, exist_ok=True)
|
147 |
+
for i, frame in enumerate(tqdm(frames, desc="Saving images")):
|
148 |
+
frame.save(os.path.join(save_path, f"{i}.png"))
|
diffsynth/extensions/FastBlend/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners.fast import TableManager, PyramidPatchMatcher
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import cupy as cp
|
5 |
+
|
6 |
+
|
7 |
+
class FastBlendSmoother:
|
8 |
+
def __init__(self):
|
9 |
+
self.batch_size = 8
|
10 |
+
self.window_size = 64
|
11 |
+
self.ebsynth_config = {
|
12 |
+
"minimum_patch_size": 5,
|
13 |
+
"threads_per_block": 8,
|
14 |
+
"num_iter": 5,
|
15 |
+
"gpu_id": 0,
|
16 |
+
"guide_weight": 10.0,
|
17 |
+
"initialize": "identity",
|
18 |
+
"tracking_window_size": 0,
|
19 |
+
}
|
20 |
+
|
21 |
+
@staticmethod
|
22 |
+
def from_model_manager(model_manager):
|
23 |
+
# TODO: fetch GPU ID from model_manager
|
24 |
+
return FastBlendSmoother()
|
25 |
+
|
26 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
|
27 |
+
frames_guide = [np.array(frame) for frame in frames_guide]
|
28 |
+
frames_style = [np.array(frame) for frame in frames_style]
|
29 |
+
table_manager = TableManager()
|
30 |
+
patch_match_engine = PyramidPatchMatcher(
|
31 |
+
image_height=frames_style[0].shape[0],
|
32 |
+
image_width=frames_style[0].shape[1],
|
33 |
+
channel=3,
|
34 |
+
**ebsynth_config
|
35 |
+
)
|
36 |
+
# left part
|
37 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
|
38 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
39 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
|
40 |
+
# right part
|
41 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
|
42 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
43 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
|
44 |
+
# merge
|
45 |
+
frames = []
|
46 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
47 |
+
weight_m = -1
|
48 |
+
weight = weight_l + weight_m + weight_r
|
49 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
50 |
+
frames.append(frame)
|
51 |
+
frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
|
52 |
+
return frames
|
53 |
+
|
54 |
+
def __call__(self, rendered_frames, original_frames=None, **kwargs):
|
55 |
+
frames = self.run(
|
56 |
+
original_frames, rendered_frames,
|
57 |
+
self.batch_size, self.window_size, self.ebsynth_config
|
58 |
+
)
|
59 |
+
mempool = cp.get_default_memory_pool()
|
60 |
+
pinned_mempool = cp.get_default_pinned_memory_pool()
|
61 |
+
mempool.free_all_blocks()
|
62 |
+
pinned_mempool.free_all_blocks()
|
63 |
+
return frames
|
diffsynth/extensions/FastBlend/api.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
2 |
+
from .data import VideoData, get_video_fps, save_video, search_for_images
|
3 |
+
import os
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
|
8 |
+
frames_guide = VideoData(video_guide, video_guide_folder)
|
9 |
+
frames_style = VideoData(video_style, video_style_folder)
|
10 |
+
message = ""
|
11 |
+
if len(frames_guide) < len(frames_style):
|
12 |
+
message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
|
13 |
+
frames_style.set_length(len(frames_guide))
|
14 |
+
elif len(frames_guide) > len(frames_style):
|
15 |
+
message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
|
16 |
+
frames_guide.set_length(len(frames_style))
|
17 |
+
height_guide, width_guide = frames_guide.shape()
|
18 |
+
height_style, width_style = frames_style.shape()
|
19 |
+
if height_guide != height_style or width_guide != width_style:
|
20 |
+
message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
|
21 |
+
frames_style.set_shape(height_guide, width_guide)
|
22 |
+
return frames_guide, frames_style, message
|
23 |
+
|
24 |
+
|
25 |
+
def smooth_video(
|
26 |
+
video_guide,
|
27 |
+
video_guide_folder,
|
28 |
+
video_style,
|
29 |
+
video_style_folder,
|
30 |
+
mode,
|
31 |
+
window_size,
|
32 |
+
batch_size,
|
33 |
+
tracking_window_size,
|
34 |
+
output_path,
|
35 |
+
fps,
|
36 |
+
minimum_patch_size,
|
37 |
+
num_iter,
|
38 |
+
guide_weight,
|
39 |
+
initialize,
|
40 |
+
progress = None,
|
41 |
+
):
|
42 |
+
# input
|
43 |
+
frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
|
44 |
+
if len(message) > 0:
|
45 |
+
print(message)
|
46 |
+
# output
|
47 |
+
if output_path == "":
|
48 |
+
if video_style is None:
|
49 |
+
output_path = os.path.join(video_style_folder, "output")
|
50 |
+
else:
|
51 |
+
output_path = os.path.join(os.path.split(video_style)[0], "output")
|
52 |
+
os.makedirs(output_path, exist_ok=True)
|
53 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
54 |
+
elif not os.path.exists(output_path):
|
55 |
+
os.makedirs(output_path, exist_ok=True)
|
56 |
+
print("Your video will be saved here:", output_path)
|
57 |
+
frames_path = os.path.join(output_path, "frames")
|
58 |
+
video_path = os.path.join(output_path, "video.mp4")
|
59 |
+
os.makedirs(frames_path, exist_ok=True)
|
60 |
+
# process
|
61 |
+
if mode == "Fast" or mode == "Balanced":
|
62 |
+
tracking_window_size = 0
|
63 |
+
ebsynth_config = {
|
64 |
+
"minimum_patch_size": minimum_patch_size,
|
65 |
+
"threads_per_block": 8,
|
66 |
+
"num_iter": num_iter,
|
67 |
+
"gpu_id": 0,
|
68 |
+
"guide_weight": guide_weight,
|
69 |
+
"initialize": initialize,
|
70 |
+
"tracking_window_size": tracking_window_size,
|
71 |
+
}
|
72 |
+
if mode == "Fast":
|
73 |
+
FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
74 |
+
elif mode == "Balanced":
|
75 |
+
BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
76 |
+
elif mode == "Accurate":
|
77 |
+
AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
|
78 |
+
# output
|
79 |
+
try:
|
80 |
+
fps = int(fps)
|
81 |
+
except:
|
82 |
+
fps = get_video_fps(video_style) if video_style is not None else 30
|
83 |
+
print("Fps:", fps)
|
84 |
+
print("Saving video...")
|
85 |
+
video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
|
86 |
+
print("Success!")
|
87 |
+
print("Your frames are here:", frames_path)
|
88 |
+
print("Your video is here:", video_path)
|
89 |
+
return output_path, fps, video_path
|
90 |
+
|
91 |
+
|
92 |
+
class KeyFrameMatcher:
|
93 |
+
def __init__(self):
|
94 |
+
pass
|
95 |
+
|
96 |
+
def extract_number_from_filename(self, file_name):
|
97 |
+
result = []
|
98 |
+
number = -1
|
99 |
+
for i in file_name:
|
100 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
101 |
+
if number == -1:
|
102 |
+
number = 0
|
103 |
+
number = number*10 + ord(i) - ord("0")
|
104 |
+
else:
|
105 |
+
if number != -1:
|
106 |
+
result.append(number)
|
107 |
+
number = -1
|
108 |
+
if number != -1:
|
109 |
+
result.append(number)
|
110 |
+
result = tuple(result)
|
111 |
+
return result
|
112 |
+
|
113 |
+
def extract_number_from_filenames(self, file_names):
|
114 |
+
numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
|
115 |
+
min_length = min(len(i) for i in numbers)
|
116 |
+
for i in range(min_length-1, -1, -1):
|
117 |
+
if len(set(number[i] for number in numbers))==len(file_names):
|
118 |
+
return [number[i] for number in numbers]
|
119 |
+
return list(range(len(file_names)))
|
120 |
+
|
121 |
+
def match_using_filename(self, file_names_a, file_names_b):
|
122 |
+
file_names_b_set = set(file_names_b)
|
123 |
+
matched_file_name = []
|
124 |
+
for file_name in file_names_a:
|
125 |
+
if file_name not in file_names_b_set:
|
126 |
+
matched_file_name.append(None)
|
127 |
+
else:
|
128 |
+
matched_file_name.append(file_name)
|
129 |
+
return matched_file_name
|
130 |
+
|
131 |
+
def match_using_numbers(self, file_names_a, file_names_b):
|
132 |
+
numbers_a = self.extract_number_from_filenames(file_names_a)
|
133 |
+
numbers_b = self.extract_number_from_filenames(file_names_b)
|
134 |
+
numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
|
135 |
+
matched_file_name = []
|
136 |
+
for number in numbers_a:
|
137 |
+
if number in numbers_b_dict:
|
138 |
+
matched_file_name.append(numbers_b_dict[number])
|
139 |
+
else:
|
140 |
+
matched_file_name.append(None)
|
141 |
+
return matched_file_name
|
142 |
+
|
143 |
+
def match_filenames(self, file_names_a, file_names_b):
|
144 |
+
matched_file_name = self.match_using_filename(file_names_a, file_names_b)
|
145 |
+
if sum([i is not None for i in matched_file_name]) > 0:
|
146 |
+
return matched_file_name
|
147 |
+
matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
|
148 |
+
return matched_file_name
|
149 |
+
|
150 |
+
|
151 |
+
def detect_frames(frames_path, keyframes_path):
|
152 |
+
if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
|
153 |
+
return "Please input the directory of guide video and rendered frames"
|
154 |
+
elif not os.path.exists(frames_path):
|
155 |
+
return "Please input the directory of guide video"
|
156 |
+
elif not os.path.exists(keyframes_path):
|
157 |
+
return "Please input the directory of rendered frames"
|
158 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
159 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
160 |
+
if len(frames)==0:
|
161 |
+
return f"No images detected in {frames_path}"
|
162 |
+
if len(keyframes)==0:
|
163 |
+
return f"No images detected in {keyframes_path}"
|
164 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
165 |
+
max_filename_length = max([len(i) for i in frames])
|
166 |
+
if sum([i is not None for i in matched_keyframes])==0:
|
167 |
+
message = ""
|
168 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
169 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
170 |
+
message += "--> No matched keyframes\n"
|
171 |
+
else:
|
172 |
+
message = ""
|
173 |
+
for frame, matched_keyframe in zip(frames, matched_keyframes):
|
174 |
+
message += frame + " " * (max_filename_length - len(frame) + 1)
|
175 |
+
if matched_keyframe is None:
|
176 |
+
message += "--> [to be rendered]\n"
|
177 |
+
else:
|
178 |
+
message += f"--> {matched_keyframe}\n"
|
179 |
+
return message
|
180 |
+
|
181 |
+
|
182 |
+
def check_input_for_interpolating(frames_path, keyframes_path):
|
183 |
+
# search for images
|
184 |
+
frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
|
185 |
+
keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
|
186 |
+
# match frames
|
187 |
+
matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
|
188 |
+
file_list = [file_name for file_name in matched_keyframes if file_name is not None]
|
189 |
+
index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
|
190 |
+
frames_guide = VideoData(None, frames_path)
|
191 |
+
frames_style = VideoData(None, keyframes_path, file_list=file_list)
|
192 |
+
# match shape
|
193 |
+
message = ""
|
194 |
+
height_guide, width_guide = frames_guide.shape()
|
195 |
+
height_style, width_style = frames_style.shape()
|
196 |
+
if height_guide != height_style or width_guide != width_style:
|
197 |
+
message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
|
198 |
+
frames_style.set_shape(height_guide, width_guide)
|
199 |
+
return frames_guide, frames_style, index_style, message
|
200 |
+
|
201 |
+
|
202 |
+
def interpolate_video(
|
203 |
+
frames_path,
|
204 |
+
keyframes_path,
|
205 |
+
output_path,
|
206 |
+
fps,
|
207 |
+
batch_size,
|
208 |
+
tracking_window_size,
|
209 |
+
minimum_patch_size,
|
210 |
+
num_iter,
|
211 |
+
guide_weight,
|
212 |
+
initialize,
|
213 |
+
progress = None,
|
214 |
+
):
|
215 |
+
# input
|
216 |
+
frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
|
217 |
+
if len(message) > 0:
|
218 |
+
print(message)
|
219 |
+
# output
|
220 |
+
if output_path == "":
|
221 |
+
output_path = os.path.join(keyframes_path, "output")
|
222 |
+
os.makedirs(output_path, exist_ok=True)
|
223 |
+
print("No valid output_path. Your video will be saved here:", output_path)
|
224 |
+
elif not os.path.exists(output_path):
|
225 |
+
os.makedirs(output_path, exist_ok=True)
|
226 |
+
print("Your video will be saved here:", output_path)
|
227 |
+
output_frames_path = os.path.join(output_path, "frames")
|
228 |
+
output_video_path = os.path.join(output_path, "video.mp4")
|
229 |
+
os.makedirs(output_frames_path, exist_ok=True)
|
230 |
+
# process
|
231 |
+
ebsynth_config = {
|
232 |
+
"minimum_patch_size": minimum_patch_size,
|
233 |
+
"threads_per_block": 8,
|
234 |
+
"num_iter": num_iter,
|
235 |
+
"gpu_id": 0,
|
236 |
+
"guide_weight": guide_weight,
|
237 |
+
"initialize": initialize,
|
238 |
+
"tracking_window_size": tracking_window_size
|
239 |
+
}
|
240 |
+
if len(index_style)==1:
|
241 |
+
InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
242 |
+
else:
|
243 |
+
InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
|
244 |
+
try:
|
245 |
+
fps = int(fps)
|
246 |
+
except:
|
247 |
+
fps = 30
|
248 |
+
print("Fps:", fps)
|
249 |
+
print("Saving video...")
|
250 |
+
video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
|
251 |
+
print("Success!")
|
252 |
+
print("Your frames are here:", output_frames_path)
|
253 |
+
print("Your video is here:", video_path)
|
254 |
+
return output_path, fps, video_path
|
255 |
+
|
256 |
+
|
257 |
+
def on_ui_tabs():
|
258 |
+
with gr.Blocks(analytics_enabled=False) as ui_component:
|
259 |
+
with gr.Tab("Blend"):
|
260 |
+
gr.Markdown("""
|
261 |
+
# Blend
|
262 |
+
|
263 |
+
Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
|
264 |
+
""")
|
265 |
+
with gr.Row():
|
266 |
+
with gr.Column():
|
267 |
+
with gr.Tab("Guide video"):
|
268 |
+
video_guide = gr.Video(label="Guide video")
|
269 |
+
with gr.Tab("Guide video (images format)"):
|
270 |
+
video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
|
271 |
+
with gr.Column():
|
272 |
+
with gr.Tab("Style video"):
|
273 |
+
video_style = gr.Video(label="Style video")
|
274 |
+
with gr.Tab("Style video (images format)"):
|
275 |
+
video_style_folder = gr.Textbox(label="Style video (images format)", value="")
|
276 |
+
with gr.Column():
|
277 |
+
output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
|
278 |
+
fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
279 |
+
video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
280 |
+
btn = gr.Button(value="Blend")
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
gr.Markdown("# Settings")
|
284 |
+
mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
|
285 |
+
window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
|
286 |
+
batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
287 |
+
tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
288 |
+
gr.Markdown("## Advanced Settings")
|
289 |
+
minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
|
290 |
+
num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
291 |
+
guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
292 |
+
initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
293 |
+
with gr.Column():
|
294 |
+
gr.Markdown("""
|
295 |
+
# Reference
|
296 |
+
|
297 |
+
* Output directory: the directory to save the video.
|
298 |
+
* Inference mode
|
299 |
+
|
300 |
+
|Mode|Time|Memory|Quality|Frame by frame output|Description|
|
301 |
+
|-|-|-|-|-|-|
|
302 |
+
|Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
|
303 |
+
|Balanced|■■|■|■■|Yes|Blend the frames naively.|
|
304 |
+
|Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
|
305 |
+
|
306 |
+
* Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
|
307 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
308 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
309 |
+
* Advanced settings
|
310 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
|
311 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
312 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
313 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
314 |
+
""")
|
315 |
+
btn.click(
|
316 |
+
smooth_video,
|
317 |
+
inputs=[
|
318 |
+
video_guide,
|
319 |
+
video_guide_folder,
|
320 |
+
video_style,
|
321 |
+
video_style_folder,
|
322 |
+
mode,
|
323 |
+
window_size,
|
324 |
+
batch_size,
|
325 |
+
tracking_window_size,
|
326 |
+
output_path,
|
327 |
+
fps,
|
328 |
+
minimum_patch_size,
|
329 |
+
num_iter,
|
330 |
+
guide_weight,
|
331 |
+
initialize
|
332 |
+
],
|
333 |
+
outputs=[output_path, fps, video_output]
|
334 |
+
)
|
335 |
+
with gr.Tab("Interpolate"):
|
336 |
+
gr.Markdown("""
|
337 |
+
# Interpolate
|
338 |
+
|
339 |
+
Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
|
340 |
+
""")
|
341 |
+
with gr.Row():
|
342 |
+
with gr.Column():
|
343 |
+
with gr.Row():
|
344 |
+
with gr.Column():
|
345 |
+
video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
|
346 |
+
with gr.Column():
|
347 |
+
rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
|
348 |
+
with gr.Row():
|
349 |
+
detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
|
350 |
+
video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
351 |
+
rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
|
352 |
+
with gr.Column():
|
353 |
+
output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
|
354 |
+
fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
|
355 |
+
video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
|
356 |
+
btn_ = gr.Button(value="Interpolate")
|
357 |
+
with gr.Row():
|
358 |
+
with gr.Column():
|
359 |
+
gr.Markdown("# Settings")
|
360 |
+
batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
|
361 |
+
tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
|
362 |
+
gr.Markdown("## Advanced Settings")
|
363 |
+
minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
|
364 |
+
num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
|
365 |
+
guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
|
366 |
+
initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
|
367 |
+
with gr.Column():
|
368 |
+
gr.Markdown("""
|
369 |
+
# Reference
|
370 |
+
|
371 |
+
* Output directory: the directory to save the video.
|
372 |
+
* Batch size: a larger batch size makes the program faster but requires more VRAM.
|
373 |
+
* Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
|
374 |
+
* Advanced settings
|
375 |
+
* Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
|
376 |
+
* Number of iterations: the number of iterations of patch matching. (Default: 5)
|
377 |
+
* Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
|
378 |
+
* NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
|
379 |
+
""")
|
380 |
+
btn_.click(
|
381 |
+
interpolate_video,
|
382 |
+
inputs=[
|
383 |
+
video_guide_folder_,
|
384 |
+
rendered_keyframes_,
|
385 |
+
output_path_,
|
386 |
+
fps_,
|
387 |
+
batch_size_,
|
388 |
+
tracking_window_size_,
|
389 |
+
minimum_patch_size_,
|
390 |
+
num_iter_,
|
391 |
+
guide_weight_,
|
392 |
+
initialize_,
|
393 |
+
],
|
394 |
+
outputs=[output_path_, fps_, video_output_]
|
395 |
+
)
|
396 |
+
|
397 |
+
return [(ui_component, "FastBlend", "FastBlend_ui")]
|
diffsynth/extensions/FastBlend/cupy_kernels.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cupy as cp
|
2 |
+
|
3 |
+
remapping_kernel = cp.RawKernel(r'''
|
4 |
+
extern "C" __global__
|
5 |
+
void remap(
|
6 |
+
const int height,
|
7 |
+
const int width,
|
8 |
+
const int channel,
|
9 |
+
const int patch_size,
|
10 |
+
const int pad_size,
|
11 |
+
const float* source_style,
|
12 |
+
const int* nnf,
|
13 |
+
float* target_style
|
14 |
+
) {
|
15 |
+
const int r = (patch_size - 1) / 2;
|
16 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
17 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
18 |
+
if (x >= height or y >= width) return;
|
19 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
20 |
+
const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
|
21 |
+
const int min_px = x < r ? -x : -r;
|
22 |
+
const int max_px = x + r > height - 1 ? height - 1 - x : r;
|
23 |
+
const int min_py = y < r ? -y : -r;
|
24 |
+
const int max_py = y + r > width - 1 ? width - 1 - y : r;
|
25 |
+
int num = 0;
|
26 |
+
for (int px = min_px; px <= max_px; px++){
|
27 |
+
for (int py = min_py; py <= max_py; py++){
|
28 |
+
const int nid = (x + px) * width + y + py;
|
29 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
|
30 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
|
31 |
+
if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
|
32 |
+
const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
|
33 |
+
num++;
|
34 |
+
for (int c = 0; c < channel; c++){
|
35 |
+
target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
|
36 |
+
}
|
37 |
+
}
|
38 |
+
}
|
39 |
+
for (int c = 0; c < channel; c++){
|
40 |
+
target_style[z + pid * channel + c] /= num;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
''', 'remap')
|
44 |
+
|
45 |
+
|
46 |
+
patch_error_kernel = cp.RawKernel(r'''
|
47 |
+
extern "C" __global__
|
48 |
+
void patch_error(
|
49 |
+
const int height,
|
50 |
+
const int width,
|
51 |
+
const int channel,
|
52 |
+
const int patch_size,
|
53 |
+
const int pad_size,
|
54 |
+
const float* source,
|
55 |
+
const int* nnf,
|
56 |
+
const float* target,
|
57 |
+
float* error
|
58 |
+
) {
|
59 |
+
const int r = (patch_size - 1) / 2;
|
60 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
61 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
62 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
63 |
+
if (x >= height or y >= width) return;
|
64 |
+
const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
|
65 |
+
const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
|
66 |
+
float e = 0;
|
67 |
+
for (int px = -r; px <= r; px++){
|
68 |
+
for (int py = -r; py <= r; py++){
|
69 |
+
const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
|
70 |
+
const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
|
71 |
+
for (int c = 0; c < channel; c++){
|
72 |
+
const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
|
73 |
+
e += diff * diff;
|
74 |
+
}
|
75 |
+
}
|
76 |
+
}
|
77 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
78 |
+
}
|
79 |
+
''', 'patch_error')
|
80 |
+
|
81 |
+
|
82 |
+
pairwise_patch_error_kernel = cp.RawKernel(r'''
|
83 |
+
extern "C" __global__
|
84 |
+
void pairwise_patch_error(
|
85 |
+
const int height,
|
86 |
+
const int width,
|
87 |
+
const int channel,
|
88 |
+
const int patch_size,
|
89 |
+
const int pad_size,
|
90 |
+
const float* source_a,
|
91 |
+
const int* nnf_a,
|
92 |
+
const float* source_b,
|
93 |
+
const int* nnf_b,
|
94 |
+
float* error
|
95 |
+
) {
|
96 |
+
const int r = (patch_size - 1) / 2;
|
97 |
+
const int x = blockDim.x * blockIdx.x + threadIdx.x;
|
98 |
+
const int y = blockDim.y * blockIdx.y + threadIdx.y;
|
99 |
+
const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
|
100 |
+
if (x >= height or y >= width) return;
|
101 |
+
const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
|
102 |
+
const int x_a = nnf_a[z_nnf + 0];
|
103 |
+
const int y_a = nnf_a[z_nnf + 1];
|
104 |
+
const int x_b = nnf_b[z_nnf + 0];
|
105 |
+
const int y_b = nnf_b[z_nnf + 1];
|
106 |
+
float e = 0;
|
107 |
+
for (int px = -r; px <= r; px++){
|
108 |
+
for (int py = -r; py <= r; py++){
|
109 |
+
const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
|
110 |
+
const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
|
111 |
+
for (int c = 0; c < channel; c++){
|
112 |
+
const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
|
113 |
+
e += diff * diff;
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
error[blockIdx.z * height * width + x * width + y] = e;
|
118 |
+
}
|
119 |
+
''', 'pairwise_patch_error')
|
diffsynth/extensions/FastBlend/data.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio, os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def read_video(file_name):
|
7 |
+
reader = imageio.get_reader(file_name)
|
8 |
+
video = []
|
9 |
+
for frame in reader:
|
10 |
+
frame = np.array(frame)
|
11 |
+
video.append(frame)
|
12 |
+
reader.close()
|
13 |
+
return video
|
14 |
+
|
15 |
+
|
16 |
+
def get_video_fps(file_name):
|
17 |
+
reader = imageio.get_reader(file_name)
|
18 |
+
fps = reader.get_meta_data()["fps"]
|
19 |
+
reader.close()
|
20 |
+
return fps
|
21 |
+
|
22 |
+
|
23 |
+
def save_video(frames_path, video_path, num_frames, fps):
|
24 |
+
writer = imageio.get_writer(video_path, fps=fps, quality=9)
|
25 |
+
for i in range(num_frames):
|
26 |
+
frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
|
27 |
+
writer.append_data(frame)
|
28 |
+
writer.close()
|
29 |
+
return video_path
|
30 |
+
|
31 |
+
|
32 |
+
class LowMemoryVideo:
|
33 |
+
def __init__(self, file_name):
|
34 |
+
self.reader = imageio.get_reader(file_name)
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return self.reader.count_frames()
|
38 |
+
|
39 |
+
def __getitem__(self, item):
|
40 |
+
return np.array(self.reader.get_data(item))
|
41 |
+
|
42 |
+
def __del__(self):
|
43 |
+
self.reader.close()
|
44 |
+
|
45 |
+
|
46 |
+
def split_file_name(file_name):
|
47 |
+
result = []
|
48 |
+
number = -1
|
49 |
+
for i in file_name:
|
50 |
+
if ord(i)>=ord("0") and ord(i)<=ord("9"):
|
51 |
+
if number == -1:
|
52 |
+
number = 0
|
53 |
+
number = number*10 + ord(i) - ord("0")
|
54 |
+
else:
|
55 |
+
if number != -1:
|
56 |
+
result.append(number)
|
57 |
+
number = -1
|
58 |
+
result.append(i)
|
59 |
+
if number != -1:
|
60 |
+
result.append(number)
|
61 |
+
result = tuple(result)
|
62 |
+
return result
|
63 |
+
|
64 |
+
|
65 |
+
def search_for_images(folder):
|
66 |
+
file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
|
67 |
+
file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
|
68 |
+
file_list = [i[1] for i in sorted(file_list)]
|
69 |
+
file_list = [os.path.join(folder, i) for i in file_list]
|
70 |
+
return file_list
|
71 |
+
|
72 |
+
|
73 |
+
def read_images(folder):
|
74 |
+
file_list = search_for_images(folder)
|
75 |
+
frames = [np.array(Image.open(i)) for i in file_list]
|
76 |
+
return frames
|
77 |
+
|
78 |
+
|
79 |
+
class LowMemoryImageFolder:
|
80 |
+
def __init__(self, folder, file_list=None):
|
81 |
+
if file_list is None:
|
82 |
+
self.file_list = search_for_images(folder)
|
83 |
+
else:
|
84 |
+
self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.file_list)
|
88 |
+
|
89 |
+
def __getitem__(self, item):
|
90 |
+
return np.array(Image.open(self.file_list[item]))
|
91 |
+
|
92 |
+
def __del__(self):
|
93 |
+
pass
|
94 |
+
|
95 |
+
|
96 |
+
class VideoData:
|
97 |
+
def __init__(self, video_file, image_folder, **kwargs):
|
98 |
+
if video_file is not None:
|
99 |
+
self.data_type = "video"
|
100 |
+
self.data = LowMemoryVideo(video_file, **kwargs)
|
101 |
+
elif image_folder is not None:
|
102 |
+
self.data_type = "images"
|
103 |
+
self.data = LowMemoryImageFolder(image_folder, **kwargs)
|
104 |
+
else:
|
105 |
+
raise ValueError("Cannot open video or image folder")
|
106 |
+
self.length = None
|
107 |
+
self.height = None
|
108 |
+
self.width = None
|
109 |
+
|
110 |
+
def raw_data(self):
|
111 |
+
frames = []
|
112 |
+
for i in range(self.__len__()):
|
113 |
+
frames.append(self.__getitem__(i))
|
114 |
+
return frames
|
115 |
+
|
116 |
+
def set_length(self, length):
|
117 |
+
self.length = length
|
118 |
+
|
119 |
+
def set_shape(self, height, width):
|
120 |
+
self.height = height
|
121 |
+
self.width = width
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
if self.length is None:
|
125 |
+
return len(self.data)
|
126 |
+
else:
|
127 |
+
return self.length
|
128 |
+
|
129 |
+
def shape(self):
|
130 |
+
if self.height is not None and self.width is not None:
|
131 |
+
return self.height, self.width
|
132 |
+
else:
|
133 |
+
height, width, _ = self.__getitem__(0).shape
|
134 |
+
return height, width
|
135 |
+
|
136 |
+
def __getitem__(self, item):
|
137 |
+
frame = self.data.__getitem__(item)
|
138 |
+
height, width, _ = frame.shape
|
139 |
+
if self.height is not None and self.width is not None:
|
140 |
+
if self.height != height or self.width != width:
|
141 |
+
frame = Image.fromarray(frame).resize((self.width, self.height))
|
142 |
+
frame = np.array(frame)
|
143 |
+
return frame
|
144 |
+
|
145 |
+
def __del__(self):
|
146 |
+
pass
|
diffsynth/extensions/FastBlend/patch_match.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
|
2 |
+
import numpy as np
|
3 |
+
import cupy as cp
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
|
7 |
+
class PatchMatcher:
|
8 |
+
def __init__(
|
9 |
+
self, height, width, channel, minimum_patch_size,
|
10 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
11 |
+
random_search_steps=3, random_search_range=4,
|
12 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
13 |
+
tracking_window_size=0
|
14 |
+
):
|
15 |
+
self.height = height
|
16 |
+
self.width = width
|
17 |
+
self.channel = channel
|
18 |
+
self.minimum_patch_size = minimum_patch_size
|
19 |
+
self.threads_per_block = threads_per_block
|
20 |
+
self.num_iter = num_iter
|
21 |
+
self.gpu_id = gpu_id
|
22 |
+
self.guide_weight = guide_weight
|
23 |
+
self.random_search_steps = random_search_steps
|
24 |
+
self.random_search_range = random_search_range
|
25 |
+
self.use_mean_target_style = use_mean_target_style
|
26 |
+
self.use_pairwise_patch_error = use_pairwise_patch_error
|
27 |
+
self.tracking_window_size = tracking_window_size
|
28 |
+
|
29 |
+
self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
|
30 |
+
self.pad_size = self.patch_size_list[0] // 2
|
31 |
+
self.grid = (
|
32 |
+
(height + threads_per_block - 1) // threads_per_block,
|
33 |
+
(width + threads_per_block - 1) // threads_per_block
|
34 |
+
)
|
35 |
+
self.block = (threads_per_block, threads_per_block)
|
36 |
+
|
37 |
+
def pad_image(self, image):
|
38 |
+
return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
|
39 |
+
|
40 |
+
def unpad_image(self, image):
|
41 |
+
return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
|
42 |
+
|
43 |
+
def apply_nnf_to_image(self, nnf, source):
|
44 |
+
batch_size = source.shape[0]
|
45 |
+
target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
|
46 |
+
remapping_kernel(
|
47 |
+
self.grid + (batch_size,),
|
48 |
+
self.block,
|
49 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
|
50 |
+
)
|
51 |
+
return target
|
52 |
+
|
53 |
+
def get_patch_error(self, source, nnf, target):
|
54 |
+
batch_size = source.shape[0]
|
55 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
56 |
+
patch_error_kernel(
|
57 |
+
self.grid + (batch_size,),
|
58 |
+
self.block,
|
59 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
|
60 |
+
)
|
61 |
+
return error
|
62 |
+
|
63 |
+
def get_pairwise_patch_error(self, source, nnf):
|
64 |
+
batch_size = source.shape[0]//2
|
65 |
+
error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
|
66 |
+
source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
|
67 |
+
source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
|
68 |
+
pairwise_patch_error_kernel(
|
69 |
+
self.grid + (batch_size,),
|
70 |
+
self.block,
|
71 |
+
(self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
|
72 |
+
)
|
73 |
+
error = error.repeat(2, axis=0)
|
74 |
+
return error
|
75 |
+
|
76 |
+
def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
|
77 |
+
error_guide = self.get_patch_error(source_guide, nnf, target_guide)
|
78 |
+
if self.use_mean_target_style:
|
79 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
80 |
+
target_style = target_style.mean(axis=0, keepdims=True)
|
81 |
+
target_style = target_style.repeat(source_guide.shape[0], axis=0)
|
82 |
+
if self.use_pairwise_patch_error:
|
83 |
+
error_style = self.get_pairwise_patch_error(source_style, nnf)
|
84 |
+
else:
|
85 |
+
error_style = self.get_patch_error(source_style, nnf, target_style)
|
86 |
+
error = error_guide * self.guide_weight + error_style
|
87 |
+
return error
|
88 |
+
|
89 |
+
def clamp_bound(self, nnf):
|
90 |
+
nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
|
91 |
+
nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
|
92 |
+
return nnf
|
93 |
+
|
94 |
+
def random_step(self, nnf, r):
|
95 |
+
batch_size = nnf.shape[0]
|
96 |
+
step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
|
97 |
+
upd_nnf = self.clamp_bound(nnf + step)
|
98 |
+
return upd_nnf
|
99 |
+
|
100 |
+
def neighboor_step(self, nnf, d):
|
101 |
+
if d==0:
|
102 |
+
upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
|
103 |
+
upd_nnf[:, :, :, 0] += 1
|
104 |
+
elif d==1:
|
105 |
+
upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
|
106 |
+
upd_nnf[:, :, :, 1] += 1
|
107 |
+
elif d==2:
|
108 |
+
upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
|
109 |
+
upd_nnf[:, :, :, 0] -= 1
|
110 |
+
elif d==3:
|
111 |
+
upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
|
112 |
+
upd_nnf[:, :, :, 1] -= 1
|
113 |
+
upd_nnf = self.clamp_bound(upd_nnf)
|
114 |
+
return upd_nnf
|
115 |
+
|
116 |
+
def shift_nnf(self, nnf, d):
|
117 |
+
if d>0:
|
118 |
+
d = min(nnf.shape[0], d)
|
119 |
+
upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
120 |
+
else:
|
121 |
+
d = max(-nnf.shape[0], d)
|
122 |
+
upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
123 |
+
return upd_nnf
|
124 |
+
|
125 |
+
def track_step(self, nnf, d):
|
126 |
+
if self.use_pairwise_patch_error:
|
127 |
+
upd_nnf = cp.zeros_like(nnf)
|
128 |
+
upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
|
129 |
+
upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
|
130 |
+
else:
|
131 |
+
upd_nnf = self.shift_nnf(nnf, d)
|
132 |
+
return upd_nnf
|
133 |
+
|
134 |
+
def C(self, n, m):
|
135 |
+
# not used
|
136 |
+
c = 1
|
137 |
+
for i in range(1, n+1):
|
138 |
+
c *= i
|
139 |
+
for i in range(1, m+1):
|
140 |
+
c //= i
|
141 |
+
for i in range(1, n-m+1):
|
142 |
+
c //= i
|
143 |
+
return c
|
144 |
+
|
145 |
+
def bezier_step(self, nnf, r):
|
146 |
+
# not used
|
147 |
+
n = r * 2 - 1
|
148 |
+
upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
|
149 |
+
for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
|
150 |
+
if d>0:
|
151 |
+
ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
|
152 |
+
elif d<0:
|
153 |
+
ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
|
154 |
+
upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
|
155 |
+
upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
|
156 |
+
return upd_nnf
|
157 |
+
|
158 |
+
def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
|
159 |
+
upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
|
160 |
+
upd_idx = (upd_err < err)
|
161 |
+
nnf[upd_idx] = upd_nnf[upd_idx]
|
162 |
+
err[upd_idx] = upd_err[upd_idx]
|
163 |
+
return nnf, err
|
164 |
+
|
165 |
+
def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
166 |
+
for d in cp.random.permutation(4):
|
167 |
+
upd_nnf = self.neighboor_step(nnf, d)
|
168 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
169 |
+
return nnf, err
|
170 |
+
|
171 |
+
def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
172 |
+
for i in range(self.random_search_steps):
|
173 |
+
upd_nnf = self.random_step(nnf, self.random_search_range)
|
174 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
175 |
+
return nnf, err
|
176 |
+
|
177 |
+
def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
178 |
+
for d in range(1, self.tracking_window_size + 1):
|
179 |
+
upd_nnf = self.track_step(nnf, d)
|
180 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
181 |
+
upd_nnf = self.track_step(nnf, -d)
|
182 |
+
nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
|
183 |
+
return nnf, err
|
184 |
+
|
185 |
+
def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
|
186 |
+
nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
|
187 |
+
nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
|
188 |
+
nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
|
189 |
+
return nnf, err
|
190 |
+
|
191 |
+
def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
|
192 |
+
with cp.cuda.Device(self.gpu_id):
|
193 |
+
source_guide = self.pad_image(source_guide)
|
194 |
+
target_guide = self.pad_image(target_guide)
|
195 |
+
source_style = self.pad_image(source_style)
|
196 |
+
for it in range(self.num_iter):
|
197 |
+
self.patch_size = self.patch_size_list[it]
|
198 |
+
target_style = self.apply_nnf_to_image(nnf, source_style)
|
199 |
+
err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
|
200 |
+
nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
|
201 |
+
target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
|
202 |
+
return nnf, target_style
|
203 |
+
|
204 |
+
|
205 |
+
class PyramidPatchMatcher:
|
206 |
+
def __init__(
|
207 |
+
self, image_height, image_width, channel, minimum_patch_size,
|
208 |
+
threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
|
209 |
+
use_mean_target_style=False, use_pairwise_patch_error=False,
|
210 |
+
tracking_window_size=0,
|
211 |
+
initialize="identity"
|
212 |
+
):
|
213 |
+
maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
|
214 |
+
self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
|
215 |
+
self.pyramid_heights = []
|
216 |
+
self.pyramid_widths = []
|
217 |
+
self.patch_matchers = []
|
218 |
+
self.minimum_patch_size = minimum_patch_size
|
219 |
+
self.num_iter = num_iter
|
220 |
+
self.gpu_id = gpu_id
|
221 |
+
self.initialize = initialize
|
222 |
+
for level in range(self.pyramid_level):
|
223 |
+
height = image_height//(2**(self.pyramid_level - 1 - level))
|
224 |
+
width = image_width//(2**(self.pyramid_level - 1 - level))
|
225 |
+
self.pyramid_heights.append(height)
|
226 |
+
self.pyramid_widths.append(width)
|
227 |
+
self.patch_matchers.append(PatchMatcher(
|
228 |
+
height, width, channel, minimum_patch_size=minimum_patch_size,
|
229 |
+
threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
|
230 |
+
use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
|
231 |
+
tracking_window_size=tracking_window_size
|
232 |
+
))
|
233 |
+
|
234 |
+
def resample_image(self, images, level):
|
235 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
236 |
+
images = images.get()
|
237 |
+
images_resample = []
|
238 |
+
for image in images:
|
239 |
+
image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
|
240 |
+
images_resample.append(image_resample)
|
241 |
+
images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
|
242 |
+
return images_resample
|
243 |
+
|
244 |
+
def initialize_nnf(self, batch_size):
|
245 |
+
if self.initialize == "random":
|
246 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
247 |
+
nnf = cp.stack([
|
248 |
+
cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
|
249 |
+
cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
|
250 |
+
], axis=3)
|
251 |
+
elif self.initialize == "identity":
|
252 |
+
height, width = self.pyramid_heights[0], self.pyramid_widths[0]
|
253 |
+
nnf = cp.stack([
|
254 |
+
cp.repeat(cp.arange(height), width).reshape(height, width),
|
255 |
+
cp.tile(cp.arange(width), height).reshape(height, width)
|
256 |
+
], axis=2)
|
257 |
+
nnf = cp.stack([nnf] * batch_size)
|
258 |
+
else:
|
259 |
+
raise NotImplementedError()
|
260 |
+
return nnf
|
261 |
+
|
262 |
+
def update_nnf(self, nnf, level):
|
263 |
+
# upscale
|
264 |
+
nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
|
265 |
+
nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
|
266 |
+
nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
|
267 |
+
# check if scale is 2
|
268 |
+
height, width = self.pyramid_heights[level], self.pyramid_widths[level]
|
269 |
+
if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
|
270 |
+
nnf = nnf.get().astype(np.float32)
|
271 |
+
nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
|
272 |
+
nnf = cp.array(np.stack(nnf), dtype=cp.int32)
|
273 |
+
nnf = self.patch_matchers[level].clamp_bound(nnf)
|
274 |
+
return nnf
|
275 |
+
|
276 |
+
def apply_nnf_to_image(self, nnf, image):
|
277 |
+
with cp.cuda.Device(self.gpu_id):
|
278 |
+
image = self.patch_matchers[-1].pad_image(image)
|
279 |
+
image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
|
280 |
+
return image
|
281 |
+
|
282 |
+
def estimate_nnf(self, source_guide, target_guide, source_style):
|
283 |
+
with cp.cuda.Device(self.gpu_id):
|
284 |
+
if not isinstance(source_guide, cp.ndarray):
|
285 |
+
source_guide = cp.array(source_guide, dtype=cp.float32)
|
286 |
+
if not isinstance(target_guide, cp.ndarray):
|
287 |
+
target_guide = cp.array(target_guide, dtype=cp.float32)
|
288 |
+
if not isinstance(source_style, cp.ndarray):
|
289 |
+
source_style = cp.array(source_style, dtype=cp.float32)
|
290 |
+
for level in range(self.pyramid_level):
|
291 |
+
nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
|
292 |
+
source_guide_ = self.resample_image(source_guide, level)
|
293 |
+
target_guide_ = self.resample_image(target_guide, level)
|
294 |
+
source_style_ = self.resample_image(source_style, level)
|
295 |
+
nnf, target_style = self.patch_matchers[level].estimate_nnf(
|
296 |
+
source_guide_, target_guide_, source_style_, nnf
|
297 |
+
)
|
298 |
+
return nnf.get(), target_style.get()
|
diffsynth/extensions/FastBlend/runners/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .accurate import AccurateModeRunner
|
2 |
+
from .fast import FastModeRunner
|
3 |
+
from .balanced import BalancedModeRunner
|
4 |
+
from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
|
diffsynth/extensions/FastBlend/runners/accurate.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class AccurateModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
use_mean_target_style=True,
|
18 |
+
**ebsynth_config
|
19 |
+
)
|
20 |
+
# run
|
21 |
+
n = len(frames_style)
|
22 |
+
for target in tqdm(range(n), desc=desc):
|
23 |
+
l, r = max(target - window_size, 0), min(target + window_size + 1, n)
|
24 |
+
remapped_frames = []
|
25 |
+
for i in range(l, r, batch_size):
|
26 |
+
j = min(i + batch_size, r)
|
27 |
+
source_guide = np.stack([frames_guide[source] for source in range(i, j)])
|
28 |
+
target_guide = np.stack([frames_guide[target]] * (j - i))
|
29 |
+
source_style = np.stack([frames_style[source] for source in range(i, j)])
|
30 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
31 |
+
remapped_frames.append(target_style)
|
32 |
+
frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
|
33 |
+
frame = frame.clip(0, 255).astype("uint8")
|
34 |
+
if save_path is not None:
|
35 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/balanced.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class BalancedModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
|
13 |
+
patch_match_engine = PyramidPatchMatcher(
|
14 |
+
image_height=frames_style[0].shape[0],
|
15 |
+
image_width=frames_style[0].shape[1],
|
16 |
+
channel=3,
|
17 |
+
**ebsynth_config
|
18 |
+
)
|
19 |
+
# tasks
|
20 |
+
n = len(frames_style)
|
21 |
+
tasks = []
|
22 |
+
for target in range(n):
|
23 |
+
for source in range(target - window_size, target + window_size + 1):
|
24 |
+
if source >= 0 and source < n and source != target:
|
25 |
+
tasks.append((source, target))
|
26 |
+
# run
|
27 |
+
frames = [(None, 1) for i in range(n)]
|
28 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
29 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
30 |
+
source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
|
31 |
+
target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
|
32 |
+
source_style = np.stack([frames_style[source] for source, target in tasks_batch])
|
33 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
34 |
+
for (source, target), result in zip(tasks_batch, target_style):
|
35 |
+
frame, weight = frames[target]
|
36 |
+
if frame is None:
|
37 |
+
frame = frames_style[target]
|
38 |
+
frames[target] = (
|
39 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
40 |
+
weight + 1
|
41 |
+
)
|
42 |
+
if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
|
43 |
+
frame = frame.clip(0, 255).astype("uint8")
|
44 |
+
if save_path is not None:
|
45 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
46 |
+
frames[target] = (None, 1)
|
diffsynth/extensions/FastBlend/runners/fast.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import functools, os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class TableManager:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def task_list(self, n):
|
13 |
+
tasks = []
|
14 |
+
max_level = 1
|
15 |
+
while (1<<max_level)<=n:
|
16 |
+
max_level += 1
|
17 |
+
for i in range(n):
|
18 |
+
j = i
|
19 |
+
for level in range(max_level):
|
20 |
+
if i&(1<<level):
|
21 |
+
continue
|
22 |
+
j |= 1<<level
|
23 |
+
if j>=n:
|
24 |
+
break
|
25 |
+
meta_data = {
|
26 |
+
"source": i,
|
27 |
+
"target": j,
|
28 |
+
"level": level + 1
|
29 |
+
}
|
30 |
+
tasks.append(meta_data)
|
31 |
+
tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
|
32 |
+
return tasks
|
33 |
+
|
34 |
+
def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
|
35 |
+
n = len(frames_guide)
|
36 |
+
tasks = self.task_list(n)
|
37 |
+
remapping_table = [[(frames_style[i], 1)] for i in range(n)]
|
38 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
39 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
40 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
41 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
42 |
+
source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
|
43 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
44 |
+
for task, result in zip(tasks_batch, target_style):
|
45 |
+
target, level = task["target"], task["level"]
|
46 |
+
if len(remapping_table[target])==level:
|
47 |
+
remapping_table[target].append((result, 1))
|
48 |
+
else:
|
49 |
+
frame, weight = remapping_table[target][level]
|
50 |
+
remapping_table[target][level] = (
|
51 |
+
frame * (weight / (weight + 1)) + result / (weight + 1),
|
52 |
+
weight + 1
|
53 |
+
)
|
54 |
+
return remapping_table
|
55 |
+
|
56 |
+
def remapping_table_to_blending_table(self, table):
|
57 |
+
for i in range(len(table)):
|
58 |
+
for j in range(1, len(table[i])):
|
59 |
+
frame_1, weight_1 = table[i][j-1]
|
60 |
+
frame_2, weight_2 = table[i][j]
|
61 |
+
frame = (frame_1 + frame_2) / 2
|
62 |
+
weight = weight_1 + weight_2
|
63 |
+
table[i][j] = (frame, weight)
|
64 |
+
return table
|
65 |
+
|
66 |
+
def tree_query(self, leftbound, rightbound):
|
67 |
+
node_list = []
|
68 |
+
node_index = rightbound
|
69 |
+
while node_index>=leftbound:
|
70 |
+
node_level = 0
|
71 |
+
while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
|
72 |
+
node_level += 1
|
73 |
+
node_list.append((node_index, node_level))
|
74 |
+
node_index -= 1<<node_level
|
75 |
+
return node_list
|
76 |
+
|
77 |
+
def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
|
78 |
+
n = len(blending_table)
|
79 |
+
tasks = []
|
80 |
+
frames_result = []
|
81 |
+
for target in range(n):
|
82 |
+
node_list = self.tree_query(max(target-window_size, 0), target)
|
83 |
+
for source, level in node_list:
|
84 |
+
if source!=target:
|
85 |
+
meta_data = {
|
86 |
+
"source": source,
|
87 |
+
"target": target,
|
88 |
+
"level": level
|
89 |
+
}
|
90 |
+
tasks.append(meta_data)
|
91 |
+
else:
|
92 |
+
frames_result.append(blending_table[target][level])
|
93 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
|
94 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
95 |
+
source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
|
96 |
+
target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
|
97 |
+
source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
|
98 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
99 |
+
for task, frame_2 in zip(tasks_batch, target_style):
|
100 |
+
source, target, level = task["source"], task["target"], task["level"]
|
101 |
+
frame_1, weight_1 = frames_result[target]
|
102 |
+
weight_2 = blending_table[source][level][1]
|
103 |
+
weight = weight_1 + weight_2
|
104 |
+
frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
|
105 |
+
frames_result[target] = (frame, weight)
|
106 |
+
return frames_result
|
107 |
+
|
108 |
+
|
109 |
+
class FastModeRunner:
|
110 |
+
def __init__(self):
|
111 |
+
pass
|
112 |
+
|
113 |
+
def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
|
114 |
+
frames_guide = frames_guide.raw_data()
|
115 |
+
frames_style = frames_style.raw_data()
|
116 |
+
table_manager = TableManager()
|
117 |
+
patch_match_engine = PyramidPatchMatcher(
|
118 |
+
image_height=frames_style[0].shape[0],
|
119 |
+
image_width=frames_style[0].shape[1],
|
120 |
+
channel=3,
|
121 |
+
**ebsynth_config
|
122 |
+
)
|
123 |
+
# left part
|
124 |
+
table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
|
125 |
+
table_l = table_manager.remapping_table_to_blending_table(table_l)
|
126 |
+
table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
|
127 |
+
# right part
|
128 |
+
table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
|
129 |
+
table_r = table_manager.remapping_table_to_blending_table(table_r)
|
130 |
+
table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
|
131 |
+
# merge
|
132 |
+
frames = []
|
133 |
+
for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
|
134 |
+
weight_m = -1
|
135 |
+
weight = weight_l + weight_m + weight_r
|
136 |
+
frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
|
137 |
+
frames.append(frame)
|
138 |
+
frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
|
139 |
+
if save_path is not None:
|
140 |
+
for target, frame in enumerate(frames):
|
141 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
|
diffsynth/extensions/FastBlend/runners/interpolation.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ..patch_match import PyramidPatchMatcher
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
class InterpolationModeRunner:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def get_index_dict(self, index_style):
|
13 |
+
index_dict = {}
|
14 |
+
for i, index in enumerate(index_style):
|
15 |
+
index_dict[index] = i
|
16 |
+
return index_dict
|
17 |
+
|
18 |
+
def get_weight(self, l, m, r):
|
19 |
+
weight_l, weight_r = abs(m - r), abs(m - l)
|
20 |
+
if weight_l + weight_r == 0:
|
21 |
+
weight_l, weight_r = 0.5, 0.5
|
22 |
+
else:
|
23 |
+
weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
|
24 |
+
return weight_l, weight_r
|
25 |
+
|
26 |
+
def get_task_group(self, index_style, n):
|
27 |
+
task_group = []
|
28 |
+
index_style = sorted(index_style)
|
29 |
+
# first frame
|
30 |
+
if index_style[0]>0:
|
31 |
+
tasks = []
|
32 |
+
for m in range(index_style[0]):
|
33 |
+
tasks.append((index_style[0], m, index_style[0]))
|
34 |
+
task_group.append(tasks)
|
35 |
+
# middle frames
|
36 |
+
for l, r in zip(index_style[:-1], index_style[1:]):
|
37 |
+
tasks = []
|
38 |
+
for m in range(l, r):
|
39 |
+
tasks.append((l, m, r))
|
40 |
+
task_group.append(tasks)
|
41 |
+
# last frame
|
42 |
+
tasks = []
|
43 |
+
for m in range(index_style[-1], n):
|
44 |
+
tasks.append((index_style[-1], m, index_style[-1]))
|
45 |
+
task_group.append(tasks)
|
46 |
+
return task_group
|
47 |
+
|
48 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
49 |
+
patch_match_engine = PyramidPatchMatcher(
|
50 |
+
image_height=frames_style[0].shape[0],
|
51 |
+
image_width=frames_style[0].shape[1],
|
52 |
+
channel=3,
|
53 |
+
use_mean_target_style=False,
|
54 |
+
use_pairwise_patch_error=True,
|
55 |
+
**ebsynth_config
|
56 |
+
)
|
57 |
+
# task
|
58 |
+
index_dict = self.get_index_dict(index_style)
|
59 |
+
task_group = self.get_task_group(index_style, len(frames_guide))
|
60 |
+
# run
|
61 |
+
for tasks in task_group:
|
62 |
+
index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
|
63 |
+
for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
|
64 |
+
tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
|
65 |
+
source_guide, target_guide, source_style = [], [], []
|
66 |
+
for l, m, r in tasks_batch:
|
67 |
+
# l -> m
|
68 |
+
source_guide.append(frames_guide[l])
|
69 |
+
target_guide.append(frames_guide[m])
|
70 |
+
source_style.append(frames_style[index_dict[l]])
|
71 |
+
# r -> m
|
72 |
+
source_guide.append(frames_guide[r])
|
73 |
+
target_guide.append(frames_guide[m])
|
74 |
+
source_style.append(frames_style[index_dict[r]])
|
75 |
+
source_guide = np.stack(source_guide)
|
76 |
+
target_guide = np.stack(target_guide)
|
77 |
+
source_style = np.stack(source_style)
|
78 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
79 |
+
if save_path is not None:
|
80 |
+
for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
|
81 |
+
weight_l, weight_r = self.get_weight(l, m, r)
|
82 |
+
frame = frame_l * weight_l + frame_r * weight_r
|
83 |
+
frame = frame.clip(0, 255).astype("uint8")
|
84 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
|
85 |
+
|
86 |
+
|
87 |
+
class InterpolationModeSingleFrameRunner:
|
88 |
+
def __init__(self):
|
89 |
+
pass
|
90 |
+
|
91 |
+
def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
|
92 |
+
# check input
|
93 |
+
tracking_window_size = ebsynth_config["tracking_window_size"]
|
94 |
+
if tracking_window_size * 2 >= batch_size:
|
95 |
+
raise ValueError("batch_size should be larger than track_window_size * 2")
|
96 |
+
frame_style = frames_style[0]
|
97 |
+
frame_guide = frames_guide[index_style[0]]
|
98 |
+
patch_match_engine = PyramidPatchMatcher(
|
99 |
+
image_height=frame_style.shape[0],
|
100 |
+
image_width=frame_style.shape[1],
|
101 |
+
channel=3,
|
102 |
+
**ebsynth_config
|
103 |
+
)
|
104 |
+
# run
|
105 |
+
frame_id, n = 0, len(frames_guide)
|
106 |
+
for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
|
107 |
+
if i + batch_size > n:
|
108 |
+
l, r = max(n - batch_size, 0), n
|
109 |
+
else:
|
110 |
+
l, r = i, i + batch_size
|
111 |
+
source_guide = np.stack([frame_guide] * (r-l))
|
112 |
+
target_guide = np.stack([frames_guide[i] for i in range(l, r)])
|
113 |
+
source_style = np.stack([frame_style] * (r-l))
|
114 |
+
_, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
|
115 |
+
for i, frame in zip(range(l, r), target_style):
|
116 |
+
if i==frame_id:
|
117 |
+
frame = frame.clip(0, 255).astype("uint8")
|
118 |
+
Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
|
119 |
+
frame_id += 1
|
120 |
+
if r < n and r-frame_id <= tracking_window_size:
|
121 |
+
break
|
diffsynth/extensions/RIFE/__init__.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
def warp(tenInput, tenFlow, device):
|
9 |
+
backwarp_tenGrid = {}
|
10 |
+
k = (str(tenFlow.device), str(tenFlow.size()))
|
11 |
+
if k not in backwarp_tenGrid:
|
12 |
+
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
13 |
+
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
|
14 |
+
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
|
15 |
+
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
|
16 |
+
backwarp_tenGrid[k] = torch.cat(
|
17 |
+
[tenHorizontal, tenVertical], 1).to(device)
|
18 |
+
|
19 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
|
20 |
+
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
|
21 |
+
|
22 |
+
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
23 |
+
return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
|
24 |
+
|
25 |
+
|
26 |
+
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
|
27 |
+
return nn.Sequential(
|
28 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
|
29 |
+
padding=padding, dilation=dilation, bias=True),
|
30 |
+
nn.PReLU(out_planes)
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class IFBlock(nn.Module):
|
35 |
+
def __init__(self, in_planes, c=64):
|
36 |
+
super(IFBlock, self).__init__()
|
37 |
+
self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
|
38 |
+
self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
|
39 |
+
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
|
40 |
+
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
|
41 |
+
self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
|
42 |
+
self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
|
43 |
+
self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
|
44 |
+
|
45 |
+
def forward(self, x, flow, scale=1):
|
46 |
+
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
47 |
+
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
|
48 |
+
feat = self.conv0(torch.cat((x, flow), 1))
|
49 |
+
feat = self.convblock0(feat) + feat
|
50 |
+
feat = self.convblock1(feat) + feat
|
51 |
+
feat = self.convblock2(feat) + feat
|
52 |
+
feat = self.convblock3(feat) + feat
|
53 |
+
flow = self.conv1(feat)
|
54 |
+
mask = self.conv2(feat)
|
55 |
+
flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
|
56 |
+
mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
|
57 |
+
return flow, mask
|
58 |
+
|
59 |
+
|
60 |
+
class IFNet(nn.Module):
|
61 |
+
def __init__(self):
|
62 |
+
super(IFNet, self).__init__()
|
63 |
+
self.block0 = IFBlock(7+4, c=90)
|
64 |
+
self.block1 = IFBlock(7+4, c=90)
|
65 |
+
self.block2 = IFBlock(7+4, c=90)
|
66 |
+
self.block_tea = IFBlock(10+4, c=90)
|
67 |
+
|
68 |
+
def forward(self, x, scale_list=[4, 2, 1], training=False):
|
69 |
+
if training == False:
|
70 |
+
channel = x.shape[1] // 2
|
71 |
+
img0 = x[:, :channel]
|
72 |
+
img1 = x[:, channel:]
|
73 |
+
flow_list = []
|
74 |
+
merged = []
|
75 |
+
mask_list = []
|
76 |
+
warped_img0 = img0
|
77 |
+
warped_img1 = img1
|
78 |
+
flow = (x[:, :4]).detach() * 0
|
79 |
+
mask = (x[:, :1]).detach() * 0
|
80 |
+
block = [self.block0, self.block1, self.block2]
|
81 |
+
for i in range(3):
|
82 |
+
f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
|
83 |
+
f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
|
84 |
+
flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
|
85 |
+
mask = mask + (m0 + (-m1)) / 2
|
86 |
+
mask_list.append(mask)
|
87 |
+
flow_list.append(flow)
|
88 |
+
warped_img0 = warp(img0, flow[:, :2], device=x.device)
|
89 |
+
warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
|
90 |
+
merged.append((warped_img0, warped_img1))
|
91 |
+
'''
|
92 |
+
c0 = self.contextnet(img0, flow[:, :2])
|
93 |
+
c1 = self.contextnet(img1, flow[:, 2:4])
|
94 |
+
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
95 |
+
res = tmp[:, 1:4] * 2 - 1
|
96 |
+
'''
|
97 |
+
for i in range(3):
|
98 |
+
mask_list[i] = torch.sigmoid(mask_list[i])
|
99 |
+
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
|
100 |
+
return flow_list, mask_list[2], merged
|
101 |
+
|
102 |
+
def state_dict_converter(self):
|
103 |
+
return IFNetStateDictConverter()
|
104 |
+
|
105 |
+
|
106 |
+
class IFNetStateDictConverter:
|
107 |
+
def __init__(self):
|
108 |
+
pass
|
109 |
+
|
110 |
+
def from_diffusers(self, state_dict):
|
111 |
+
state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
112 |
+
return state_dict_
|
113 |
+
|
114 |
+
def from_civitai(self, state_dict):
|
115 |
+
return self.from_diffusers(state_dict)
|
116 |
+
|
117 |
+
|
118 |
+
class RIFEInterpolater:
|
119 |
+
def __init__(self, model, device="cuda"):
|
120 |
+
self.model = model
|
121 |
+
self.device = device
|
122 |
+
# IFNet only does not support float16
|
123 |
+
self.torch_dtype = torch.float32
|
124 |
+
|
125 |
+
@staticmethod
|
126 |
+
def from_model_manager(model_manager):
|
127 |
+
return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
|
128 |
+
|
129 |
+
def process_image(self, image):
|
130 |
+
width, height = image.size
|
131 |
+
if width % 32 != 0 or height % 32 != 0:
|
132 |
+
width = (width + 31) // 32
|
133 |
+
height = (height + 31) // 32
|
134 |
+
image = image.resize((width, height))
|
135 |
+
image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
|
136 |
+
return image
|
137 |
+
|
138 |
+
def process_images(self, images):
|
139 |
+
images = [self.process_image(image) for image in images]
|
140 |
+
images = torch.stack(images)
|
141 |
+
return images
|
142 |
+
|
143 |
+
def decode_images(self, images):
|
144 |
+
images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
|
145 |
+
images = [Image.fromarray(image) for image in images]
|
146 |
+
return images
|
147 |
+
|
148 |
+
def add_interpolated_images(self, images, interpolated_images):
|
149 |
+
output_images = []
|
150 |
+
for image, interpolated_image in zip(images, interpolated_images):
|
151 |
+
output_images.append(image)
|
152 |
+
output_images.append(interpolated_image)
|
153 |
+
output_images.append(images[-1])
|
154 |
+
return output_images
|
155 |
+
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def interpolate_(self, images, scale=1.0):
|
159 |
+
input_tensor = self.process_images(images)
|
160 |
+
input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
|
161 |
+
input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
162 |
+
flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
|
163 |
+
output_images = self.decode_images(merged[2].cpu())
|
164 |
+
if output_images[0].size != images[0].size:
|
165 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
166 |
+
return output_images
|
167 |
+
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1):
|
171 |
+
# Preprocess
|
172 |
+
processed_images = self.process_images(images)
|
173 |
+
|
174 |
+
for iter in range(num_iter):
|
175 |
+
# Input
|
176 |
+
input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
|
177 |
+
|
178 |
+
# Interpolate
|
179 |
+
output_tensor = []
|
180 |
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
181 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
182 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
183 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
184 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
185 |
+
output_tensor.append(merged[2].cpu())
|
186 |
+
|
187 |
+
# Output
|
188 |
+
output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
|
189 |
+
processed_images = self.add_interpolated_images(processed_images, output_tensor)
|
190 |
+
processed_images = torch.stack(processed_images)
|
191 |
+
|
192 |
+
# To images
|
193 |
+
output_images = self.decode_images(processed_images)
|
194 |
+
if output_images[0].size != images[0].size:
|
195 |
+
output_images = [image.resize(images[0].size) for image in output_images]
|
196 |
+
return output_images
|
197 |
+
|
198 |
+
|
199 |
+
class RIFESmoother(RIFEInterpolater):
|
200 |
+
def __init__(self, model, device="cuda"):
|
201 |
+
super(RIFESmoother, self).__init__(model, device=device)
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def from_model_manager(model_manager):
|
205 |
+
return RIFESmoother(model_manager.RIFE, device=model_manager.device)
|
206 |
+
|
207 |
+
def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
|
208 |
+
output_tensor = []
|
209 |
+
for batch_id in range(0, input_tensor.shape[0], batch_size):
|
210 |
+
batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
|
211 |
+
batch_input_tensor = input_tensor[batch_id: batch_id_]
|
212 |
+
batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
|
213 |
+
flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
|
214 |
+
output_tensor.append(merged[2].cpu())
|
215 |
+
output_tensor = torch.concat(output_tensor, dim=0)
|
216 |
+
return output_tensor
|
217 |
+
|
218 |
+
@torch.no_grad()
|
219 |
+
def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
|
220 |
+
# Preprocess
|
221 |
+
processed_images = self.process_images(rendered_frames)
|
222 |
+
|
223 |
+
for iter in range(num_iter):
|
224 |
+
# Input
|
225 |
+
input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
|
226 |
+
|
227 |
+
# Interpolate
|
228 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
229 |
+
|
230 |
+
# Blend
|
231 |
+
input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
|
232 |
+
output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
|
233 |
+
|
234 |
+
# Add to frames
|
235 |
+
processed_images[1:-1] = output_tensor
|
236 |
+
|
237 |
+
# To images
|
238 |
+
output_images = self.decode_images(processed_images)
|
239 |
+
if output_images[0].size != rendered_frames[0].size:
|
240 |
+
output_images = [image.resize(rendered_frames[0].size) for image in output_images]
|
241 |
+
return output_images
|
diffsynth/models/__init__.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, os
|
2 |
+
from safetensors import safe_open
|
3 |
+
|
4 |
+
from .sd_text_encoder import SDTextEncoder
|
5 |
+
from .sd_unet import SDUNet
|
6 |
+
from .sd_vae_encoder import SDVAEEncoder
|
7 |
+
from .sd_vae_decoder import SDVAEDecoder
|
8 |
+
from .sd_lora import SDLoRA
|
9 |
+
|
10 |
+
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
11 |
+
from .sdxl_unet import SDXLUNet
|
12 |
+
from .sdxl_vae_decoder import SDXLVAEDecoder
|
13 |
+
from .sdxl_vae_encoder import SDXLVAEEncoder
|
14 |
+
|
15 |
+
from .sd_controlnet import SDControlNet
|
16 |
+
|
17 |
+
from .sd_motion import SDMotionModel
|
18 |
+
|
19 |
+
|
20 |
+
class ModelManager:
|
21 |
+
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
22 |
+
self.torch_dtype = torch_dtype
|
23 |
+
self.device = device
|
24 |
+
self.model = {}
|
25 |
+
self.model_path = {}
|
26 |
+
self.textual_inversion_dict = {}
|
27 |
+
|
28 |
+
def is_RIFE(self, state_dict):
|
29 |
+
param_name = "block_tea.convblock3.0.1.weight"
|
30 |
+
return param_name in state_dict or ("module." + param_name) in state_dict
|
31 |
+
|
32 |
+
def is_beautiful_prompt(self, state_dict):
|
33 |
+
param_name = "transformer.h.9.self_attention.query_key_value.weight"
|
34 |
+
return param_name in state_dict
|
35 |
+
|
36 |
+
def is_stabe_diffusion_xl(self, state_dict):
|
37 |
+
param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
|
38 |
+
return param_name in state_dict
|
39 |
+
|
40 |
+
def is_stable_diffusion(self, state_dict):
|
41 |
+
if self.is_stabe_diffusion_xl(state_dict):
|
42 |
+
return False
|
43 |
+
param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
|
44 |
+
return param_name in state_dict
|
45 |
+
|
46 |
+
def is_controlnet(self, state_dict):
|
47 |
+
param_name = "control_model.time_embed.0.weight"
|
48 |
+
return param_name in state_dict
|
49 |
+
|
50 |
+
def is_animatediff(self, state_dict):
|
51 |
+
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
52 |
+
return param_name in state_dict
|
53 |
+
|
54 |
+
def is_sd_lora(self, state_dict):
|
55 |
+
param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
|
56 |
+
return param_name in state_dict
|
57 |
+
|
58 |
+
def is_translator(self, state_dict):
|
59 |
+
param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
|
60 |
+
return param_name in state_dict and len(state_dict) == 254
|
61 |
+
|
62 |
+
def load_stable_diffusion(self, state_dict, components=None, file_path=""):
|
63 |
+
component_dict = {
|
64 |
+
"text_encoder": SDTextEncoder,
|
65 |
+
"unet": SDUNet,
|
66 |
+
"vae_decoder": SDVAEDecoder,
|
67 |
+
"vae_encoder": SDVAEEncoder,
|
68 |
+
"refiner": SDXLUNet,
|
69 |
+
}
|
70 |
+
if components is None:
|
71 |
+
components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
|
72 |
+
for component in components:
|
73 |
+
if component == "text_encoder":
|
74 |
+
# Add additional token embeddings to text encoder
|
75 |
+
token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
|
76 |
+
for keyword in self.textual_inversion_dict:
|
77 |
+
_, embeddings = self.textual_inversion_dict[keyword]
|
78 |
+
token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
|
79 |
+
token_embeddings = torch.concat(token_embeddings, dim=0)
|
80 |
+
state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
|
81 |
+
self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
|
82 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
83 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
84 |
+
else:
|
85 |
+
self.model[component] = component_dict[component]()
|
86 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
87 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
88 |
+
self.model_path[component] = file_path
|
89 |
+
|
90 |
+
def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
|
91 |
+
component_dict = {
|
92 |
+
"text_encoder": SDXLTextEncoder,
|
93 |
+
"text_encoder_2": SDXLTextEncoder2,
|
94 |
+
"unet": SDXLUNet,
|
95 |
+
"vae_decoder": SDXLVAEDecoder,
|
96 |
+
"vae_encoder": SDXLVAEEncoder,
|
97 |
+
}
|
98 |
+
if components is None:
|
99 |
+
components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
|
100 |
+
for component in components:
|
101 |
+
self.model[component] = component_dict[component]()
|
102 |
+
self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
|
103 |
+
if component in ["vae_decoder", "vae_encoder"]:
|
104 |
+
# These two model will output nan when float16 is enabled.
|
105 |
+
# The precision problem happens in the last three resnet blocks.
|
106 |
+
# I do not know how to solve this problem.
|
107 |
+
self.model[component].to(torch.float32).to(self.device)
|
108 |
+
else:
|
109 |
+
self.model[component].to(self.torch_dtype).to(self.device)
|
110 |
+
self.model_path[component] = file_path
|
111 |
+
|
112 |
+
def load_controlnet(self, state_dict, file_path=""):
|
113 |
+
component = "controlnet"
|
114 |
+
if component not in self.model:
|
115 |
+
self.model[component] = []
|
116 |
+
self.model_path[component] = []
|
117 |
+
model = SDControlNet()
|
118 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
119 |
+
model.to(self.torch_dtype).to(self.device)
|
120 |
+
self.model[component].append(model)
|
121 |
+
self.model_path[component].append(file_path)
|
122 |
+
|
123 |
+
def load_animatediff(self, state_dict, file_path=""):
|
124 |
+
component = "motion_modules"
|
125 |
+
model = SDMotionModel()
|
126 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
127 |
+
model.to(self.torch_dtype).to(self.device)
|
128 |
+
self.model[component] = model
|
129 |
+
self.model_path[component] = file_path
|
130 |
+
|
131 |
+
def load_beautiful_prompt(self, state_dict, file_path=""):
|
132 |
+
component = "beautiful_prompt"
|
133 |
+
from transformers import AutoModelForCausalLM
|
134 |
+
model_folder = os.path.dirname(file_path)
|
135 |
+
model = AutoModelForCausalLM.from_pretrained(
|
136 |
+
model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
|
137 |
+
).to(self.device).eval()
|
138 |
+
self.model[component] = model
|
139 |
+
self.model_path[component] = file_path
|
140 |
+
|
141 |
+
def load_RIFE(self, state_dict, file_path=""):
|
142 |
+
component = "RIFE"
|
143 |
+
from ..extensions.RIFE import IFNet
|
144 |
+
model = IFNet().eval()
|
145 |
+
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
146 |
+
model.to(torch.float32).to(self.device)
|
147 |
+
self.model[component] = model
|
148 |
+
self.model_path[component] = file_path
|
149 |
+
|
150 |
+
def load_sd_lora(self, state_dict, alpha):
|
151 |
+
SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
|
152 |
+
SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
|
153 |
+
|
154 |
+
def load_translator(self, state_dict, file_path=""):
|
155 |
+
# This model is lightweight, we do not place it on GPU.
|
156 |
+
component = "translator"
|
157 |
+
from transformers import AutoModelForSeq2SeqLM
|
158 |
+
model_folder = os.path.dirname(file_path)
|
159 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
|
160 |
+
self.model[component] = model
|
161 |
+
self.model_path[component] = file_path
|
162 |
+
|
163 |
+
def search_for_embeddings(self, state_dict):
|
164 |
+
embeddings = []
|
165 |
+
for k in state_dict:
|
166 |
+
if isinstance(state_dict[k], torch.Tensor):
|
167 |
+
embeddings.append(state_dict[k])
|
168 |
+
elif isinstance(state_dict[k], dict):
|
169 |
+
embeddings += self.search_for_embeddings(state_dict[k])
|
170 |
+
return embeddings
|
171 |
+
|
172 |
+
def load_textual_inversions(self, folder):
|
173 |
+
# Store additional tokens here
|
174 |
+
self.textual_inversion_dict = {}
|
175 |
+
|
176 |
+
# Load every textual inversion file
|
177 |
+
for file_name in os.listdir(folder):
|
178 |
+
if file_name.endswith(".txt"):
|
179 |
+
continue
|
180 |
+
keyword = os.path.splitext(file_name)[0]
|
181 |
+
state_dict = load_state_dict(os.path.join(folder, file_name))
|
182 |
+
|
183 |
+
# Search for embeddings
|
184 |
+
for embeddings in self.search_for_embeddings(state_dict):
|
185 |
+
if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
|
186 |
+
tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
|
187 |
+
self.textual_inversion_dict[keyword] = (tokens, embeddings)
|
188 |
+
break
|
189 |
+
|
190 |
+
def load_model(self, file_path, components=None, lora_alphas=[]):
|
191 |
+
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
192 |
+
if self.is_animatediff(state_dict):
|
193 |
+
self.load_animatediff(state_dict, file_path=file_path)
|
194 |
+
elif self.is_controlnet(state_dict):
|
195 |
+
self.load_controlnet(state_dict, file_path=file_path)
|
196 |
+
elif self.is_stabe_diffusion_xl(state_dict):
|
197 |
+
self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
|
198 |
+
elif self.is_stable_diffusion(state_dict):
|
199 |
+
self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
|
200 |
+
elif self.is_sd_lora(state_dict):
|
201 |
+
self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
|
202 |
+
elif self.is_beautiful_prompt(state_dict):
|
203 |
+
self.load_beautiful_prompt(state_dict, file_path=file_path)
|
204 |
+
elif self.is_RIFE(state_dict):
|
205 |
+
self.load_RIFE(state_dict, file_path=file_path)
|
206 |
+
elif self.is_translator(state_dict):
|
207 |
+
self.load_translator(state_dict, file_path=file_path)
|
208 |
+
|
209 |
+
def load_models(self, file_path_list, lora_alphas=[]):
|
210 |
+
for file_path in file_path_list:
|
211 |
+
self.load_model(file_path, lora_alphas=lora_alphas)
|
212 |
+
|
213 |
+
def to(self, device):
|
214 |
+
for component in self.model:
|
215 |
+
if isinstance(self.model[component], list):
|
216 |
+
for model in self.model[component]:
|
217 |
+
model.to(device)
|
218 |
+
else:
|
219 |
+
self.model[component].to(device)
|
220 |
+
torch.cuda.empty_cache()
|
221 |
+
|
222 |
+
def get_model_with_model_path(self, model_path):
|
223 |
+
for component in self.model_path:
|
224 |
+
if isinstance(self.model_path[component], str):
|
225 |
+
if os.path.samefile(self.model_path[component], model_path):
|
226 |
+
return self.model[component]
|
227 |
+
elif isinstance(self.model_path[component], list):
|
228 |
+
for i, model_path_ in enumerate(self.model_path[component]):
|
229 |
+
if os.path.samefile(model_path_, model_path):
|
230 |
+
return self.model[component][i]
|
231 |
+
raise ValueError(f"Please load model {model_path} before you use it.")
|
232 |
+
|
233 |
+
def __getattr__(self, __name):
|
234 |
+
if __name in self.model:
|
235 |
+
return self.model[__name]
|
236 |
+
else:
|
237 |
+
return super.__getattribute__(__name)
|
238 |
+
|
239 |
+
|
240 |
+
def load_state_dict(file_path, torch_dtype=None):
|
241 |
+
if file_path.endswith(".safetensors"):
|
242 |
+
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
|
243 |
+
else:
|
244 |
+
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
|
245 |
+
|
246 |
+
|
247 |
+
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
248 |
+
state_dict = {}
|
249 |
+
with safe_open(file_path, framework="pt", device="cpu") as f:
|
250 |
+
for k in f.keys():
|
251 |
+
state_dict[k] = f.get_tensor(k)
|
252 |
+
if torch_dtype is not None:
|
253 |
+
state_dict[k] = state_dict[k].to(torch_dtype)
|
254 |
+
return state_dict
|
255 |
+
|
256 |
+
|
257 |
+
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
258 |
+
state_dict = torch.load(file_path, map_location="cpu")
|
259 |
+
if torch_dtype is not None:
|
260 |
+
state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
|
261 |
+
return state_dict
|
262 |
+
|
263 |
+
|
264 |
+
def search_parameter(param, state_dict):
|
265 |
+
for name, param_ in state_dict.items():
|
266 |
+
if param.numel() == param_.numel():
|
267 |
+
if param.shape == param_.shape:
|
268 |
+
if torch.dist(param, param_) < 1e-6:
|
269 |
+
return name
|
270 |
+
else:
|
271 |
+
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
|
272 |
+
return name
|
273 |
+
return None
|
274 |
+
|
275 |
+
|
276 |
+
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
277 |
+
matched_keys = set()
|
278 |
+
with torch.no_grad():
|
279 |
+
for name in source_state_dict:
|
280 |
+
rename = search_parameter(source_state_dict[name], target_state_dict)
|
281 |
+
if rename is not None:
|
282 |
+
print(f'"{name}": "{rename}",')
|
283 |
+
matched_keys.add(rename)
|
284 |
+
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
285 |
+
length = source_state_dict[name].shape[0] // 3
|
286 |
+
rename = []
|
287 |
+
for i in range(3):
|
288 |
+
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
289 |
+
if None not in rename:
|
290 |
+
print(f'"{name}": {rename},')
|
291 |
+
for rename_ in rename:
|
292 |
+
matched_keys.add(rename_)
|
293 |
+
for name in target_state_dict:
|
294 |
+
if name not in matched_keys:
|
295 |
+
print("Cannot find", name, target_state_dict[name].shape)
|
diffsynth/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (10.9 kB). View file
|
|
diffsynth/models/__pycache__/attention.cpython-39.pyc
ADDED
Binary file (2.44 kB). View file
|
|
diffsynth/models/__pycache__/sd_controlnet.cpython-39.pyc
ADDED
Binary file (39.8 kB). View file
|
|
diffsynth/models/__pycache__/sd_lora.cpython-39.pyc
ADDED
Binary file (2.49 kB). View file
|
|
diffsynth/models/__pycache__/sd_motion.cpython-39.pyc
ADDED
Binary file (7.05 kB). View file
|
|
diffsynth/models/__pycache__/sd_text_encoder.cpython-39.pyc
ADDED
Binary file (26 kB). View file
|
|
diffsynth/models/__pycache__/sd_unet.cpython-39.pyc
ADDED
Binary file (86.1 kB). View file
|
|
diffsynth/models/__pycache__/sd_vae_decoder.cpython-39.pyc
ADDED
Binary file (17.3 kB). View file
|
|
diffsynth/models/__pycache__/sd_vae_encoder.cpython-39.pyc
ADDED
Binary file (13.7 kB). View file
|
|
diffsynth/models/__pycache__/sdxl_text_encoder.cpython-39.pyc
ADDED
Binary file (67.3 kB). View file
|
|
diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc
ADDED
Binary file (213 kB). View file
|
|
diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-39.pyc
ADDED
Binary file (1.11 kB). View file
|
|
diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-39.pyc
ADDED
Binary file (1.11 kB). View file
|
|
diffsynth/models/__pycache__/tiler.cpython-39.pyc
ADDED
Binary file (3.01 kB). View file
|
|
diffsynth/models/attention.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
|
4 |
+
|
5 |
+
def low_version_attention(query, key, value, attn_bias=None):
|
6 |
+
scale = 1 / query.shape[-1] ** 0.5
|
7 |
+
query = query * scale
|
8 |
+
attn = torch.matmul(query, key.transpose(-2, -1))
|
9 |
+
if attn_bias is not None:
|
10 |
+
attn = attn + attn_bias
|
11 |
+
attn = attn.softmax(-1)
|
12 |
+
return attn @ value
|
13 |
+
|
14 |
+
|
15 |
+
class Attention(torch.nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
18 |
+
super().__init__()
|
19 |
+
dim_inner = head_dim * num_heads
|
20 |
+
kv_dim = kv_dim if kv_dim is not None else q_dim
|
21 |
+
self.num_heads = num_heads
|
22 |
+
self.head_dim = head_dim
|
23 |
+
|
24 |
+
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
25 |
+
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
26 |
+
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
27 |
+
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
28 |
+
|
29 |
+
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
30 |
+
if encoder_hidden_states is None:
|
31 |
+
encoder_hidden_states = hidden_states
|
32 |
+
|
33 |
+
batch_size = encoder_hidden_states.shape[0]
|
34 |
+
|
35 |
+
q = self.to_q(hidden_states)
|
36 |
+
k = self.to_k(encoder_hidden_states)
|
37 |
+
v = self.to_v(encoder_hidden_states)
|
38 |
+
|
39 |
+
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
40 |
+
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
41 |
+
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
42 |
+
|
43 |
+
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
44 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
45 |
+
hidden_states = hidden_states.to(q.dtype)
|
46 |
+
|
47 |
+
hidden_states = self.to_out(hidden_states)
|
48 |
+
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
52 |
+
if encoder_hidden_states is None:
|
53 |
+
encoder_hidden_states = hidden_states
|
54 |
+
|
55 |
+
q = self.to_q(hidden_states)
|
56 |
+
k = self.to_k(encoder_hidden_states)
|
57 |
+
v = self.to_v(encoder_hidden_states)
|
58 |
+
|
59 |
+
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
60 |
+
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
61 |
+
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
62 |
+
|
63 |
+
if attn_mask is not None:
|
64 |
+
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
65 |
+
else:
|
66 |
+
import xformers.ops as xops
|
67 |
+
hidden_states = xops.memory_efficient_attention(q, k, v)
|
68 |
+
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
69 |
+
|
70 |
+
hidden_states = hidden_states.to(q.dtype)
|
71 |
+
hidden_states = self.to_out(hidden_states)
|
72 |
+
|
73 |
+
return hidden_states
|
74 |
+
|
75 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
76 |
+
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)
|
diffsynth/models/sd_controlnet.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
3 |
+
from .tiler import TileWorker
|
4 |
+
|
5 |
+
|
6 |
+
class ControlNetConditioningLayer(torch.nn.Module):
|
7 |
+
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
8 |
+
super().__init__()
|
9 |
+
self.blocks = torch.nn.ModuleList([])
|
10 |
+
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
11 |
+
self.blocks.append(torch.nn.SiLU())
|
12 |
+
for i in range(1, len(channels) - 2):
|
13 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
14 |
+
self.blocks.append(torch.nn.SiLU())
|
15 |
+
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
16 |
+
self.blocks.append(torch.nn.SiLU())
|
17 |
+
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
18 |
+
|
19 |
+
def forward(self, conditioning):
|
20 |
+
for block in self.blocks:
|
21 |
+
conditioning = block(conditioning)
|
22 |
+
return conditioning
|
23 |
+
|
24 |
+
|
25 |
+
class SDControlNet(torch.nn.Module):
|
26 |
+
def __init__(self, global_pool=False):
|
27 |
+
super().__init__()
|
28 |
+
self.time_proj = Timesteps(320)
|
29 |
+
self.time_embedding = torch.nn.Sequential(
|
30 |
+
torch.nn.Linear(320, 1280),
|
31 |
+
torch.nn.SiLU(),
|
32 |
+
torch.nn.Linear(1280, 1280)
|
33 |
+
)
|
34 |
+
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
35 |
+
|
36 |
+
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
37 |
+
|
38 |
+
self.blocks = torch.nn.ModuleList([
|
39 |
+
# CrossAttnDownBlock2D
|
40 |
+
ResnetBlock(320, 320, 1280),
|
41 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
42 |
+
PushBlock(),
|
43 |
+
ResnetBlock(320, 320, 1280),
|
44 |
+
AttentionBlock(8, 40, 320, 1, 768),
|
45 |
+
PushBlock(),
|
46 |
+
DownSampler(320),
|
47 |
+
PushBlock(),
|
48 |
+
# CrossAttnDownBlock2D
|
49 |
+
ResnetBlock(320, 640, 1280),
|
50 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
51 |
+
PushBlock(),
|
52 |
+
ResnetBlock(640, 640, 1280),
|
53 |
+
AttentionBlock(8, 80, 640, 1, 768),
|
54 |
+
PushBlock(),
|
55 |
+
DownSampler(640),
|
56 |
+
PushBlock(),
|
57 |
+
# CrossAttnDownBlock2D
|
58 |
+
ResnetBlock(640, 1280, 1280),
|
59 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
60 |
+
PushBlock(),
|
61 |
+
ResnetBlock(1280, 1280, 1280),
|
62 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
63 |
+
PushBlock(),
|
64 |
+
DownSampler(1280),
|
65 |
+
PushBlock(),
|
66 |
+
# DownBlock2D
|
67 |
+
ResnetBlock(1280, 1280, 1280),
|
68 |
+
PushBlock(),
|
69 |
+
ResnetBlock(1280, 1280, 1280),
|
70 |
+
PushBlock(),
|
71 |
+
# UNetMidBlock2DCrossAttn
|
72 |
+
ResnetBlock(1280, 1280, 1280),
|
73 |
+
AttentionBlock(8, 160, 1280, 1, 768),
|
74 |
+
ResnetBlock(1280, 1280, 1280),
|
75 |
+
PushBlock()
|
76 |
+
])
|
77 |
+
|
78 |
+
self.controlnet_blocks = torch.nn.ModuleList([
|
79 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
80 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
81 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
82 |
+
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
83 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
84 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
85 |
+
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
86 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
87 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
88 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
89 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
90 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
91 |
+
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
92 |
+
])
|
93 |
+
|
94 |
+
self.global_pool = global_pool
|
95 |
+
|
96 |
+
def forward(
|
97 |
+
self,
|
98 |
+
sample, timestep, encoder_hidden_states, conditioning,
|
99 |
+
tiled=False, tile_size=64, tile_stride=32,
|
100 |
+
):
|
101 |
+
# 1. time
|
102 |
+
time_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
103 |
+
time_emb = self.time_embedding(time_emb)
|
104 |
+
time_emb = time_emb.repeat(sample.shape[0], 1)
|
105 |
+
|
106 |
+
# 2. pre-process
|
107 |
+
height, width = sample.shape[2], sample.shape[3]
|
108 |
+
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
109 |
+
text_emb = encoder_hidden_states
|
110 |
+
res_stack = [hidden_states]
|
111 |
+
|
112 |
+
# 3. blocks
|
113 |
+
for i, block in enumerate(self.blocks):
|
114 |
+
if tiled and not isinstance(block, PushBlock):
|
115 |
+
_, _, inter_height, _ = hidden_states.shape
|
116 |
+
resize_scale = inter_height / height
|
117 |
+
hidden_states = TileWorker().tiled_forward(
|
118 |
+
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
119 |
+
hidden_states,
|
120 |
+
int(tile_size * resize_scale),
|
121 |
+
int(tile_stride * resize_scale),
|
122 |
+
tile_device=hidden_states.device,
|
123 |
+
tile_dtype=hidden_states.dtype
|
124 |
+
)
|
125 |
+
else:
|
126 |
+
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
127 |
+
|
128 |
+
# 4. ControlNet blocks
|
129 |
+
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
130 |
+
|
131 |
+
# pool
|
132 |
+
if self.global_pool:
|
133 |
+
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
134 |
+
|
135 |
+
return controlnet_res_stack
|
136 |
+
|
137 |
+
def state_dict_converter(self):
|
138 |
+
return SDControlNetStateDictConverter()
|
139 |
+
|
140 |
+
|
141 |
+
class SDControlNetStateDictConverter:
|
142 |
+
def __init__(self):
|
143 |
+
pass
|
144 |
+
|
145 |
+
def from_diffusers(self, state_dict):
|
146 |
+
# architecture
|
147 |
+
block_types = [
|
148 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
149 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
150 |
+
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
151 |
+
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
152 |
+
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
153 |
+
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
154 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
155 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
156 |
+
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
157 |
+
]
|
158 |
+
|
159 |
+
# controlnet_rename_dict
|
160 |
+
controlnet_rename_dict = {
|
161 |
+
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
162 |
+
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
163 |
+
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
164 |
+
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
165 |
+
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
166 |
+
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
167 |
+
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
168 |
+
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
169 |
+
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
170 |
+
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
171 |
+
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
172 |
+
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
173 |
+
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
174 |
+
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
175 |
+
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
176 |
+
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
177 |
+
}
|
178 |
+
|
179 |
+
# Rename each parameter
|
180 |
+
name_list = sorted([name for name in state_dict])
|
181 |
+
rename_dict = {}
|
182 |
+
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
183 |
+
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
184 |
+
for name in name_list:
|
185 |
+
names = name.split(".")
|
186 |
+
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
187 |
+
pass
|
188 |
+
elif name in controlnet_rename_dict:
|
189 |
+
names = controlnet_rename_dict[name].split(".")
|
190 |
+
elif names[0] == "controlnet_down_blocks":
|
191 |
+
names[0] = "controlnet_blocks"
|
192 |
+
elif names[0] == "controlnet_mid_block":
|
193 |
+
names = ["controlnet_blocks", "12", names[-1]]
|
194 |
+
elif names[0] in ["time_embedding", "add_embedding"]:
|
195 |
+
if names[0] == "add_embedding":
|
196 |
+
names[0] = "add_time_embedding"
|
197 |
+
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
198 |
+
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
199 |
+
if names[0] == "mid_block":
|
200 |
+
names.insert(1, "0")
|
201 |
+
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
202 |
+
block_type_with_id = ".".join(names[:4])
|
203 |
+
if block_type_with_id != last_block_type_with_id[block_type]:
|
204 |
+
block_id[block_type] += 1
|
205 |
+
last_block_type_with_id[block_type] = block_type_with_id
|
206 |
+
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
207 |
+
block_id[block_type] += 1
|
208 |
+
block_type_with_id = ".".join(names[:4])
|
209 |
+
names = ["blocks", str(block_id[block_type])] + names[4:]
|
210 |
+
if "ff" in names:
|
211 |
+
ff_index = names.index("ff")
|
212 |
+
component = ".".join(names[ff_index:ff_index+3])
|
213 |
+
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
214 |
+
names = names[:ff_index] + [component] + names[ff_index+3:]
|
215 |
+
if "to_out" in names:
|
216 |
+
names.pop(names.index("to_out") + 1)
|
217 |
+
else:
|
218 |
+
raise ValueError(f"Unknown parameters: {name}")
|
219 |
+
rename_dict[name] = ".".join(names)
|
220 |
+
|
221 |
+
# Convert state_dict
|
222 |
+
state_dict_ = {}
|
223 |
+
for name, param in state_dict.items():
|
224 |
+
if ".proj_in." in name or ".proj_out." in name:
|
225 |
+
param = param.squeeze()
|
226 |
+
if rename_dict[name] in [
|
227 |
+
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
228 |
+
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
229 |
+
]:
|
230 |
+
continue
|
231 |
+
state_dict_[rename_dict[name]] = param
|
232 |
+
return state_dict_
|
233 |
+
|
234 |
+
def from_civitai(self, state_dict):
|
235 |
+
rename_dict = {
|
236 |
+
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
237 |
+
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
238 |
+
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
239 |
+
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
240 |
+
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
241 |
+
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
242 |
+
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
243 |
+
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
244 |
+
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
245 |
+
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
246 |
+
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
247 |
+
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
248 |
+
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
249 |
+
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
250 |
+
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
251 |
+
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
252 |
+
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
253 |
+
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
254 |
+
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
255 |
+
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
256 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
257 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
258 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
259 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
260 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
261 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
262 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
263 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
264 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
265 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
266 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
267 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
268 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
269 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
270 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
271 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
272 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
273 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
274 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
275 |
+
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
276 |
+
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
277 |
+
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
278 |
+
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
279 |
+
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
280 |
+
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
281 |
+
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
282 |
+
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
283 |
+
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
284 |
+
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
285 |
+
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
286 |
+
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
287 |
+
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
288 |
+
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
289 |
+
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
290 |
+
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
291 |
+
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
292 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
293 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
294 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
295 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
296 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
297 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
298 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
299 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
300 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
301 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
302 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
303 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
304 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
305 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
306 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
307 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
308 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
309 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
310 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
311 |
+
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
312 |
+
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
313 |
+
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
314 |
+
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
315 |
+
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
316 |
+
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
317 |
+
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
318 |
+
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
319 |
+
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
320 |
+
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
321 |
+
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
322 |
+
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
323 |
+
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
324 |
+
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
325 |
+
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
326 |
+
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
327 |
+
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
328 |
+
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
329 |
+
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
330 |
+
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
331 |
+
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
332 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
333 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
334 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
335 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
336 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
337 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
338 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
339 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
340 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
341 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
342 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
343 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
344 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
345 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
346 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
347 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
348 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
349 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
350 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
351 |
+
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
352 |
+
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
353 |
+
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
354 |
+
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
355 |
+
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
356 |
+
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
357 |
+
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
358 |
+
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
359 |
+
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
360 |
+
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
361 |
+
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
362 |
+
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
363 |
+
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
364 |
+
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
365 |
+
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
366 |
+
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
367 |
+
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
368 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
369 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
370 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
371 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
372 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
373 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
374 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
375 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
376 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
377 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
378 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
379 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
380 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
381 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
382 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
383 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
384 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
385 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
386 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
387 |
+
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
388 |
+
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
389 |
+
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
390 |
+
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
391 |
+
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
392 |
+
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
393 |
+
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
394 |
+
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
395 |
+
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
396 |
+
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
397 |
+
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
398 |
+
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
399 |
+
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
400 |
+
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
401 |
+
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
402 |
+
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
403 |
+
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
404 |
+
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
405 |
+
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
406 |
+
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
407 |
+
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
408 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
409 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
410 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
411 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
412 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
413 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
414 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
415 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
416 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
417 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
418 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
419 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
420 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
421 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
422 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
423 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
424 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
425 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
426 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
427 |
+
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
428 |
+
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
429 |
+
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
430 |
+
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
431 |
+
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
432 |
+
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
433 |
+
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
434 |
+
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
435 |
+
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
436 |
+
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
437 |
+
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
438 |
+
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
439 |
+
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
440 |
+
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
441 |
+
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
442 |
+
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
443 |
+
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
444 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
445 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
446 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
447 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
448 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
449 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
450 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
451 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
452 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
453 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
454 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
455 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
456 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
457 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
458 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
459 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
460 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
461 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
462 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
463 |
+
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
464 |
+
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
465 |
+
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
466 |
+
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
467 |
+
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
468 |
+
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
469 |
+
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
470 |
+
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
471 |
+
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
472 |
+
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
473 |
+
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
474 |
+
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
475 |
+
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
476 |
+
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
477 |
+
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
478 |
+
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
479 |
+
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
480 |
+
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
481 |
+
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
482 |
+
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
483 |
+
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
484 |
+
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
485 |
+
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
486 |
+
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
487 |
+
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
488 |
+
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
489 |
+
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
490 |
+
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
491 |
+
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
492 |
+
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
493 |
+
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
494 |
+
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
495 |
+
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
496 |
+
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
497 |
+
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
498 |
+
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
499 |
+
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
500 |
+
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
501 |
+
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
502 |
+
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
503 |
+
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
504 |
+
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
505 |
+
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
506 |
+
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
507 |
+
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
508 |
+
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
509 |
+
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
510 |
+
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
511 |
+
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
512 |
+
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
513 |
+
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
514 |
+
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
515 |
+
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
516 |
+
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
517 |
+
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
518 |
+
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
519 |
+
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
520 |
+
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
521 |
+
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
522 |
+
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
523 |
+
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
524 |
+
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
525 |
+
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
526 |
+
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
527 |
+
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
528 |
+
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
529 |
+
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
530 |
+
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
531 |
+
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
532 |
+
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
533 |
+
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
534 |
+
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
535 |
+
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
536 |
+
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
537 |
+
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
538 |
+
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
539 |
+
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
540 |
+
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
541 |
+
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
542 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
543 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
544 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
545 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
546 |
+
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
547 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
548 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
549 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
550 |
+
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
551 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
552 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
553 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
554 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
555 |
+
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
556 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
557 |
+
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
558 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
559 |
+
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
560 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
561 |
+
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
562 |
+
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
563 |
+
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
564 |
+
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
565 |
+
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
566 |
+
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
567 |
+
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
568 |
+
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
569 |
+
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
570 |
+
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
571 |
+
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
572 |
+
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
573 |
+
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
574 |
+
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
575 |
+
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
576 |
+
}
|
577 |
+
state_dict_ = {}
|
578 |
+
for name in state_dict:
|
579 |
+
if name in rename_dict:
|
580 |
+
param = state_dict[name]
|
581 |
+
if ".proj_in." in name or ".proj_out." in name:
|
582 |
+
param = param.squeeze()
|
583 |
+
state_dict_[rename_dict[name]] = param
|
584 |
+
return state_dict_
|