ameerazam08 commited on
Commit
5bf1581
·
verified ·
1 Parent(s): 8efb628

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. DiffSynth_Studio.py +15 -0
  2. README.md +101 -0
  3. configs/stable_diffusion/tokenizer/merges.txt +0 -0
  4. configs/stable_diffusion/tokenizer/special_tokens_map.json +24 -0
  5. configs/stable_diffusion/tokenizer/tokenizer_config.json +34 -0
  6. configs/stable_diffusion/tokenizer/vocab.json +0 -0
  7. configs/stable_diffusion_xl/tokenizer_2/merges.txt +0 -0
  8. configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json +24 -0
  9. configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json +38 -0
  10. configs/stable_diffusion_xl/tokenizer_2/vocab.json +0 -0
  11. diffsynth/__init__.py +6 -0
  12. diffsynth/__pycache__/__init__.cpython-39.pyc +0 -0
  13. diffsynth/controlnets/__init__.py +2 -0
  14. diffsynth/controlnets/__pycache__/__init__.cpython-39.pyc +0 -0
  15. diffsynth/controlnets/__pycache__/controlnet_unit.cpython-39.pyc +0 -0
  16. diffsynth/controlnets/__pycache__/processors.cpython-39.pyc +0 -0
  17. diffsynth/controlnets/controlnet_unit.py +53 -0
  18. diffsynth/controlnets/processors.py +51 -0
  19. diffsynth/data/__init__.py +1 -0
  20. diffsynth/data/__pycache__/__init__.cpython-39.pyc +0 -0
  21. diffsynth/data/__pycache__/video.cpython-39.pyc +0 -0
  22. diffsynth/data/video.py +148 -0
  23. diffsynth/extensions/FastBlend/__init__.py +63 -0
  24. diffsynth/extensions/FastBlend/api.py +397 -0
  25. diffsynth/extensions/FastBlend/cupy_kernels.py +119 -0
  26. diffsynth/extensions/FastBlend/data.py +146 -0
  27. diffsynth/extensions/FastBlend/patch_match.py +298 -0
  28. diffsynth/extensions/FastBlend/runners/__init__.py +4 -0
  29. diffsynth/extensions/FastBlend/runners/accurate.py +35 -0
  30. diffsynth/extensions/FastBlend/runners/balanced.py +46 -0
  31. diffsynth/extensions/FastBlend/runners/fast.py +141 -0
  32. diffsynth/extensions/FastBlend/runners/interpolation.py +121 -0
  33. diffsynth/extensions/RIFE/__init__.py +241 -0
  34. diffsynth/models/__init__.py +295 -0
  35. diffsynth/models/__pycache__/__init__.cpython-39.pyc +0 -0
  36. diffsynth/models/__pycache__/attention.cpython-39.pyc +0 -0
  37. diffsynth/models/__pycache__/sd_controlnet.cpython-39.pyc +0 -0
  38. diffsynth/models/__pycache__/sd_lora.cpython-39.pyc +0 -0
  39. diffsynth/models/__pycache__/sd_motion.cpython-39.pyc +0 -0
  40. diffsynth/models/__pycache__/sd_text_encoder.cpython-39.pyc +0 -0
  41. diffsynth/models/__pycache__/sd_unet.cpython-39.pyc +0 -0
  42. diffsynth/models/__pycache__/sd_vae_decoder.cpython-39.pyc +0 -0
  43. diffsynth/models/__pycache__/sd_vae_encoder.cpython-39.pyc +0 -0
  44. diffsynth/models/__pycache__/sdxl_text_encoder.cpython-39.pyc +0 -0
  45. diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc +0 -0
  46. diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-39.pyc +0 -0
  47. diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-39.pyc +0 -0
  48. diffsynth/models/__pycache__/tiler.cpython-39.pyc +0 -0
  49. diffsynth/models/attention.py +76 -0
  50. diffsynth/models/sd_controlnet.py +584 -0
DiffSynth_Studio.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set web page format
2
+ import streamlit as st
3
+ st.set_page_config(layout="wide")
4
+ # Diasble virtual VRAM on windows system
5
+ import torch
6
+ torch.cuda.set_per_process_memory_fraction(0.999, 0)
7
+
8
+
9
+ st.markdown("""
10
+ # DiffSynth Studio
11
+
12
+ [Source Code](https://github.com/Artiprocher/DiffSynth-Studio)
13
+
14
+ Welcome to DiffSynth Studio.
15
+ """)
README.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSynth Studio
2
+
3
+ ## Introduction
4
+
5
+ DiffSynth is a new Diffusion engine. We have restructured architectures including Text Encoder, UNet, VAE, among others, maintaining compatibility with models from the open-source community while enhancing computational performance. This version is currently in its initial stage, supporting SD and SDXL architectures. In the future, we plan to develop more interesting features based on this new codebase.
6
+
7
+ ## Installation
8
+
9
+ Create Python environment:
10
+
11
+ ```
12
+ conda env create -f environment.yml
13
+ ```
14
+
15
+ We find that sometimes `conda` cannot install `cupy` correctly, please install it manually. See [this document](https://docs.cupy.dev/en/stable/install.html) for more details.
16
+
17
+ Enter the Python environment:
18
+
19
+ ```
20
+ conda activate DiffSynthStudio
21
+ ```
22
+
23
+ ## Usage (in WebUI)
24
+
25
+ ```
26
+ python -m streamlit run Diffsynth_Studio.py
27
+ ```
28
+
29
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/93085557-73f3-4eee-a205-9829591ef954
30
+
31
+ ## Usage (in Python code)
32
+
33
+ ### Example 1: Stable Diffusion
34
+
35
+ We can generate images with very high resolution. Please see `examples/sd_text_to_image.py` for more details.
36
+
37
+ |512*512|1024*1024|2048*2048|4096*4096|
38
+ |-|-|-|-|
39
+ |![512](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/55f679e9-7445-4605-9315-302e93d11370)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/9087a73c-9164-4c58-b2a0-effc694143fb)|![4096](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edee9e71-fc39-4d1c-9ca9-fa52002c67ac)|
40
+
41
+ ### Example 2: Stable Diffusion XL
42
+
43
+ Generate images with Stable Diffusion XL. Please see `examples/sdxl_text_to_image.py` for more details.
44
+
45
+ |1024*1024|2048*2048|
46
+ |-|-|
47
+ |![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
48
+
49
+ ### Example 3: Stable Diffusion XL Turbo
50
+
51
+ Generate images with Stable Diffusion XL Turbo. You can see `examples/sdxl_turbo.py` for more details, but we highly recommend you to use it in the WebUI.
52
+
53
+ |"black car"|"red car"|
54
+ |-|-|
55
+ |![black_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7fbfd803-68d4-44f3-8713-8c925fec47d0)|![black_car_to_red_car](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/aaf886e4-c33c-4fd8-98e2-29eef117ba00)|
56
+
57
+ ### Example 4: Toon Shading (Diffutoon)
58
+
59
+ This example is implemented based on [Diffutoon](https://arxiv.org/abs/2401.16224). This approach is adept for rendering high-resoluton videos with rapid motion. You can easily modify the parameters in the config dict. See `examples/diffutoon_toon_shading.py`.
60
+
61
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-4709-be5e-b39af82404dd
62
+
63
+ ### Example 5: Toon Shading with Editing Signals (Diffutoon)
64
+
65
+ Coming soon.
66
+
67
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/20528af5-5100-474a-8cdc-440b9efdd86c
68
+
69
+ ### Example 6: Toon Shading (in native Python code)
70
+
71
+ This example is provided for developers. If you don't want to use the config to manage parameters, you can see `examples/sd_toon_shading.py` to learn how to use it in native Python code.
72
+
73
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/607c199b-6140-410b-a111-3e4ffb01142c
74
+
75
+ ### Example 7: Text to Video
76
+
77
+ Given a prompt, DiffSynth Studio can generate a video using a Stable Diffusion model and an AnimateDiff model. We can break the limitation of number of frames! See `examples/sd_text_to_video.py`.
78
+
79
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/8f556355-4079-4445-9b48-e9da77699437
80
+
81
+ ### Example 8: Video Stylization
82
+
83
+ We provide an example for video stylization. In this pipeline, the rendered video is completely different from the original video, thus we need a powerful deflickering algorithm. We use FastBlend to implement the deflickering module. Please see `examples/sd_video_rerender.py` for more details.
84
+
85
+ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
86
+
87
+ ### Example 9: Prompt Processing
88
+
89
+ If you are not native English user, we provide translation service for you. Our prompter can translate other language to English and refine it using "BeautifulPrompt" models. Please see `examples/sd_prompt_refining.py` for more details.
90
+
91
+ Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English.
92
+
93
+ |seed=0|seed=1|seed=2|seed=3|
94
+ |-|-|-|-|
95
+ |![0_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/ebb25ca8-7ce1-4d9e-8081-59a867c70c4d)|![1_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a7e79853-3c1a-471a-9c58-c209ec4b76dd)|![2_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/a292b959-a121-481f-b79c-61cc3346f810)|![3_](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1c19b54e-5a6f-4d48-960b-a7b2b149bb4c)|
96
+
97
+ Prompt: "一个漂亮的女孩". The [translation model](https://huggingface.co/Helsinki-NLP/opus-mt-en-zh) will translate it to English. Then the [refining model](https://huggingface.co/alibaba-pai/pai-bloom-1b1-text2prompt-sd) will refine the translated prompt for better visual quality.
98
+
99
+ |seed=0|seed=1|seed=2|seed=3|
100
+ |-|-|-|-|
101
+ |![0](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/778b1bd9-44e0-46ac-a99c-712b3fc9aaa4)|![1](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/c03479b8-2082-4c6e-8e1c-3582b98686f6)|![2](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/edb33d21-3288-4a55-96ca-a4bfe1b50b00)|![3](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/7848cfc1-cad5-4848-8373-41d24e98e584)|
configs/stable_diffusion/tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
configs/stable_diffusion/tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<|endoftext|>",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
configs/stable_diffusion/tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<|startoftext|>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "do_lower_case": true,
12
+ "eos_token": {
13
+ "__type": "AddedToken",
14
+ "content": "<|endoftext|>",
15
+ "lstrip": false,
16
+ "normalized": true,
17
+ "rstrip": false,
18
+ "single_word": false
19
+ },
20
+ "errors": "replace",
21
+ "model_max_length": 77,
22
+ "name_or_path": "openai/clip-vit-large-patch14",
23
+ "pad_token": "<|endoftext|>",
24
+ "special_tokens_map_file": "./special_tokens_map.json",
25
+ "tokenizer_class": "CLIPTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<|endoftext|>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
configs/stable_diffusion/tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/stable_diffusion_xl/tokenizer_2/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
configs/stable_diffusion_xl/tokenizer_2/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
configs/stable_diffusion_xl/tokenizer_2/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
configs/stable_diffusion_xl/tokenizer_2/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
diffsynth/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data import *
2
+ from .models import *
3
+ from .prompts import *
4
+ from .schedulers import *
5
+ from .pipelines import *
6
+ from .controlnets import *
diffsynth/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (279 Bytes). View file
 
diffsynth/controlnets/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .controlnet_unit import ControlNetConfigUnit, ControlNetUnit, MultiControlNetManager
2
+ from .processors import Annotator
diffsynth/controlnets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (329 Bytes). View file
 
diffsynth/controlnets/__pycache__/controlnet_unit.cpython-39.pyc ADDED
Binary file (3.09 kB). View file
 
diffsynth/controlnets/__pycache__/processors.cpython-39.pyc ADDED
Binary file (1.82 kB). View file
 
diffsynth/controlnets/controlnet_unit.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from .processors import Processor_id
4
+
5
+
6
+ class ControlNetConfigUnit:
7
+ def __init__(self, processor_id: Processor_id, model_path, scale=1.0):
8
+ self.processor_id = processor_id
9
+ self.model_path = model_path
10
+ self.scale = scale
11
+
12
+
13
+ class ControlNetUnit:
14
+ def __init__(self, processor, model, scale=1.0):
15
+ self.processor = processor
16
+ self.model = model
17
+ self.scale = scale
18
+
19
+
20
+ class MultiControlNetManager:
21
+ def __init__(self, controlnet_units=[]):
22
+ self.processors = [unit.processor for unit in controlnet_units]
23
+ self.models = [unit.model for unit in controlnet_units]
24
+ self.scales = [unit.scale for unit in controlnet_units]
25
+
26
+ def process_image(self, image, processor_id=None):
27
+ if processor_id is None:
28
+ processed_image = [processor(image) for processor in self.processors]
29
+ else:
30
+ processed_image = [self.processors[processor_id](image)]
31
+ processed_image = torch.concat([
32
+ torch.Tensor(np.array(image_, dtype=np.float32) / 255).permute(2, 0, 1).unsqueeze(0)
33
+ for image_ in processed_image
34
+ ], dim=0)
35
+ return processed_image
36
+
37
+ def __call__(
38
+ self,
39
+ sample, timestep, encoder_hidden_states, conditionings,
40
+ tiled=False, tile_size=64, tile_stride=32
41
+ ):
42
+ res_stack = None
43
+ for conditioning, model, scale in zip(conditionings, self.models, self.scales):
44
+ res_stack_ = model(
45
+ sample, timestep, encoder_hidden_states, conditioning,
46
+ tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
47
+ )
48
+ res_stack_ = [res * scale for res in res_stack_]
49
+ if res_stack is None:
50
+ res_stack = res_stack_
51
+ else:
52
+ res_stack = [i + j for i, j in zip(res_stack, res_stack_)]
53
+ return res_stack
diffsynth/controlnets/processors.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing_extensions import Literal, TypeAlias
2
+ import warnings
3
+ with warnings.catch_warnings():
4
+ warnings.simplefilter("ignore")
5
+ from controlnet_aux.processor import (
6
+ CannyDetector, MidasDetector, HEDdetector, LineartDetector, LineartAnimeDetector, OpenposeDetector
7
+ )
8
+
9
+
10
+ Processor_id: TypeAlias = Literal[
11
+ "canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "tile"
12
+ ]
13
+
14
+ class Annotator:
15
+ def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
16
+ if processor_id == "canny":
17
+ self.processor = CannyDetector()
18
+ elif processor_id == "depth":
19
+ self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
20
+ elif processor_id == "softedge":
21
+ self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
22
+ elif processor_id == "lineart":
23
+ self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
24
+ elif processor_id == "lineart_anime":
25
+ self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
26
+ elif processor_id == "openpose":
27
+ self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
28
+ elif processor_id == "tile":
29
+ self.processor = None
30
+ else:
31
+ raise ValueError(f"Unsupported processor_id: {processor_id}")
32
+
33
+ self.processor_id = processor_id
34
+ self.detect_resolution = detect_resolution
35
+
36
+ def __call__(self, image):
37
+ width, height = image.size
38
+ if self.processor_id == "openpose":
39
+ kwargs = {
40
+ "include_body": True,
41
+ "include_hand": True,
42
+ "include_face": True
43
+ }
44
+ else:
45
+ kwargs = {}
46
+ if self.processor is not None:
47
+ detect_resolution = self.detect_resolution if self.detect_resolution is not None else min(width, height)
48
+ image = self.processor(image, detect_resolution=detect_resolution, image_resolution=min(width, height), **kwargs)
49
+ image = image.resize((width, height))
50
+ return image
51
+
diffsynth/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .video import VideoData, save_video, save_frames
diffsynth/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (242 Bytes). View file
 
diffsynth/data/__pycache__/video.cpython-39.pyc ADDED
Binary file (6.09 kB). View file
 
diffsynth/data/video.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+ from tqdm import tqdm
5
+
6
+
7
+ class LowMemoryVideo:
8
+ def __init__(self, file_name):
9
+ self.reader = imageio.get_reader(file_name)
10
+
11
+ def __len__(self):
12
+ return self.reader.count_frames()
13
+
14
+ def __getitem__(self, item):
15
+ return Image.fromarray(np.array(self.reader.get_data(item))).convert("RGB")
16
+
17
+ def __del__(self):
18
+ self.reader.close()
19
+
20
+
21
+ def split_file_name(file_name):
22
+ result = []
23
+ number = -1
24
+ for i in file_name:
25
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
26
+ if number == -1:
27
+ number = 0
28
+ number = number*10 + ord(i) - ord("0")
29
+ else:
30
+ if number != -1:
31
+ result.append(number)
32
+ number = -1
33
+ result.append(i)
34
+ if number != -1:
35
+ result.append(number)
36
+ result = tuple(result)
37
+ return result
38
+
39
+
40
+ def search_for_images(folder):
41
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
42
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
43
+ file_list = [i[1] for i in sorted(file_list)]
44
+ file_list = [os.path.join(folder, i) for i in file_list]
45
+ return file_list
46
+
47
+
48
+ class LowMemoryImageFolder:
49
+ def __init__(self, folder, file_list=None):
50
+ if file_list is None:
51
+ self.file_list = search_for_images(folder)
52
+ else:
53
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
54
+
55
+ def __len__(self):
56
+ return len(self.file_list)
57
+
58
+ def __getitem__(self, item):
59
+ return Image.open(self.file_list[item]).convert("RGB")
60
+
61
+ def __del__(self):
62
+ pass
63
+
64
+
65
+ def crop_and_resize(image, height, width):
66
+ image = np.array(image)
67
+ image_height, image_width, _ = image.shape
68
+ if image_height / image_width < height / width:
69
+ croped_width = int(image_height / height * width)
70
+ left = (image_width - croped_width) // 2
71
+ image = image[:, left: left+croped_width]
72
+ image = Image.fromarray(image).resize((width, height))
73
+ else:
74
+ croped_height = int(image_width / width * height)
75
+ left = (image_height - croped_height) // 2
76
+ image = image[left: left+croped_height, :]
77
+ image = Image.fromarray(image).resize((width, height))
78
+ return image
79
+
80
+
81
+ class VideoData:
82
+ def __init__(self, video_file=None, image_folder=None, height=None, width=None, **kwargs):
83
+ if video_file is not None:
84
+ self.data_type = "video"
85
+ self.data = LowMemoryVideo(video_file, **kwargs)
86
+ elif image_folder is not None:
87
+ self.data_type = "images"
88
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
89
+ else:
90
+ raise ValueError("Cannot open video or image folder")
91
+ self.length = None
92
+ self.set_shape(height, width)
93
+
94
+ def raw_data(self):
95
+ frames = []
96
+ for i in range(self.__len__()):
97
+ frames.append(self.__getitem__(i))
98
+ return frames
99
+
100
+ def set_length(self, length):
101
+ self.length = length
102
+
103
+ def set_shape(self, height, width):
104
+ self.height = height
105
+ self.width = width
106
+
107
+ def __len__(self):
108
+ if self.length is None:
109
+ return len(self.data)
110
+ else:
111
+ return self.length
112
+
113
+ def shape(self):
114
+ if self.height is not None and self.width is not None:
115
+ return self.height, self.width
116
+ else:
117
+ height, width, _ = self.__getitem__(0).shape
118
+ return height, width
119
+
120
+ def __getitem__(self, item):
121
+ frame = self.data.__getitem__(item)
122
+ width, height = frame.size
123
+ if self.height is not None and self.width is not None:
124
+ if self.height != height or self.width != width:
125
+ frame = crop_and_resize(frame, self.height, self.width)
126
+ return frame
127
+
128
+ def __del__(self):
129
+ pass
130
+
131
+ def save_images(self, folder):
132
+ os.makedirs(folder, exist_ok=True)
133
+ for i in tqdm(range(self.__len__()), desc="Saving images"):
134
+ frame = self.__getitem__(i)
135
+ frame.save(os.path.join(folder, f"{i}.png"))
136
+
137
+
138
+ def save_video(frames, save_path, fps, quality=9):
139
+ writer = imageio.get_writer(save_path, fps=fps, quality=quality)
140
+ for frame in tqdm(frames, desc="Saving video"):
141
+ frame = np.array(frame)
142
+ writer.append_data(frame)
143
+ writer.close()
144
+
145
+ def save_frames(frames, save_path):
146
+ os.makedirs(save_path, exist_ok=True)
147
+ for i, frame in enumerate(tqdm(frames, desc="Saving images")):
148
+ frame.save(os.path.join(save_path, f"{i}.png"))
diffsynth/extensions/FastBlend/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners.fast import TableManager, PyramidPatchMatcher
2
+ from PIL import Image
3
+ import numpy as np
4
+ import cupy as cp
5
+
6
+
7
+ class FastBlendSmoother:
8
+ def __init__(self):
9
+ self.batch_size = 8
10
+ self.window_size = 64
11
+ self.ebsynth_config = {
12
+ "minimum_patch_size": 5,
13
+ "threads_per_block": 8,
14
+ "num_iter": 5,
15
+ "gpu_id": 0,
16
+ "guide_weight": 10.0,
17
+ "initialize": "identity",
18
+ "tracking_window_size": 0,
19
+ }
20
+
21
+ @staticmethod
22
+ def from_model_manager(model_manager):
23
+ # TODO: fetch GPU ID from model_manager
24
+ return FastBlendSmoother()
25
+
26
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config):
27
+ frames_guide = [np.array(frame) for frame in frames_guide]
28
+ frames_style = [np.array(frame) for frame in frames_style]
29
+ table_manager = TableManager()
30
+ patch_match_engine = PyramidPatchMatcher(
31
+ image_height=frames_style[0].shape[0],
32
+ image_width=frames_style[0].shape[1],
33
+ channel=3,
34
+ **ebsynth_config
35
+ )
36
+ # left part
37
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="FastBlend Step 1/4")
38
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
39
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="FastBlend Step 2/4")
40
+ # right part
41
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="FastBlend Step 3/4")
42
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
43
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="FastBlend Step 4/4")[::-1]
44
+ # merge
45
+ frames = []
46
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
47
+ weight_m = -1
48
+ weight = weight_l + weight_m + weight_r
49
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
50
+ frames.append(frame)
51
+ frames = [Image.fromarray(frame.clip(0, 255).astype("uint8")) for frame in frames]
52
+ return frames
53
+
54
+ def __call__(self, rendered_frames, original_frames=None, **kwargs):
55
+ frames = self.run(
56
+ original_frames, rendered_frames,
57
+ self.batch_size, self.window_size, self.ebsynth_config
58
+ )
59
+ mempool = cp.get_default_memory_pool()
60
+ pinned_mempool = cp.get_default_pinned_memory_pool()
61
+ mempool.free_all_blocks()
62
+ pinned_mempool.free_all_blocks()
63
+ return frames
diffsynth/extensions/FastBlend/api.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .runners import AccurateModeRunner, FastModeRunner, BalancedModeRunner, InterpolationModeRunner, InterpolationModeSingleFrameRunner
2
+ from .data import VideoData, get_video_fps, save_video, search_for_images
3
+ import os
4
+ import gradio as gr
5
+
6
+
7
+ def check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder):
8
+ frames_guide = VideoData(video_guide, video_guide_folder)
9
+ frames_style = VideoData(video_style, video_style_folder)
10
+ message = ""
11
+ if len(frames_guide) < len(frames_style):
12
+ message += f"The number of frames mismatches. Only the first {len(frames_guide)} frames of style video will be used.\n"
13
+ frames_style.set_length(len(frames_guide))
14
+ elif len(frames_guide) > len(frames_style):
15
+ message += f"The number of frames mismatches. Only the first {len(frames_style)} frames of guide video will be used.\n"
16
+ frames_guide.set_length(len(frames_style))
17
+ height_guide, width_guide = frames_guide.shape()
18
+ height_style, width_style = frames_style.shape()
19
+ if height_guide != height_style or width_guide != width_style:
20
+ message += f"The shape of frames mismatches. The frames in style video will be resized to (height: {height_guide}, width: {width_guide})\n"
21
+ frames_style.set_shape(height_guide, width_guide)
22
+ return frames_guide, frames_style, message
23
+
24
+
25
+ def smooth_video(
26
+ video_guide,
27
+ video_guide_folder,
28
+ video_style,
29
+ video_style_folder,
30
+ mode,
31
+ window_size,
32
+ batch_size,
33
+ tracking_window_size,
34
+ output_path,
35
+ fps,
36
+ minimum_patch_size,
37
+ num_iter,
38
+ guide_weight,
39
+ initialize,
40
+ progress = None,
41
+ ):
42
+ # input
43
+ frames_guide, frames_style, message = check_input_for_blending(video_guide, video_guide_folder, video_style, video_style_folder)
44
+ if len(message) > 0:
45
+ print(message)
46
+ # output
47
+ if output_path == "":
48
+ if video_style is None:
49
+ output_path = os.path.join(video_style_folder, "output")
50
+ else:
51
+ output_path = os.path.join(os.path.split(video_style)[0], "output")
52
+ os.makedirs(output_path, exist_ok=True)
53
+ print("No valid output_path. Your video will be saved here:", output_path)
54
+ elif not os.path.exists(output_path):
55
+ os.makedirs(output_path, exist_ok=True)
56
+ print("Your video will be saved here:", output_path)
57
+ frames_path = os.path.join(output_path, "frames")
58
+ video_path = os.path.join(output_path, "video.mp4")
59
+ os.makedirs(frames_path, exist_ok=True)
60
+ # process
61
+ if mode == "Fast" or mode == "Balanced":
62
+ tracking_window_size = 0
63
+ ebsynth_config = {
64
+ "minimum_patch_size": minimum_patch_size,
65
+ "threads_per_block": 8,
66
+ "num_iter": num_iter,
67
+ "gpu_id": 0,
68
+ "guide_weight": guide_weight,
69
+ "initialize": initialize,
70
+ "tracking_window_size": tracking_window_size,
71
+ }
72
+ if mode == "Fast":
73
+ FastModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
74
+ elif mode == "Balanced":
75
+ BalancedModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
76
+ elif mode == "Accurate":
77
+ AccurateModeRunner().run(frames_guide, frames_style, batch_size=batch_size, window_size=window_size, ebsynth_config=ebsynth_config, save_path=frames_path)
78
+ # output
79
+ try:
80
+ fps = int(fps)
81
+ except:
82
+ fps = get_video_fps(video_style) if video_style is not None else 30
83
+ print("Fps:", fps)
84
+ print("Saving video...")
85
+ video_path = save_video(frames_path, video_path, num_frames=len(frames_style), fps=fps)
86
+ print("Success!")
87
+ print("Your frames are here:", frames_path)
88
+ print("Your video is here:", video_path)
89
+ return output_path, fps, video_path
90
+
91
+
92
+ class KeyFrameMatcher:
93
+ def __init__(self):
94
+ pass
95
+
96
+ def extract_number_from_filename(self, file_name):
97
+ result = []
98
+ number = -1
99
+ for i in file_name:
100
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
101
+ if number == -1:
102
+ number = 0
103
+ number = number*10 + ord(i) - ord("0")
104
+ else:
105
+ if number != -1:
106
+ result.append(number)
107
+ number = -1
108
+ if number != -1:
109
+ result.append(number)
110
+ result = tuple(result)
111
+ return result
112
+
113
+ def extract_number_from_filenames(self, file_names):
114
+ numbers = [self.extract_number_from_filename(file_name) for file_name in file_names]
115
+ min_length = min(len(i) for i in numbers)
116
+ for i in range(min_length-1, -1, -1):
117
+ if len(set(number[i] for number in numbers))==len(file_names):
118
+ return [number[i] for number in numbers]
119
+ return list(range(len(file_names)))
120
+
121
+ def match_using_filename(self, file_names_a, file_names_b):
122
+ file_names_b_set = set(file_names_b)
123
+ matched_file_name = []
124
+ for file_name in file_names_a:
125
+ if file_name not in file_names_b_set:
126
+ matched_file_name.append(None)
127
+ else:
128
+ matched_file_name.append(file_name)
129
+ return matched_file_name
130
+
131
+ def match_using_numbers(self, file_names_a, file_names_b):
132
+ numbers_a = self.extract_number_from_filenames(file_names_a)
133
+ numbers_b = self.extract_number_from_filenames(file_names_b)
134
+ numbers_b_dict = {number: file_name for number, file_name in zip(numbers_b, file_names_b)}
135
+ matched_file_name = []
136
+ for number in numbers_a:
137
+ if number in numbers_b_dict:
138
+ matched_file_name.append(numbers_b_dict[number])
139
+ else:
140
+ matched_file_name.append(None)
141
+ return matched_file_name
142
+
143
+ def match_filenames(self, file_names_a, file_names_b):
144
+ matched_file_name = self.match_using_filename(file_names_a, file_names_b)
145
+ if sum([i is not None for i in matched_file_name]) > 0:
146
+ return matched_file_name
147
+ matched_file_name = self.match_using_numbers(file_names_a, file_names_b)
148
+ return matched_file_name
149
+
150
+
151
+ def detect_frames(frames_path, keyframes_path):
152
+ if not os.path.exists(frames_path) and not os.path.exists(keyframes_path):
153
+ return "Please input the directory of guide video and rendered frames"
154
+ elif not os.path.exists(frames_path):
155
+ return "Please input the directory of guide video"
156
+ elif not os.path.exists(keyframes_path):
157
+ return "Please input the directory of rendered frames"
158
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
159
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
160
+ if len(frames)==0:
161
+ return f"No images detected in {frames_path}"
162
+ if len(keyframes)==0:
163
+ return f"No images detected in {keyframes_path}"
164
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
165
+ max_filename_length = max([len(i) for i in frames])
166
+ if sum([i is not None for i in matched_keyframes])==0:
167
+ message = ""
168
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
169
+ message += frame + " " * (max_filename_length - len(frame) + 1)
170
+ message += "--> No matched keyframes\n"
171
+ else:
172
+ message = ""
173
+ for frame, matched_keyframe in zip(frames, matched_keyframes):
174
+ message += frame + " " * (max_filename_length - len(frame) + 1)
175
+ if matched_keyframe is None:
176
+ message += "--> [to be rendered]\n"
177
+ else:
178
+ message += f"--> {matched_keyframe}\n"
179
+ return message
180
+
181
+
182
+ def check_input_for_interpolating(frames_path, keyframes_path):
183
+ # search for images
184
+ frames = [os.path.split(i)[-1] for i in search_for_images(frames_path)]
185
+ keyframes = [os.path.split(i)[-1] for i in search_for_images(keyframes_path)]
186
+ # match frames
187
+ matched_keyframes = KeyFrameMatcher().match_filenames(frames, keyframes)
188
+ file_list = [file_name for file_name in matched_keyframes if file_name is not None]
189
+ index_style = [i for i, file_name in enumerate(matched_keyframes) if file_name is not None]
190
+ frames_guide = VideoData(None, frames_path)
191
+ frames_style = VideoData(None, keyframes_path, file_list=file_list)
192
+ # match shape
193
+ message = ""
194
+ height_guide, width_guide = frames_guide.shape()
195
+ height_style, width_style = frames_style.shape()
196
+ if height_guide != height_style or width_guide != width_style:
197
+ message += f"The shape of frames mismatches. The rendered keyframes will be resized to (height: {height_guide}, width: {width_guide})\n"
198
+ frames_style.set_shape(height_guide, width_guide)
199
+ return frames_guide, frames_style, index_style, message
200
+
201
+
202
+ def interpolate_video(
203
+ frames_path,
204
+ keyframes_path,
205
+ output_path,
206
+ fps,
207
+ batch_size,
208
+ tracking_window_size,
209
+ minimum_patch_size,
210
+ num_iter,
211
+ guide_weight,
212
+ initialize,
213
+ progress = None,
214
+ ):
215
+ # input
216
+ frames_guide, frames_style, index_style, message = check_input_for_interpolating(frames_path, keyframes_path)
217
+ if len(message) > 0:
218
+ print(message)
219
+ # output
220
+ if output_path == "":
221
+ output_path = os.path.join(keyframes_path, "output")
222
+ os.makedirs(output_path, exist_ok=True)
223
+ print("No valid output_path. Your video will be saved here:", output_path)
224
+ elif not os.path.exists(output_path):
225
+ os.makedirs(output_path, exist_ok=True)
226
+ print("Your video will be saved here:", output_path)
227
+ output_frames_path = os.path.join(output_path, "frames")
228
+ output_video_path = os.path.join(output_path, "video.mp4")
229
+ os.makedirs(output_frames_path, exist_ok=True)
230
+ # process
231
+ ebsynth_config = {
232
+ "minimum_patch_size": minimum_patch_size,
233
+ "threads_per_block": 8,
234
+ "num_iter": num_iter,
235
+ "gpu_id": 0,
236
+ "guide_weight": guide_weight,
237
+ "initialize": initialize,
238
+ "tracking_window_size": tracking_window_size
239
+ }
240
+ if len(index_style)==1:
241
+ InterpolationModeSingleFrameRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
242
+ else:
243
+ InterpolationModeRunner().run(frames_guide, frames_style, index_style, batch_size=batch_size, ebsynth_config=ebsynth_config, save_path=output_frames_path)
244
+ try:
245
+ fps = int(fps)
246
+ except:
247
+ fps = 30
248
+ print("Fps:", fps)
249
+ print("Saving video...")
250
+ video_path = save_video(output_frames_path, output_video_path, num_frames=len(frames_guide), fps=fps)
251
+ print("Success!")
252
+ print("Your frames are here:", output_frames_path)
253
+ print("Your video is here:", video_path)
254
+ return output_path, fps, video_path
255
+
256
+
257
+ def on_ui_tabs():
258
+ with gr.Blocks(analytics_enabled=False) as ui_component:
259
+ with gr.Tab("Blend"):
260
+ gr.Markdown("""
261
+ # Blend
262
+
263
+ Given a guide video and a style video, this algorithm will make the style video fluent according to the motion features of the guide video. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/208d902d-6aba-48d7-b7d5-cd120ebd306d) to see the example. Note that this extension doesn't support long videos. Please use short videos (e.g., several seconds). The algorithm is mainly designed for 512*512 resolution. Please use a larger `Minimum patch size` for higher resolution.
264
+ """)
265
+ with gr.Row():
266
+ with gr.Column():
267
+ with gr.Tab("Guide video"):
268
+ video_guide = gr.Video(label="Guide video")
269
+ with gr.Tab("Guide video (images format)"):
270
+ video_guide_folder = gr.Textbox(label="Guide video (images format)", value="")
271
+ with gr.Column():
272
+ with gr.Tab("Style video"):
273
+ video_style = gr.Video(label="Style video")
274
+ with gr.Tab("Style video (images format)"):
275
+ video_style_folder = gr.Textbox(label="Style video (images format)", value="")
276
+ with gr.Column():
277
+ output_path = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of style video")
278
+ fps = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
279
+ video_output = gr.Video(label="Output video", interactive=False, show_share_button=True)
280
+ btn = gr.Button(value="Blend")
281
+ with gr.Row():
282
+ with gr.Column():
283
+ gr.Markdown("# Settings")
284
+ mode = gr.Radio(["Fast", "Balanced", "Accurate"], label="Inference mode", value="Fast", interactive=True)
285
+ window_size = gr.Slider(label="Sliding window size", value=15, minimum=1, maximum=1000, step=1, interactive=True)
286
+ batch_size = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
287
+ tracking_window_size = gr.Slider(label="Tracking window size (only for accurate mode)", value=0, minimum=0, maximum=10, step=1, interactive=True)
288
+ gr.Markdown("## Advanced Settings")
289
+ minimum_patch_size = gr.Slider(label="Minimum patch size (odd number)", value=5, minimum=5, maximum=99, step=2, interactive=True)
290
+ num_iter = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
291
+ guide_weight = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
292
+ initialize = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
293
+ with gr.Column():
294
+ gr.Markdown("""
295
+ # Reference
296
+
297
+ * Output directory: the directory to save the video.
298
+ * Inference mode
299
+
300
+ |Mode|Time|Memory|Quality|Frame by frame output|Description|
301
+ |-|-|-|-|-|-|
302
+ |Fast|■|■■■|■■|No|Blend the frames using a tree-like data structure, which requires much RAM but is fast.|
303
+ |Balanced|■■|■|■■|Yes|Blend the frames naively.|
304
+ |Accurate|■■■|■|■■■|Yes|Blend the frames and align them together for higher video quality. When [batch size] >= [sliding window size] * 2 + 1, the performance is the best.|
305
+
306
+ * Sliding window size: our algorithm will blend the frames in a sliding windows. If the size is n, each frame will be blended with the last n frames and the next n frames. A large sliding window can make the video fluent but sometimes smoggy.
307
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
308
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
309
+ * Advanced settings
310
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. (Default: 5)
311
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
312
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
313
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
314
+ """)
315
+ btn.click(
316
+ smooth_video,
317
+ inputs=[
318
+ video_guide,
319
+ video_guide_folder,
320
+ video_style,
321
+ video_style_folder,
322
+ mode,
323
+ window_size,
324
+ batch_size,
325
+ tracking_window_size,
326
+ output_path,
327
+ fps,
328
+ minimum_patch_size,
329
+ num_iter,
330
+ guide_weight,
331
+ initialize
332
+ ],
333
+ outputs=[output_path, fps, video_output]
334
+ )
335
+ with gr.Tab("Interpolate"):
336
+ gr.Markdown("""
337
+ # Interpolate
338
+
339
+ Given a guide video and some rendered keyframes, this algorithm will render the remaining frames. Click [here](https://github.com/Artiprocher/sd-webui-fastblend/assets/35051019/3490c5b4-8f67-478f-86de-f9adc2ace16a) to see the example. The algorithm is experimental and is only tested for 512*512 resolution.
340
+ """)
341
+ with gr.Row():
342
+ with gr.Column():
343
+ with gr.Row():
344
+ with gr.Column():
345
+ video_guide_folder_ = gr.Textbox(label="Guide video (images format)", value="")
346
+ with gr.Column():
347
+ rendered_keyframes_ = gr.Textbox(label="Rendered keyframes (images format)", value="")
348
+ with gr.Row():
349
+ detected_frames = gr.Textbox(label="Detected frames", value="Please input the directory of guide video and rendered frames", lines=9, max_lines=9, interactive=False)
350
+ video_guide_folder_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
351
+ rendered_keyframes_.change(detect_frames, inputs=[video_guide_folder_, rendered_keyframes_], outputs=detected_frames)
352
+ with gr.Column():
353
+ output_path_ = gr.Textbox(label="Output directory", value="", placeholder="Leave empty to use the directory of rendered keyframes")
354
+ fps_ = gr.Textbox(label="Fps", value="", placeholder="Leave empty to use the default fps")
355
+ video_output_ = gr.Video(label="Output video", interactive=False, show_share_button=True)
356
+ btn_ = gr.Button(value="Interpolate")
357
+ with gr.Row():
358
+ with gr.Column():
359
+ gr.Markdown("# Settings")
360
+ batch_size_ = gr.Slider(label="Batch size", value=8, minimum=1, maximum=128, step=1, interactive=True)
361
+ tracking_window_size_ = gr.Slider(label="Tracking window size", value=0, minimum=0, maximum=10, step=1, interactive=True)
362
+ gr.Markdown("## Advanced Settings")
363
+ minimum_patch_size_ = gr.Slider(label="Minimum patch size (odd number, larger is better)", value=15, minimum=5, maximum=99, step=2, interactive=True)
364
+ num_iter_ = gr.Slider(label="Number of iterations", value=5, minimum=1, maximum=10, step=1, interactive=True)
365
+ guide_weight_ = gr.Slider(label="Guide weight", value=10.0, minimum=0.0, maximum=100.0, step=0.1, interactive=True)
366
+ initialize_ = gr.Radio(["identity", "random"], label="NNF initialization", value="identity", interactive=True)
367
+ with gr.Column():
368
+ gr.Markdown("""
369
+ # Reference
370
+
371
+ * Output directory: the directory to save the video.
372
+ * Batch size: a larger batch size makes the program faster but requires more VRAM.
373
+ * Tracking window size (only for accurate mode): The size of window in which our algorithm tracks moving objects. Empirically, 1 is enough.
374
+ * Advanced settings
375
+ * Minimum patch size (odd number): the minimum patch size used for patch matching. **This parameter should be larger than that in blending. (Default: 15)**
376
+ * Number of iterations: the number of iterations of patch matching. (Default: 5)
377
+ * Guide weight: a parameter that determines how much motion feature applied to the style video. (Default: 10)
378
+ * NNF initialization: how to initialize the NNF (Nearest Neighbor Field). (Default: identity)
379
+ """)
380
+ btn_.click(
381
+ interpolate_video,
382
+ inputs=[
383
+ video_guide_folder_,
384
+ rendered_keyframes_,
385
+ output_path_,
386
+ fps_,
387
+ batch_size_,
388
+ tracking_window_size_,
389
+ minimum_patch_size_,
390
+ num_iter_,
391
+ guide_weight_,
392
+ initialize_,
393
+ ],
394
+ outputs=[output_path_, fps_, video_output_]
395
+ )
396
+
397
+ return [(ui_component, "FastBlend", "FastBlend_ui")]
diffsynth/extensions/FastBlend/cupy_kernels.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cupy as cp
2
+
3
+ remapping_kernel = cp.RawKernel(r'''
4
+ extern "C" __global__
5
+ void remap(
6
+ const int height,
7
+ const int width,
8
+ const int channel,
9
+ const int patch_size,
10
+ const int pad_size,
11
+ const float* source_style,
12
+ const int* nnf,
13
+ float* target_style
14
+ ) {
15
+ const int r = (patch_size - 1) / 2;
16
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
17
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
18
+ if (x >= height or y >= width) return;
19
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
20
+ const int pid = (x + pad_size) * (width + pad_size * 2) + (y + pad_size);
21
+ const int min_px = x < r ? -x : -r;
22
+ const int max_px = x + r > height - 1 ? height - 1 - x : r;
23
+ const int min_py = y < r ? -y : -r;
24
+ const int max_py = y + r > width - 1 ? width - 1 - y : r;
25
+ int num = 0;
26
+ for (int px = min_px; px <= max_px; px++){
27
+ for (int py = min_py; py <= max_py; py++){
28
+ const int nid = (x + px) * width + y + py;
29
+ const int x_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 0] - px;
30
+ const int y_ = nnf[blockIdx.z * height * width * 2 + nid*2 + 1] - py;
31
+ if (x_ < 0 or y_ < 0 or x_ >= height or y_ >= width)continue;
32
+ const int pid_ = (x_ + pad_size) * (width + pad_size * 2) + (y_ + pad_size);
33
+ num++;
34
+ for (int c = 0; c < channel; c++){
35
+ target_style[z + pid * channel + c] += source_style[z + pid_ * channel + c];
36
+ }
37
+ }
38
+ }
39
+ for (int c = 0; c < channel; c++){
40
+ target_style[z + pid * channel + c] /= num;
41
+ }
42
+ }
43
+ ''', 'remap')
44
+
45
+
46
+ patch_error_kernel = cp.RawKernel(r'''
47
+ extern "C" __global__
48
+ void patch_error(
49
+ const int height,
50
+ const int width,
51
+ const int channel,
52
+ const int patch_size,
53
+ const int pad_size,
54
+ const float* source,
55
+ const int* nnf,
56
+ const float* target,
57
+ float* error
58
+ ) {
59
+ const int r = (patch_size - 1) / 2;
60
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
61
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
62
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
63
+ if (x >= height or y >= width) return;
64
+ const int x_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 0];
65
+ const int y_ = nnf[blockIdx.z * height * width * 2 + (x * width + y)*2 + 1];
66
+ float e = 0;
67
+ for (int px = -r; px <= r; px++){
68
+ for (int py = -r; py <= r; py++){
69
+ const int pid = (x + pad_size + px) * (width + pad_size * 2) + y + pad_size + py;
70
+ const int pid_ = (x_ + pad_size + px) * (width + pad_size * 2) + y_ + pad_size + py;
71
+ for (int c = 0; c < channel; c++){
72
+ const float diff = target[z + pid * channel + c] - source[z + pid_ * channel + c];
73
+ e += diff * diff;
74
+ }
75
+ }
76
+ }
77
+ error[blockIdx.z * height * width + x * width + y] = e;
78
+ }
79
+ ''', 'patch_error')
80
+
81
+
82
+ pairwise_patch_error_kernel = cp.RawKernel(r'''
83
+ extern "C" __global__
84
+ void pairwise_patch_error(
85
+ const int height,
86
+ const int width,
87
+ const int channel,
88
+ const int patch_size,
89
+ const int pad_size,
90
+ const float* source_a,
91
+ const int* nnf_a,
92
+ const float* source_b,
93
+ const int* nnf_b,
94
+ float* error
95
+ ) {
96
+ const int r = (patch_size - 1) / 2;
97
+ const int x = blockDim.x * blockIdx.x + threadIdx.x;
98
+ const int y = blockDim.y * blockIdx.y + threadIdx.y;
99
+ const int z = blockIdx.z * (height + pad_size * 2) * (width + pad_size * 2) * channel;
100
+ if (x >= height or y >= width) return;
101
+ const int z_nnf = blockIdx.z * height * width * 2 + (x * width + y) * 2;
102
+ const int x_a = nnf_a[z_nnf + 0];
103
+ const int y_a = nnf_a[z_nnf + 1];
104
+ const int x_b = nnf_b[z_nnf + 0];
105
+ const int y_b = nnf_b[z_nnf + 1];
106
+ float e = 0;
107
+ for (int px = -r; px <= r; px++){
108
+ for (int py = -r; py <= r; py++){
109
+ const int pid_a = (x_a + pad_size + px) * (width + pad_size * 2) + y_a + pad_size + py;
110
+ const int pid_b = (x_b + pad_size + px) * (width + pad_size * 2) + y_b + pad_size + py;
111
+ for (int c = 0; c < channel; c++){
112
+ const float diff = source_a[z + pid_a * channel + c] - source_b[z + pid_b * channel + c];
113
+ e += diff * diff;
114
+ }
115
+ }
116
+ }
117
+ error[blockIdx.z * height * width + x * width + y] = e;
118
+ }
119
+ ''', 'pairwise_patch_error')
diffsynth/extensions/FastBlend/data.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio, os
2
+ import numpy as np
3
+ from PIL import Image
4
+
5
+
6
+ def read_video(file_name):
7
+ reader = imageio.get_reader(file_name)
8
+ video = []
9
+ for frame in reader:
10
+ frame = np.array(frame)
11
+ video.append(frame)
12
+ reader.close()
13
+ return video
14
+
15
+
16
+ def get_video_fps(file_name):
17
+ reader = imageio.get_reader(file_name)
18
+ fps = reader.get_meta_data()["fps"]
19
+ reader.close()
20
+ return fps
21
+
22
+
23
+ def save_video(frames_path, video_path, num_frames, fps):
24
+ writer = imageio.get_writer(video_path, fps=fps, quality=9)
25
+ for i in range(num_frames):
26
+ frame = np.array(Image.open(os.path.join(frames_path, "%05d.png" % i)))
27
+ writer.append_data(frame)
28
+ writer.close()
29
+ return video_path
30
+
31
+
32
+ class LowMemoryVideo:
33
+ def __init__(self, file_name):
34
+ self.reader = imageio.get_reader(file_name)
35
+
36
+ def __len__(self):
37
+ return self.reader.count_frames()
38
+
39
+ def __getitem__(self, item):
40
+ return np.array(self.reader.get_data(item))
41
+
42
+ def __del__(self):
43
+ self.reader.close()
44
+
45
+
46
+ def split_file_name(file_name):
47
+ result = []
48
+ number = -1
49
+ for i in file_name:
50
+ if ord(i)>=ord("0") and ord(i)<=ord("9"):
51
+ if number == -1:
52
+ number = 0
53
+ number = number*10 + ord(i) - ord("0")
54
+ else:
55
+ if number != -1:
56
+ result.append(number)
57
+ number = -1
58
+ result.append(i)
59
+ if number != -1:
60
+ result.append(number)
61
+ result = tuple(result)
62
+ return result
63
+
64
+
65
+ def search_for_images(folder):
66
+ file_list = [i for i in os.listdir(folder) if i.endswith(".jpg") or i.endswith(".png")]
67
+ file_list = [(split_file_name(file_name), file_name) for file_name in file_list]
68
+ file_list = [i[1] for i in sorted(file_list)]
69
+ file_list = [os.path.join(folder, i) for i in file_list]
70
+ return file_list
71
+
72
+
73
+ def read_images(folder):
74
+ file_list = search_for_images(folder)
75
+ frames = [np.array(Image.open(i)) for i in file_list]
76
+ return frames
77
+
78
+
79
+ class LowMemoryImageFolder:
80
+ def __init__(self, folder, file_list=None):
81
+ if file_list is None:
82
+ self.file_list = search_for_images(folder)
83
+ else:
84
+ self.file_list = [os.path.join(folder, file_name) for file_name in file_list]
85
+
86
+ def __len__(self):
87
+ return len(self.file_list)
88
+
89
+ def __getitem__(self, item):
90
+ return np.array(Image.open(self.file_list[item]))
91
+
92
+ def __del__(self):
93
+ pass
94
+
95
+
96
+ class VideoData:
97
+ def __init__(self, video_file, image_folder, **kwargs):
98
+ if video_file is not None:
99
+ self.data_type = "video"
100
+ self.data = LowMemoryVideo(video_file, **kwargs)
101
+ elif image_folder is not None:
102
+ self.data_type = "images"
103
+ self.data = LowMemoryImageFolder(image_folder, **kwargs)
104
+ else:
105
+ raise ValueError("Cannot open video or image folder")
106
+ self.length = None
107
+ self.height = None
108
+ self.width = None
109
+
110
+ def raw_data(self):
111
+ frames = []
112
+ for i in range(self.__len__()):
113
+ frames.append(self.__getitem__(i))
114
+ return frames
115
+
116
+ def set_length(self, length):
117
+ self.length = length
118
+
119
+ def set_shape(self, height, width):
120
+ self.height = height
121
+ self.width = width
122
+
123
+ def __len__(self):
124
+ if self.length is None:
125
+ return len(self.data)
126
+ else:
127
+ return self.length
128
+
129
+ def shape(self):
130
+ if self.height is not None and self.width is not None:
131
+ return self.height, self.width
132
+ else:
133
+ height, width, _ = self.__getitem__(0).shape
134
+ return height, width
135
+
136
+ def __getitem__(self, item):
137
+ frame = self.data.__getitem__(item)
138
+ height, width, _ = frame.shape
139
+ if self.height is not None and self.width is not None:
140
+ if self.height != height or self.width != width:
141
+ frame = Image.fromarray(frame).resize((self.width, self.height))
142
+ frame = np.array(frame)
143
+ return frame
144
+
145
+ def __del__(self):
146
+ pass
diffsynth/extensions/FastBlend/patch_match.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cupy_kernels import remapping_kernel, patch_error_kernel, pairwise_patch_error_kernel
2
+ import numpy as np
3
+ import cupy as cp
4
+ import cv2
5
+
6
+
7
+ class PatchMatcher:
8
+ def __init__(
9
+ self, height, width, channel, minimum_patch_size,
10
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
11
+ random_search_steps=3, random_search_range=4,
12
+ use_mean_target_style=False, use_pairwise_patch_error=False,
13
+ tracking_window_size=0
14
+ ):
15
+ self.height = height
16
+ self.width = width
17
+ self.channel = channel
18
+ self.minimum_patch_size = minimum_patch_size
19
+ self.threads_per_block = threads_per_block
20
+ self.num_iter = num_iter
21
+ self.gpu_id = gpu_id
22
+ self.guide_weight = guide_weight
23
+ self.random_search_steps = random_search_steps
24
+ self.random_search_range = random_search_range
25
+ self.use_mean_target_style = use_mean_target_style
26
+ self.use_pairwise_patch_error = use_pairwise_patch_error
27
+ self.tracking_window_size = tracking_window_size
28
+
29
+ self.patch_size_list = [minimum_patch_size + i*2 for i in range(num_iter)][::-1]
30
+ self.pad_size = self.patch_size_list[0] // 2
31
+ self.grid = (
32
+ (height + threads_per_block - 1) // threads_per_block,
33
+ (width + threads_per_block - 1) // threads_per_block
34
+ )
35
+ self.block = (threads_per_block, threads_per_block)
36
+
37
+ def pad_image(self, image):
38
+ return cp.pad(image, ((0, 0), (self.pad_size, self.pad_size), (self.pad_size, self.pad_size), (0, 0)))
39
+
40
+ def unpad_image(self, image):
41
+ return image[:, self.pad_size: -self.pad_size, self.pad_size: -self.pad_size, :]
42
+
43
+ def apply_nnf_to_image(self, nnf, source):
44
+ batch_size = source.shape[0]
45
+ target = cp.zeros((batch_size, self.height + self.pad_size * 2, self.width + self.pad_size * 2, self.channel), dtype=cp.float32)
46
+ remapping_kernel(
47
+ self.grid + (batch_size,),
48
+ self.block,
49
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target)
50
+ )
51
+ return target
52
+
53
+ def get_patch_error(self, source, nnf, target):
54
+ batch_size = source.shape[0]
55
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
56
+ patch_error_kernel(
57
+ self.grid + (batch_size,),
58
+ self.block,
59
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source, nnf, target, error)
60
+ )
61
+ return error
62
+
63
+ def get_pairwise_patch_error(self, source, nnf):
64
+ batch_size = source.shape[0]//2
65
+ error = cp.zeros((batch_size, self.height, self.width), dtype=cp.float32)
66
+ source_a, nnf_a = source[0::2].copy(), nnf[0::2].copy()
67
+ source_b, nnf_b = source[1::2].copy(), nnf[1::2].copy()
68
+ pairwise_patch_error_kernel(
69
+ self.grid + (batch_size,),
70
+ self.block,
71
+ (self.height, self.width, self.channel, self.patch_size, self.pad_size, source_a, nnf_a, source_b, nnf_b, error)
72
+ )
73
+ error = error.repeat(2, axis=0)
74
+ return error
75
+
76
+ def get_error(self, source_guide, target_guide, source_style, target_style, nnf):
77
+ error_guide = self.get_patch_error(source_guide, nnf, target_guide)
78
+ if self.use_mean_target_style:
79
+ target_style = self.apply_nnf_to_image(nnf, source_style)
80
+ target_style = target_style.mean(axis=0, keepdims=True)
81
+ target_style = target_style.repeat(source_guide.shape[0], axis=0)
82
+ if self.use_pairwise_patch_error:
83
+ error_style = self.get_pairwise_patch_error(source_style, nnf)
84
+ else:
85
+ error_style = self.get_patch_error(source_style, nnf, target_style)
86
+ error = error_guide * self.guide_weight + error_style
87
+ return error
88
+
89
+ def clamp_bound(self, nnf):
90
+ nnf[:,:,:,0] = cp.clip(nnf[:,:,:,0], 0, self.height-1)
91
+ nnf[:,:,:,1] = cp.clip(nnf[:,:,:,1], 0, self.width-1)
92
+ return nnf
93
+
94
+ def random_step(self, nnf, r):
95
+ batch_size = nnf.shape[0]
96
+ step = cp.random.randint(-r, r+1, size=(batch_size, self.height, self.width, 2), dtype=cp.int32)
97
+ upd_nnf = self.clamp_bound(nnf + step)
98
+ return upd_nnf
99
+
100
+ def neighboor_step(self, nnf, d):
101
+ if d==0:
102
+ upd_nnf = cp.concatenate([nnf[:, :1, :], nnf[:, :-1, :]], axis=1)
103
+ upd_nnf[:, :, :, 0] += 1
104
+ elif d==1:
105
+ upd_nnf = cp.concatenate([nnf[:, :, :1], nnf[:, :, :-1]], axis=2)
106
+ upd_nnf[:, :, :, 1] += 1
107
+ elif d==2:
108
+ upd_nnf = cp.concatenate([nnf[:, 1:, :], nnf[:, -1:, :]], axis=1)
109
+ upd_nnf[:, :, :, 0] -= 1
110
+ elif d==3:
111
+ upd_nnf = cp.concatenate([nnf[:, :, 1:], nnf[:, :, -1:]], axis=2)
112
+ upd_nnf[:, :, :, 1] -= 1
113
+ upd_nnf = self.clamp_bound(upd_nnf)
114
+ return upd_nnf
115
+
116
+ def shift_nnf(self, nnf, d):
117
+ if d>0:
118
+ d = min(nnf.shape[0], d)
119
+ upd_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
120
+ else:
121
+ d = max(-nnf.shape[0], d)
122
+ upd_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
123
+ return upd_nnf
124
+
125
+ def track_step(self, nnf, d):
126
+ if self.use_pairwise_patch_error:
127
+ upd_nnf = cp.zeros_like(nnf)
128
+ upd_nnf[0::2] = self.shift_nnf(nnf[0::2], d)
129
+ upd_nnf[1::2] = self.shift_nnf(nnf[1::2], d)
130
+ else:
131
+ upd_nnf = self.shift_nnf(nnf, d)
132
+ return upd_nnf
133
+
134
+ def C(self, n, m):
135
+ # not used
136
+ c = 1
137
+ for i in range(1, n+1):
138
+ c *= i
139
+ for i in range(1, m+1):
140
+ c //= i
141
+ for i in range(1, n-m+1):
142
+ c //= i
143
+ return c
144
+
145
+ def bezier_step(self, nnf, r):
146
+ # not used
147
+ n = r * 2 - 1
148
+ upd_nnf = cp.zeros(shape=nnf.shape, dtype=cp.float32)
149
+ for i, d in enumerate(list(range(-r, 0)) + list(range(1, r+1))):
150
+ if d>0:
151
+ ctl_nnf = cp.concatenate([nnf[d:]] + [nnf[-1:]] * d, axis=0)
152
+ elif d<0:
153
+ ctl_nnf = cp.concatenate([nnf[:1]] * (-d) + [nnf[:d]], axis=0)
154
+ upd_nnf += ctl_nnf * (self.C(n, i) / 2**n)
155
+ upd_nnf = self.clamp_bound(upd_nnf).astype(nnf.dtype)
156
+ return upd_nnf
157
+
158
+ def update(self, source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf):
159
+ upd_err = self.get_error(source_guide, target_guide, source_style, target_style, upd_nnf)
160
+ upd_idx = (upd_err < err)
161
+ nnf[upd_idx] = upd_nnf[upd_idx]
162
+ err[upd_idx] = upd_err[upd_idx]
163
+ return nnf, err
164
+
165
+ def propagation(self, source_guide, target_guide, source_style, target_style, nnf, err):
166
+ for d in cp.random.permutation(4):
167
+ upd_nnf = self.neighboor_step(nnf, d)
168
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
169
+ return nnf, err
170
+
171
+ def random_search(self, source_guide, target_guide, source_style, target_style, nnf, err):
172
+ for i in range(self.random_search_steps):
173
+ upd_nnf = self.random_step(nnf, self.random_search_range)
174
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
175
+ return nnf, err
176
+
177
+ def track(self, source_guide, target_guide, source_style, target_style, nnf, err):
178
+ for d in range(1, self.tracking_window_size + 1):
179
+ upd_nnf = self.track_step(nnf, d)
180
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
181
+ upd_nnf = self.track_step(nnf, -d)
182
+ nnf, err = self.update(source_guide, target_guide, source_style, target_style, nnf, err, upd_nnf)
183
+ return nnf, err
184
+
185
+ def iteration(self, source_guide, target_guide, source_style, target_style, nnf, err):
186
+ nnf, err = self.propagation(source_guide, target_guide, source_style, target_style, nnf, err)
187
+ nnf, err = self.random_search(source_guide, target_guide, source_style, target_style, nnf, err)
188
+ nnf, err = self.track(source_guide, target_guide, source_style, target_style, nnf, err)
189
+ return nnf, err
190
+
191
+ def estimate_nnf(self, source_guide, target_guide, source_style, nnf):
192
+ with cp.cuda.Device(self.gpu_id):
193
+ source_guide = self.pad_image(source_guide)
194
+ target_guide = self.pad_image(target_guide)
195
+ source_style = self.pad_image(source_style)
196
+ for it in range(self.num_iter):
197
+ self.patch_size = self.patch_size_list[it]
198
+ target_style = self.apply_nnf_to_image(nnf, source_style)
199
+ err = self.get_error(source_guide, target_guide, source_style, target_style, nnf)
200
+ nnf, err = self.iteration(source_guide, target_guide, source_style, target_style, nnf, err)
201
+ target_style = self.unpad_image(self.apply_nnf_to_image(nnf, source_style))
202
+ return nnf, target_style
203
+
204
+
205
+ class PyramidPatchMatcher:
206
+ def __init__(
207
+ self, image_height, image_width, channel, minimum_patch_size,
208
+ threads_per_block=8, num_iter=5, gpu_id=0, guide_weight=10.0,
209
+ use_mean_target_style=False, use_pairwise_patch_error=False,
210
+ tracking_window_size=0,
211
+ initialize="identity"
212
+ ):
213
+ maximum_patch_size = minimum_patch_size + (num_iter - 1) * 2
214
+ self.pyramid_level = int(np.log2(min(image_height, image_width) / maximum_patch_size))
215
+ self.pyramid_heights = []
216
+ self.pyramid_widths = []
217
+ self.patch_matchers = []
218
+ self.minimum_patch_size = minimum_patch_size
219
+ self.num_iter = num_iter
220
+ self.gpu_id = gpu_id
221
+ self.initialize = initialize
222
+ for level in range(self.pyramid_level):
223
+ height = image_height//(2**(self.pyramid_level - 1 - level))
224
+ width = image_width//(2**(self.pyramid_level - 1 - level))
225
+ self.pyramid_heights.append(height)
226
+ self.pyramid_widths.append(width)
227
+ self.patch_matchers.append(PatchMatcher(
228
+ height, width, channel, minimum_patch_size=minimum_patch_size,
229
+ threads_per_block=threads_per_block, num_iter=num_iter, gpu_id=gpu_id, guide_weight=guide_weight,
230
+ use_mean_target_style=use_mean_target_style, use_pairwise_patch_error=use_pairwise_patch_error,
231
+ tracking_window_size=tracking_window_size
232
+ ))
233
+
234
+ def resample_image(self, images, level):
235
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
236
+ images = images.get()
237
+ images_resample = []
238
+ for image in images:
239
+ image_resample = cv2.resize(image, (width, height), interpolation=cv2.INTER_AREA)
240
+ images_resample.append(image_resample)
241
+ images_resample = cp.array(np.stack(images_resample), dtype=cp.float32)
242
+ return images_resample
243
+
244
+ def initialize_nnf(self, batch_size):
245
+ if self.initialize == "random":
246
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
247
+ nnf = cp.stack([
248
+ cp.random.randint(0, height, (batch_size, height, width), dtype=cp.int32),
249
+ cp.random.randint(0, width, (batch_size, height, width), dtype=cp.int32)
250
+ ], axis=3)
251
+ elif self.initialize == "identity":
252
+ height, width = self.pyramid_heights[0], self.pyramid_widths[0]
253
+ nnf = cp.stack([
254
+ cp.repeat(cp.arange(height), width).reshape(height, width),
255
+ cp.tile(cp.arange(width), height).reshape(height, width)
256
+ ], axis=2)
257
+ nnf = cp.stack([nnf] * batch_size)
258
+ else:
259
+ raise NotImplementedError()
260
+ return nnf
261
+
262
+ def update_nnf(self, nnf, level):
263
+ # upscale
264
+ nnf = nnf.repeat(2, axis=1).repeat(2, axis=2) * 2
265
+ nnf[:,[i for i in range(nnf.shape[0]) if i&1],:,0] += 1
266
+ nnf[:,:,[i for i in range(nnf.shape[0]) if i&1],1] += 1
267
+ # check if scale is 2
268
+ height, width = self.pyramid_heights[level], self.pyramid_widths[level]
269
+ if height != nnf.shape[0] * 2 or width != nnf.shape[1] * 2:
270
+ nnf = nnf.get().astype(np.float32)
271
+ nnf = [cv2.resize(n, (width, height), interpolation=cv2.INTER_LINEAR) for n in nnf]
272
+ nnf = cp.array(np.stack(nnf), dtype=cp.int32)
273
+ nnf = self.patch_matchers[level].clamp_bound(nnf)
274
+ return nnf
275
+
276
+ def apply_nnf_to_image(self, nnf, image):
277
+ with cp.cuda.Device(self.gpu_id):
278
+ image = self.patch_matchers[-1].pad_image(image)
279
+ image = self.patch_matchers[-1].apply_nnf_to_image(nnf, image)
280
+ return image
281
+
282
+ def estimate_nnf(self, source_guide, target_guide, source_style):
283
+ with cp.cuda.Device(self.gpu_id):
284
+ if not isinstance(source_guide, cp.ndarray):
285
+ source_guide = cp.array(source_guide, dtype=cp.float32)
286
+ if not isinstance(target_guide, cp.ndarray):
287
+ target_guide = cp.array(target_guide, dtype=cp.float32)
288
+ if not isinstance(source_style, cp.ndarray):
289
+ source_style = cp.array(source_style, dtype=cp.float32)
290
+ for level in range(self.pyramid_level):
291
+ nnf = self.initialize_nnf(source_guide.shape[0]) if level==0 else self.update_nnf(nnf, level)
292
+ source_guide_ = self.resample_image(source_guide, level)
293
+ target_guide_ = self.resample_image(target_guide, level)
294
+ source_style_ = self.resample_image(source_style, level)
295
+ nnf, target_style = self.patch_matchers[level].estimate_nnf(
296
+ source_guide_, target_guide_, source_style_, nnf
297
+ )
298
+ return nnf.get(), target_style.get()
diffsynth/extensions/FastBlend/runners/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .accurate import AccurateModeRunner
2
+ from .fast import FastModeRunner
3
+ from .balanced import BalancedModeRunner
4
+ from .interpolation import InterpolationModeRunner, InterpolationModeSingleFrameRunner
diffsynth/extensions/FastBlend/runners/accurate.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class AccurateModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Accurate Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ use_mean_target_style=True,
18
+ **ebsynth_config
19
+ )
20
+ # run
21
+ n = len(frames_style)
22
+ for target in tqdm(range(n), desc=desc):
23
+ l, r = max(target - window_size, 0), min(target + window_size + 1, n)
24
+ remapped_frames = []
25
+ for i in range(l, r, batch_size):
26
+ j = min(i + batch_size, r)
27
+ source_guide = np.stack([frames_guide[source] for source in range(i, j)])
28
+ target_guide = np.stack([frames_guide[target]] * (j - i))
29
+ source_style = np.stack([frames_style[source] for source in range(i, j)])
30
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
31
+ remapped_frames.append(target_style)
32
+ frame = np.concatenate(remapped_frames, axis=0).mean(axis=0)
33
+ frame = frame.clip(0, 255).astype("uint8")
34
+ if save_path is not None:
35
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/balanced.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class BalancedModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, desc="Balanced Mode", save_path=None):
13
+ patch_match_engine = PyramidPatchMatcher(
14
+ image_height=frames_style[0].shape[0],
15
+ image_width=frames_style[0].shape[1],
16
+ channel=3,
17
+ **ebsynth_config
18
+ )
19
+ # tasks
20
+ n = len(frames_style)
21
+ tasks = []
22
+ for target in range(n):
23
+ for source in range(target - window_size, target + window_size + 1):
24
+ if source >= 0 and source < n and source != target:
25
+ tasks.append((source, target))
26
+ # run
27
+ frames = [(None, 1) for i in range(n)]
28
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
29
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
30
+ source_guide = np.stack([frames_guide[source] for source, target in tasks_batch])
31
+ target_guide = np.stack([frames_guide[target] for source, target in tasks_batch])
32
+ source_style = np.stack([frames_style[source] for source, target in tasks_batch])
33
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
34
+ for (source, target), result in zip(tasks_batch, target_style):
35
+ frame, weight = frames[target]
36
+ if frame is None:
37
+ frame = frames_style[target]
38
+ frames[target] = (
39
+ frame * (weight / (weight + 1)) + result / (weight + 1),
40
+ weight + 1
41
+ )
42
+ if weight + 1 == min(n, target + window_size + 1) - max(0, target - window_size):
43
+ frame = frame.clip(0, 255).astype("uint8")
44
+ if save_path is not None:
45
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
46
+ frames[target] = (None, 1)
diffsynth/extensions/FastBlend/runners/fast.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import functools, os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class TableManager:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def task_list(self, n):
13
+ tasks = []
14
+ max_level = 1
15
+ while (1<<max_level)<=n:
16
+ max_level += 1
17
+ for i in range(n):
18
+ j = i
19
+ for level in range(max_level):
20
+ if i&(1<<level):
21
+ continue
22
+ j |= 1<<level
23
+ if j>=n:
24
+ break
25
+ meta_data = {
26
+ "source": i,
27
+ "target": j,
28
+ "level": level + 1
29
+ }
30
+ tasks.append(meta_data)
31
+ tasks.sort(key=functools.cmp_to_key(lambda u, v: u["level"]-v["level"]))
32
+ return tasks
33
+
34
+ def build_remapping_table(self, frames_guide, frames_style, patch_match_engine, batch_size, desc=""):
35
+ n = len(frames_guide)
36
+ tasks = self.task_list(n)
37
+ remapping_table = [[(frames_style[i], 1)] for i in range(n)]
38
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
39
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
40
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
41
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
42
+ source_style = np.stack([frames_style[task["source"]] for task in tasks_batch])
43
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
44
+ for task, result in zip(tasks_batch, target_style):
45
+ target, level = task["target"], task["level"]
46
+ if len(remapping_table[target])==level:
47
+ remapping_table[target].append((result, 1))
48
+ else:
49
+ frame, weight = remapping_table[target][level]
50
+ remapping_table[target][level] = (
51
+ frame * (weight / (weight + 1)) + result / (weight + 1),
52
+ weight + 1
53
+ )
54
+ return remapping_table
55
+
56
+ def remapping_table_to_blending_table(self, table):
57
+ for i in range(len(table)):
58
+ for j in range(1, len(table[i])):
59
+ frame_1, weight_1 = table[i][j-1]
60
+ frame_2, weight_2 = table[i][j]
61
+ frame = (frame_1 + frame_2) / 2
62
+ weight = weight_1 + weight_2
63
+ table[i][j] = (frame, weight)
64
+ return table
65
+
66
+ def tree_query(self, leftbound, rightbound):
67
+ node_list = []
68
+ node_index = rightbound
69
+ while node_index>=leftbound:
70
+ node_level = 0
71
+ while (1<<node_level)&node_index and node_index-(1<<node_level+1)+1>=leftbound:
72
+ node_level += 1
73
+ node_list.append((node_index, node_level))
74
+ node_index -= 1<<node_level
75
+ return node_list
76
+
77
+ def process_window_sum(self, frames_guide, blending_table, patch_match_engine, window_size, batch_size, desc=""):
78
+ n = len(blending_table)
79
+ tasks = []
80
+ frames_result = []
81
+ for target in range(n):
82
+ node_list = self.tree_query(max(target-window_size, 0), target)
83
+ for source, level in node_list:
84
+ if source!=target:
85
+ meta_data = {
86
+ "source": source,
87
+ "target": target,
88
+ "level": level
89
+ }
90
+ tasks.append(meta_data)
91
+ else:
92
+ frames_result.append(blending_table[target][level])
93
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=desc):
94
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
95
+ source_guide = np.stack([frames_guide[task["source"]] for task in tasks_batch])
96
+ target_guide = np.stack([frames_guide[task["target"]] for task in tasks_batch])
97
+ source_style = np.stack([blending_table[task["source"]][task["level"]][0] for task in tasks_batch])
98
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
99
+ for task, frame_2 in zip(tasks_batch, target_style):
100
+ source, target, level = task["source"], task["target"], task["level"]
101
+ frame_1, weight_1 = frames_result[target]
102
+ weight_2 = blending_table[source][level][1]
103
+ weight = weight_1 + weight_2
104
+ frame = frame_1 * (weight_1 / weight) + frame_2 * (weight_2 / weight)
105
+ frames_result[target] = (frame, weight)
106
+ return frames_result
107
+
108
+
109
+ class FastModeRunner:
110
+ def __init__(self):
111
+ pass
112
+
113
+ def run(self, frames_guide, frames_style, batch_size, window_size, ebsynth_config, save_path=None):
114
+ frames_guide = frames_guide.raw_data()
115
+ frames_style = frames_style.raw_data()
116
+ table_manager = TableManager()
117
+ patch_match_engine = PyramidPatchMatcher(
118
+ image_height=frames_style[0].shape[0],
119
+ image_width=frames_style[0].shape[1],
120
+ channel=3,
121
+ **ebsynth_config
122
+ )
123
+ # left part
124
+ table_l = table_manager.build_remapping_table(frames_guide, frames_style, patch_match_engine, batch_size, desc="Fast Mode Step 1/4")
125
+ table_l = table_manager.remapping_table_to_blending_table(table_l)
126
+ table_l = table_manager.process_window_sum(frames_guide, table_l, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 2/4")
127
+ # right part
128
+ table_r = table_manager.build_remapping_table(frames_guide[::-1], frames_style[::-1], patch_match_engine, batch_size, desc="Fast Mode Step 3/4")
129
+ table_r = table_manager.remapping_table_to_blending_table(table_r)
130
+ table_r = table_manager.process_window_sum(frames_guide[::-1], table_r, patch_match_engine, window_size, batch_size, desc="Fast Mode Step 4/4")[::-1]
131
+ # merge
132
+ frames = []
133
+ for (frame_l, weight_l), frame_m, (frame_r, weight_r) in zip(table_l, frames_style, table_r):
134
+ weight_m = -1
135
+ weight = weight_l + weight_m + weight_r
136
+ frame = frame_l * (weight_l / weight) + frame_m * (weight_m / weight) + frame_r * (weight_r / weight)
137
+ frames.append(frame)
138
+ frames = [frame.clip(0, 255).astype("uint8") for frame in frames]
139
+ if save_path is not None:
140
+ for target, frame in enumerate(frames):
141
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % target))
diffsynth/extensions/FastBlend/runners/interpolation.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..patch_match import PyramidPatchMatcher
2
+ import os
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+
7
+
8
+ class InterpolationModeRunner:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def get_index_dict(self, index_style):
13
+ index_dict = {}
14
+ for i, index in enumerate(index_style):
15
+ index_dict[index] = i
16
+ return index_dict
17
+
18
+ def get_weight(self, l, m, r):
19
+ weight_l, weight_r = abs(m - r), abs(m - l)
20
+ if weight_l + weight_r == 0:
21
+ weight_l, weight_r = 0.5, 0.5
22
+ else:
23
+ weight_l, weight_r = weight_l / (weight_l + weight_r), weight_r / (weight_l + weight_r)
24
+ return weight_l, weight_r
25
+
26
+ def get_task_group(self, index_style, n):
27
+ task_group = []
28
+ index_style = sorted(index_style)
29
+ # first frame
30
+ if index_style[0]>0:
31
+ tasks = []
32
+ for m in range(index_style[0]):
33
+ tasks.append((index_style[0], m, index_style[0]))
34
+ task_group.append(tasks)
35
+ # middle frames
36
+ for l, r in zip(index_style[:-1], index_style[1:]):
37
+ tasks = []
38
+ for m in range(l, r):
39
+ tasks.append((l, m, r))
40
+ task_group.append(tasks)
41
+ # last frame
42
+ tasks = []
43
+ for m in range(index_style[-1], n):
44
+ tasks.append((index_style[-1], m, index_style[-1]))
45
+ task_group.append(tasks)
46
+ return task_group
47
+
48
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
49
+ patch_match_engine = PyramidPatchMatcher(
50
+ image_height=frames_style[0].shape[0],
51
+ image_width=frames_style[0].shape[1],
52
+ channel=3,
53
+ use_mean_target_style=False,
54
+ use_pairwise_patch_error=True,
55
+ **ebsynth_config
56
+ )
57
+ # task
58
+ index_dict = self.get_index_dict(index_style)
59
+ task_group = self.get_task_group(index_style, len(frames_guide))
60
+ # run
61
+ for tasks in task_group:
62
+ index_start, index_end = min([i[1] for i in tasks]), max([i[1] for i in tasks])
63
+ for batch_id in tqdm(range(0, len(tasks), batch_size), desc=f"Rendering frames {index_start}...{index_end}"):
64
+ tasks_batch = tasks[batch_id: min(batch_id+batch_size, len(tasks))]
65
+ source_guide, target_guide, source_style = [], [], []
66
+ for l, m, r in tasks_batch:
67
+ # l -> m
68
+ source_guide.append(frames_guide[l])
69
+ target_guide.append(frames_guide[m])
70
+ source_style.append(frames_style[index_dict[l]])
71
+ # r -> m
72
+ source_guide.append(frames_guide[r])
73
+ target_guide.append(frames_guide[m])
74
+ source_style.append(frames_style[index_dict[r]])
75
+ source_guide = np.stack(source_guide)
76
+ target_guide = np.stack(target_guide)
77
+ source_style = np.stack(source_style)
78
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
79
+ if save_path is not None:
80
+ for frame_l, frame_r, (l, m, r) in zip(target_style[0::2], target_style[1::2], tasks_batch):
81
+ weight_l, weight_r = self.get_weight(l, m, r)
82
+ frame = frame_l * weight_l + frame_r * weight_r
83
+ frame = frame.clip(0, 255).astype("uint8")
84
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % m))
85
+
86
+
87
+ class InterpolationModeSingleFrameRunner:
88
+ def __init__(self):
89
+ pass
90
+
91
+ def run(self, frames_guide, frames_style, index_style, batch_size, ebsynth_config, save_path=None):
92
+ # check input
93
+ tracking_window_size = ebsynth_config["tracking_window_size"]
94
+ if tracking_window_size * 2 >= batch_size:
95
+ raise ValueError("batch_size should be larger than track_window_size * 2")
96
+ frame_style = frames_style[0]
97
+ frame_guide = frames_guide[index_style[0]]
98
+ patch_match_engine = PyramidPatchMatcher(
99
+ image_height=frame_style.shape[0],
100
+ image_width=frame_style.shape[1],
101
+ channel=3,
102
+ **ebsynth_config
103
+ )
104
+ # run
105
+ frame_id, n = 0, len(frames_guide)
106
+ for i in tqdm(range(0, n, batch_size - tracking_window_size * 2), desc=f"Rendering frames 0...{n}"):
107
+ if i + batch_size > n:
108
+ l, r = max(n - batch_size, 0), n
109
+ else:
110
+ l, r = i, i + batch_size
111
+ source_guide = np.stack([frame_guide] * (r-l))
112
+ target_guide = np.stack([frames_guide[i] for i in range(l, r)])
113
+ source_style = np.stack([frame_style] * (r-l))
114
+ _, target_style = patch_match_engine.estimate_nnf(source_guide, target_guide, source_style)
115
+ for i, frame in zip(range(l, r), target_style):
116
+ if i==frame_id:
117
+ frame = frame.clip(0, 255).astype("uint8")
118
+ Image.fromarray(frame).save(os.path.join(save_path, "%05d.png" % frame_id))
119
+ frame_id += 1
120
+ if r < n and r-frame_id <= tracking_window_size:
121
+ break
diffsynth/extensions/RIFE/__init__.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+
8
+ def warp(tenInput, tenFlow, device):
9
+ backwarp_tenGrid = {}
10
+ k = (str(tenFlow.device), str(tenFlow.size()))
11
+ if k not in backwarp_tenGrid:
12
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view(
13
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
14
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view(
15
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
16
+ backwarp_tenGrid[k] = torch.cat(
17
+ [tenHorizontal, tenVertical], 1).to(device)
18
+
19
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
20
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
21
+
22
+ g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
23
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)
24
+
25
+
26
+ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
27
+ return nn.Sequential(
28
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
29
+ padding=padding, dilation=dilation, bias=True),
30
+ nn.PReLU(out_planes)
31
+ )
32
+
33
+
34
+ class IFBlock(nn.Module):
35
+ def __init__(self, in_planes, c=64):
36
+ super(IFBlock, self).__init__()
37
+ self.conv0 = nn.Sequential(conv(in_planes, c//2, 3, 2, 1), conv(c//2, c, 3, 2, 1),)
38
+ self.convblock0 = nn.Sequential(conv(c, c), conv(c, c))
39
+ self.convblock1 = nn.Sequential(conv(c, c), conv(c, c))
40
+ self.convblock2 = nn.Sequential(conv(c, c), conv(c, c))
41
+ self.convblock3 = nn.Sequential(conv(c, c), conv(c, c))
42
+ self.conv1 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 4, 4, 2, 1))
43
+ self.conv2 = nn.Sequential(nn.ConvTranspose2d(c, c//2, 4, 2, 1), nn.PReLU(c//2), nn.ConvTranspose2d(c//2, 1, 4, 2, 1))
44
+
45
+ def forward(self, x, flow, scale=1):
46
+ x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
47
+ flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 1. / scale
48
+ feat = self.conv0(torch.cat((x, flow), 1))
49
+ feat = self.convblock0(feat) + feat
50
+ feat = self.convblock1(feat) + feat
51
+ feat = self.convblock2(feat) + feat
52
+ feat = self.convblock3(feat) + feat
53
+ flow = self.conv1(feat)
54
+ mask = self.conv2(feat)
55
+ flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) * scale
56
+ mask = F.interpolate(mask, scale_factor=scale, mode="bilinear", align_corners=False, recompute_scale_factor=False)
57
+ return flow, mask
58
+
59
+
60
+ class IFNet(nn.Module):
61
+ def __init__(self):
62
+ super(IFNet, self).__init__()
63
+ self.block0 = IFBlock(7+4, c=90)
64
+ self.block1 = IFBlock(7+4, c=90)
65
+ self.block2 = IFBlock(7+4, c=90)
66
+ self.block_tea = IFBlock(10+4, c=90)
67
+
68
+ def forward(self, x, scale_list=[4, 2, 1], training=False):
69
+ if training == False:
70
+ channel = x.shape[1] // 2
71
+ img0 = x[:, :channel]
72
+ img1 = x[:, channel:]
73
+ flow_list = []
74
+ merged = []
75
+ mask_list = []
76
+ warped_img0 = img0
77
+ warped_img1 = img1
78
+ flow = (x[:, :4]).detach() * 0
79
+ mask = (x[:, :1]).detach() * 0
80
+ block = [self.block0, self.block1, self.block2]
81
+ for i in range(3):
82
+ f0, m0 = block[i](torch.cat((warped_img0[:, :3], warped_img1[:, :3], mask), 1), flow, scale=scale_list[i])
83
+ f1, m1 = block[i](torch.cat((warped_img1[:, :3], warped_img0[:, :3], -mask), 1), torch.cat((flow[:, 2:4], flow[:, :2]), 1), scale=scale_list[i])
84
+ flow = flow + (f0 + torch.cat((f1[:, 2:4], f1[:, :2]), 1)) / 2
85
+ mask = mask + (m0 + (-m1)) / 2
86
+ mask_list.append(mask)
87
+ flow_list.append(flow)
88
+ warped_img0 = warp(img0, flow[:, :2], device=x.device)
89
+ warped_img1 = warp(img1, flow[:, 2:4], device=x.device)
90
+ merged.append((warped_img0, warped_img1))
91
+ '''
92
+ c0 = self.contextnet(img0, flow[:, :2])
93
+ c1 = self.contextnet(img1, flow[:, 2:4])
94
+ tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
95
+ res = tmp[:, 1:4] * 2 - 1
96
+ '''
97
+ for i in range(3):
98
+ mask_list[i] = torch.sigmoid(mask_list[i])
99
+ merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
100
+ return flow_list, mask_list[2], merged
101
+
102
+ def state_dict_converter(self):
103
+ return IFNetStateDictConverter()
104
+
105
+
106
+ class IFNetStateDictConverter:
107
+ def __init__(self):
108
+ pass
109
+
110
+ def from_diffusers(self, state_dict):
111
+ state_dict_ = {k.replace("module.", ""): v for k, v in state_dict.items()}
112
+ return state_dict_
113
+
114
+ def from_civitai(self, state_dict):
115
+ return self.from_diffusers(state_dict)
116
+
117
+
118
+ class RIFEInterpolater:
119
+ def __init__(self, model, device="cuda"):
120
+ self.model = model
121
+ self.device = device
122
+ # IFNet only does not support float16
123
+ self.torch_dtype = torch.float32
124
+
125
+ @staticmethod
126
+ def from_model_manager(model_manager):
127
+ return RIFEInterpolater(model_manager.RIFE, device=model_manager.device)
128
+
129
+ def process_image(self, image):
130
+ width, height = image.size
131
+ if width % 32 != 0 or height % 32 != 0:
132
+ width = (width + 31) // 32
133
+ height = (height + 31) // 32
134
+ image = image.resize((width, height))
135
+ image = torch.Tensor(np.array(image, dtype=np.float32)[:, :, [2,1,0]] / 255).permute(2, 0, 1)
136
+ return image
137
+
138
+ def process_images(self, images):
139
+ images = [self.process_image(image) for image in images]
140
+ images = torch.stack(images)
141
+ return images
142
+
143
+ def decode_images(self, images):
144
+ images = (images[:, [2,1,0]].permute(0, 2, 3, 1) * 255).clip(0, 255).numpy().astype(np.uint8)
145
+ images = [Image.fromarray(image) for image in images]
146
+ return images
147
+
148
+ def add_interpolated_images(self, images, interpolated_images):
149
+ output_images = []
150
+ for image, interpolated_image in zip(images, interpolated_images):
151
+ output_images.append(image)
152
+ output_images.append(interpolated_image)
153
+ output_images.append(images[-1])
154
+ return output_images
155
+
156
+
157
+ @torch.no_grad()
158
+ def interpolate_(self, images, scale=1.0):
159
+ input_tensor = self.process_images(images)
160
+ input_tensor = torch.cat((input_tensor[:-1], input_tensor[1:]), dim=1)
161
+ input_tensor = input_tensor.to(device=self.device, dtype=self.torch_dtype)
162
+ flow, mask, merged = self.model(input_tensor, [4/scale, 2/scale, 1/scale])
163
+ output_images = self.decode_images(merged[2].cpu())
164
+ if output_images[0].size != images[0].size:
165
+ output_images = [image.resize(images[0].size) for image in output_images]
166
+ return output_images
167
+
168
+
169
+ @torch.no_grad()
170
+ def interpolate(self, images, scale=1.0, batch_size=4, num_iter=1):
171
+ # Preprocess
172
+ processed_images = self.process_images(images)
173
+
174
+ for iter in range(num_iter):
175
+ # Input
176
+ input_tensor = torch.cat((processed_images[:-1], processed_images[1:]), dim=1)
177
+
178
+ # Interpolate
179
+ output_tensor = []
180
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
181
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
182
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
183
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
184
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
185
+ output_tensor.append(merged[2].cpu())
186
+
187
+ # Output
188
+ output_tensor = torch.concat(output_tensor, dim=0).clip(0, 1)
189
+ processed_images = self.add_interpolated_images(processed_images, output_tensor)
190
+ processed_images = torch.stack(processed_images)
191
+
192
+ # To images
193
+ output_images = self.decode_images(processed_images)
194
+ if output_images[0].size != images[0].size:
195
+ output_images = [image.resize(images[0].size) for image in output_images]
196
+ return output_images
197
+
198
+
199
+ class RIFESmoother(RIFEInterpolater):
200
+ def __init__(self, model, device="cuda"):
201
+ super(RIFESmoother, self).__init__(model, device=device)
202
+
203
+ @staticmethod
204
+ def from_model_manager(model_manager):
205
+ return RIFESmoother(model_manager.RIFE, device=model_manager.device)
206
+
207
+ def process_tensors(self, input_tensor, scale=1.0, batch_size=4):
208
+ output_tensor = []
209
+ for batch_id in range(0, input_tensor.shape[0], batch_size):
210
+ batch_id_ = min(batch_id + batch_size, input_tensor.shape[0])
211
+ batch_input_tensor = input_tensor[batch_id: batch_id_]
212
+ batch_input_tensor = batch_input_tensor.to(device=self.device, dtype=self.torch_dtype)
213
+ flow, mask, merged = self.model(batch_input_tensor, [4/scale, 2/scale, 1/scale])
214
+ output_tensor.append(merged[2].cpu())
215
+ output_tensor = torch.concat(output_tensor, dim=0)
216
+ return output_tensor
217
+
218
+ @torch.no_grad()
219
+ def __call__(self, rendered_frames, scale=1.0, batch_size=4, num_iter=1, **kwargs):
220
+ # Preprocess
221
+ processed_images = self.process_images(rendered_frames)
222
+
223
+ for iter in range(num_iter):
224
+ # Input
225
+ input_tensor = torch.cat((processed_images[:-2], processed_images[2:]), dim=1)
226
+
227
+ # Interpolate
228
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
229
+
230
+ # Blend
231
+ input_tensor = torch.cat((processed_images[1:-1], output_tensor), dim=1)
232
+ output_tensor = self.process_tensors(input_tensor, scale=scale, batch_size=batch_size)
233
+
234
+ # Add to frames
235
+ processed_images[1:-1] = output_tensor
236
+
237
+ # To images
238
+ output_images = self.decode_images(processed_images)
239
+ if output_images[0].size != rendered_frames[0].size:
240
+ output_images = [image.resize(rendered_frames[0].size) for image in output_images]
241
+ return output_images
diffsynth/models/__init__.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, os
2
+ from safetensors import safe_open
3
+
4
+ from .sd_text_encoder import SDTextEncoder
5
+ from .sd_unet import SDUNet
6
+ from .sd_vae_encoder import SDVAEEncoder
7
+ from .sd_vae_decoder import SDVAEDecoder
8
+ from .sd_lora import SDLoRA
9
+
10
+ from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
11
+ from .sdxl_unet import SDXLUNet
12
+ from .sdxl_vae_decoder import SDXLVAEDecoder
13
+ from .sdxl_vae_encoder import SDXLVAEEncoder
14
+
15
+ from .sd_controlnet import SDControlNet
16
+
17
+ from .sd_motion import SDMotionModel
18
+
19
+
20
+ class ModelManager:
21
+ def __init__(self, torch_dtype=torch.float16, device="cuda"):
22
+ self.torch_dtype = torch_dtype
23
+ self.device = device
24
+ self.model = {}
25
+ self.model_path = {}
26
+ self.textual_inversion_dict = {}
27
+
28
+ def is_RIFE(self, state_dict):
29
+ param_name = "block_tea.convblock3.0.1.weight"
30
+ return param_name in state_dict or ("module." + param_name) in state_dict
31
+
32
+ def is_beautiful_prompt(self, state_dict):
33
+ param_name = "transformer.h.9.self_attention.query_key_value.weight"
34
+ return param_name in state_dict
35
+
36
+ def is_stabe_diffusion_xl(self, state_dict):
37
+ param_name = "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight"
38
+ return param_name in state_dict
39
+
40
+ def is_stable_diffusion(self, state_dict):
41
+ if self.is_stabe_diffusion_xl(state_dict):
42
+ return False
43
+ param_name = "model.diffusion_model.output_blocks.9.1.transformer_blocks.0.norm3.weight"
44
+ return param_name in state_dict
45
+
46
+ def is_controlnet(self, state_dict):
47
+ param_name = "control_model.time_embed.0.weight"
48
+ return param_name in state_dict
49
+
50
+ def is_animatediff(self, state_dict):
51
+ param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
52
+ return param_name in state_dict
53
+
54
+ def is_sd_lora(self, state_dict):
55
+ param_name = "lora_unet_up_blocks_3_attentions_2_transformer_blocks_0_ff_net_2.lora_up.weight"
56
+ return param_name in state_dict
57
+
58
+ def is_translator(self, state_dict):
59
+ param_name = "model.encoder.layers.5.self_attn_layer_norm.weight"
60
+ return param_name in state_dict and len(state_dict) == 254
61
+
62
+ def load_stable_diffusion(self, state_dict, components=None, file_path=""):
63
+ component_dict = {
64
+ "text_encoder": SDTextEncoder,
65
+ "unet": SDUNet,
66
+ "vae_decoder": SDVAEDecoder,
67
+ "vae_encoder": SDVAEEncoder,
68
+ "refiner": SDXLUNet,
69
+ }
70
+ if components is None:
71
+ components = ["text_encoder", "unet", "vae_decoder", "vae_encoder"]
72
+ for component in components:
73
+ if component == "text_encoder":
74
+ # Add additional token embeddings to text encoder
75
+ token_embeddings = [state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"]]
76
+ for keyword in self.textual_inversion_dict:
77
+ _, embeddings = self.textual_inversion_dict[keyword]
78
+ token_embeddings.append(embeddings.to(dtype=token_embeddings[0].dtype))
79
+ token_embeddings = torch.concat(token_embeddings, dim=0)
80
+ state_dict["cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"] = token_embeddings
81
+ self.model[component] = component_dict[component](vocab_size=token_embeddings.shape[0])
82
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
83
+ self.model[component].to(self.torch_dtype).to(self.device)
84
+ else:
85
+ self.model[component] = component_dict[component]()
86
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
87
+ self.model[component].to(self.torch_dtype).to(self.device)
88
+ self.model_path[component] = file_path
89
+
90
+ def load_stable_diffusion_xl(self, state_dict, components=None, file_path=""):
91
+ component_dict = {
92
+ "text_encoder": SDXLTextEncoder,
93
+ "text_encoder_2": SDXLTextEncoder2,
94
+ "unet": SDXLUNet,
95
+ "vae_decoder": SDXLVAEDecoder,
96
+ "vae_encoder": SDXLVAEEncoder,
97
+ }
98
+ if components is None:
99
+ components = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder"]
100
+ for component in components:
101
+ self.model[component] = component_dict[component]()
102
+ self.model[component].load_state_dict(self.model[component].state_dict_converter().from_civitai(state_dict))
103
+ if component in ["vae_decoder", "vae_encoder"]:
104
+ # These two model will output nan when float16 is enabled.
105
+ # The precision problem happens in the last three resnet blocks.
106
+ # I do not know how to solve this problem.
107
+ self.model[component].to(torch.float32).to(self.device)
108
+ else:
109
+ self.model[component].to(self.torch_dtype).to(self.device)
110
+ self.model_path[component] = file_path
111
+
112
+ def load_controlnet(self, state_dict, file_path=""):
113
+ component = "controlnet"
114
+ if component not in self.model:
115
+ self.model[component] = []
116
+ self.model_path[component] = []
117
+ model = SDControlNet()
118
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
119
+ model.to(self.torch_dtype).to(self.device)
120
+ self.model[component].append(model)
121
+ self.model_path[component].append(file_path)
122
+
123
+ def load_animatediff(self, state_dict, file_path=""):
124
+ component = "motion_modules"
125
+ model = SDMotionModel()
126
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
127
+ model.to(self.torch_dtype).to(self.device)
128
+ self.model[component] = model
129
+ self.model_path[component] = file_path
130
+
131
+ def load_beautiful_prompt(self, state_dict, file_path=""):
132
+ component = "beautiful_prompt"
133
+ from transformers import AutoModelForCausalLM
134
+ model_folder = os.path.dirname(file_path)
135
+ model = AutoModelForCausalLM.from_pretrained(
136
+ model_folder, state_dict=state_dict, local_files_only=True, torch_dtype=self.torch_dtype
137
+ ).to(self.device).eval()
138
+ self.model[component] = model
139
+ self.model_path[component] = file_path
140
+
141
+ def load_RIFE(self, state_dict, file_path=""):
142
+ component = "RIFE"
143
+ from ..extensions.RIFE import IFNet
144
+ model = IFNet().eval()
145
+ model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
146
+ model.to(torch.float32).to(self.device)
147
+ self.model[component] = model
148
+ self.model_path[component] = file_path
149
+
150
+ def load_sd_lora(self, state_dict, alpha):
151
+ SDLoRA().add_lora_to_text_encoder(self.model["text_encoder"], state_dict, alpha=alpha, device=self.device)
152
+ SDLoRA().add_lora_to_unet(self.model["unet"], state_dict, alpha=alpha, device=self.device)
153
+
154
+ def load_translator(self, state_dict, file_path=""):
155
+ # This model is lightweight, we do not place it on GPU.
156
+ component = "translator"
157
+ from transformers import AutoModelForSeq2SeqLM
158
+ model_folder = os.path.dirname(file_path)
159
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_folder).eval()
160
+ self.model[component] = model
161
+ self.model_path[component] = file_path
162
+
163
+ def search_for_embeddings(self, state_dict):
164
+ embeddings = []
165
+ for k in state_dict:
166
+ if isinstance(state_dict[k], torch.Tensor):
167
+ embeddings.append(state_dict[k])
168
+ elif isinstance(state_dict[k], dict):
169
+ embeddings += self.search_for_embeddings(state_dict[k])
170
+ return embeddings
171
+
172
+ def load_textual_inversions(self, folder):
173
+ # Store additional tokens here
174
+ self.textual_inversion_dict = {}
175
+
176
+ # Load every textual inversion file
177
+ for file_name in os.listdir(folder):
178
+ if file_name.endswith(".txt"):
179
+ continue
180
+ keyword = os.path.splitext(file_name)[0]
181
+ state_dict = load_state_dict(os.path.join(folder, file_name))
182
+
183
+ # Search for embeddings
184
+ for embeddings in self.search_for_embeddings(state_dict):
185
+ if len(embeddings.shape) == 2 and embeddings.shape[1] == 768:
186
+ tokens = [f"{keyword}_{i}" for i in range(embeddings.shape[0])]
187
+ self.textual_inversion_dict[keyword] = (tokens, embeddings)
188
+ break
189
+
190
+ def load_model(self, file_path, components=None, lora_alphas=[]):
191
+ state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
192
+ if self.is_animatediff(state_dict):
193
+ self.load_animatediff(state_dict, file_path=file_path)
194
+ elif self.is_controlnet(state_dict):
195
+ self.load_controlnet(state_dict, file_path=file_path)
196
+ elif self.is_stabe_diffusion_xl(state_dict):
197
+ self.load_stable_diffusion_xl(state_dict, components=components, file_path=file_path)
198
+ elif self.is_stable_diffusion(state_dict):
199
+ self.load_stable_diffusion(state_dict, components=components, file_path=file_path)
200
+ elif self.is_sd_lora(state_dict):
201
+ self.load_sd_lora(state_dict, alpha=lora_alphas.pop(0))
202
+ elif self.is_beautiful_prompt(state_dict):
203
+ self.load_beautiful_prompt(state_dict, file_path=file_path)
204
+ elif self.is_RIFE(state_dict):
205
+ self.load_RIFE(state_dict, file_path=file_path)
206
+ elif self.is_translator(state_dict):
207
+ self.load_translator(state_dict, file_path=file_path)
208
+
209
+ def load_models(self, file_path_list, lora_alphas=[]):
210
+ for file_path in file_path_list:
211
+ self.load_model(file_path, lora_alphas=lora_alphas)
212
+
213
+ def to(self, device):
214
+ for component in self.model:
215
+ if isinstance(self.model[component], list):
216
+ for model in self.model[component]:
217
+ model.to(device)
218
+ else:
219
+ self.model[component].to(device)
220
+ torch.cuda.empty_cache()
221
+
222
+ def get_model_with_model_path(self, model_path):
223
+ for component in self.model_path:
224
+ if isinstance(self.model_path[component], str):
225
+ if os.path.samefile(self.model_path[component], model_path):
226
+ return self.model[component]
227
+ elif isinstance(self.model_path[component], list):
228
+ for i, model_path_ in enumerate(self.model_path[component]):
229
+ if os.path.samefile(model_path_, model_path):
230
+ return self.model[component][i]
231
+ raise ValueError(f"Please load model {model_path} before you use it.")
232
+
233
+ def __getattr__(self, __name):
234
+ if __name in self.model:
235
+ return self.model[__name]
236
+ else:
237
+ return super.__getattribute__(__name)
238
+
239
+
240
+ def load_state_dict(file_path, torch_dtype=None):
241
+ if file_path.endswith(".safetensors"):
242
+ return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
243
+ else:
244
+ return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
245
+
246
+
247
+ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
248
+ state_dict = {}
249
+ with safe_open(file_path, framework="pt", device="cpu") as f:
250
+ for k in f.keys():
251
+ state_dict[k] = f.get_tensor(k)
252
+ if torch_dtype is not None:
253
+ state_dict[k] = state_dict[k].to(torch_dtype)
254
+ return state_dict
255
+
256
+
257
+ def load_state_dict_from_bin(file_path, torch_dtype=None):
258
+ state_dict = torch.load(file_path, map_location="cpu")
259
+ if torch_dtype is not None:
260
+ state_dict = {i: state_dict[i].to(torch_dtype) for i in state_dict}
261
+ return state_dict
262
+
263
+
264
+ def search_parameter(param, state_dict):
265
+ for name, param_ in state_dict.items():
266
+ if param.numel() == param_.numel():
267
+ if param.shape == param_.shape:
268
+ if torch.dist(param, param_) < 1e-6:
269
+ return name
270
+ else:
271
+ if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
272
+ return name
273
+ return None
274
+
275
+
276
+ def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
277
+ matched_keys = set()
278
+ with torch.no_grad():
279
+ for name in source_state_dict:
280
+ rename = search_parameter(source_state_dict[name], target_state_dict)
281
+ if rename is not None:
282
+ print(f'"{name}": "{rename}",')
283
+ matched_keys.add(rename)
284
+ elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
285
+ length = source_state_dict[name].shape[0] // 3
286
+ rename = []
287
+ for i in range(3):
288
+ rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
289
+ if None not in rename:
290
+ print(f'"{name}": {rename},')
291
+ for rename_ in rename:
292
+ matched_keys.add(rename_)
293
+ for name in target_state_dict:
294
+ if name not in matched_keys:
295
+ print("Cannot find", name, target_state_dict[name].shape)
diffsynth/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (10.9 kB). View file
 
diffsynth/models/__pycache__/attention.cpython-39.pyc ADDED
Binary file (2.44 kB). View file
 
diffsynth/models/__pycache__/sd_controlnet.cpython-39.pyc ADDED
Binary file (39.8 kB). View file
 
diffsynth/models/__pycache__/sd_lora.cpython-39.pyc ADDED
Binary file (2.49 kB). View file
 
diffsynth/models/__pycache__/sd_motion.cpython-39.pyc ADDED
Binary file (7.05 kB). View file
 
diffsynth/models/__pycache__/sd_text_encoder.cpython-39.pyc ADDED
Binary file (26 kB). View file
 
diffsynth/models/__pycache__/sd_unet.cpython-39.pyc ADDED
Binary file (86.1 kB). View file
 
diffsynth/models/__pycache__/sd_vae_decoder.cpython-39.pyc ADDED
Binary file (17.3 kB). View file
 
diffsynth/models/__pycache__/sd_vae_encoder.cpython-39.pyc ADDED
Binary file (13.7 kB). View file
 
diffsynth/models/__pycache__/sdxl_text_encoder.cpython-39.pyc ADDED
Binary file (67.3 kB). View file
 
diffsynth/models/__pycache__/sdxl_unet.cpython-39.pyc ADDED
Binary file (213 kB). View file
 
diffsynth/models/__pycache__/sdxl_vae_decoder.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
diffsynth/models/__pycache__/sdxl_vae_encoder.cpython-39.pyc ADDED
Binary file (1.11 kB). View file
 
diffsynth/models/__pycache__/tiler.cpython-39.pyc ADDED
Binary file (3.01 kB). View file
 
diffsynth/models/attention.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+
4
+
5
+ def low_version_attention(query, key, value, attn_bias=None):
6
+ scale = 1 / query.shape[-1] ** 0.5
7
+ query = query * scale
8
+ attn = torch.matmul(query, key.transpose(-2, -1))
9
+ if attn_bias is not None:
10
+ attn = attn + attn_bias
11
+ attn = attn.softmax(-1)
12
+ return attn @ value
13
+
14
+
15
+ class Attention(torch.nn.Module):
16
+
17
+ def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
18
+ super().__init__()
19
+ dim_inner = head_dim * num_heads
20
+ kv_dim = kv_dim if kv_dim is not None else q_dim
21
+ self.num_heads = num_heads
22
+ self.head_dim = head_dim
23
+
24
+ self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
25
+ self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
26
+ self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
27
+ self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
28
+
29
+ def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
30
+ if encoder_hidden_states is None:
31
+ encoder_hidden_states = hidden_states
32
+
33
+ batch_size = encoder_hidden_states.shape[0]
34
+
35
+ q = self.to_q(hidden_states)
36
+ k = self.to_k(encoder_hidden_states)
37
+ v = self.to_v(encoder_hidden_states)
38
+
39
+ q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
40
+ k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
41
+ v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
42
+
43
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
44
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
45
+ hidden_states = hidden_states.to(q.dtype)
46
+
47
+ hidden_states = self.to_out(hidden_states)
48
+
49
+ return hidden_states
50
+
51
+ def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
52
+ if encoder_hidden_states is None:
53
+ encoder_hidden_states = hidden_states
54
+
55
+ q = self.to_q(hidden_states)
56
+ k = self.to_k(encoder_hidden_states)
57
+ v = self.to_v(encoder_hidden_states)
58
+
59
+ q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
60
+ k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
61
+ v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
62
+
63
+ if attn_mask is not None:
64
+ hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
65
+ else:
66
+ import xformers.ops as xops
67
+ hidden_states = xops.memory_efficient_attention(q, k, v)
68
+ hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
69
+
70
+ hidden_states = hidden_states.to(q.dtype)
71
+ hidden_states = self.to_out(hidden_states)
72
+
73
+ return hidden_states
74
+
75
+ def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
76
+ return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask)
diffsynth/models/sd_controlnet.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
3
+ from .tiler import TileWorker
4
+
5
+
6
+ class ControlNetConditioningLayer(torch.nn.Module):
7
+ def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
8
+ super().__init__()
9
+ self.blocks = torch.nn.ModuleList([])
10
+ self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
11
+ self.blocks.append(torch.nn.SiLU())
12
+ for i in range(1, len(channels) - 2):
13
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
14
+ self.blocks.append(torch.nn.SiLU())
15
+ self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
16
+ self.blocks.append(torch.nn.SiLU())
17
+ self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
18
+
19
+ def forward(self, conditioning):
20
+ for block in self.blocks:
21
+ conditioning = block(conditioning)
22
+ return conditioning
23
+
24
+
25
+ class SDControlNet(torch.nn.Module):
26
+ def __init__(self, global_pool=False):
27
+ super().__init__()
28
+ self.time_proj = Timesteps(320)
29
+ self.time_embedding = torch.nn.Sequential(
30
+ torch.nn.Linear(320, 1280),
31
+ torch.nn.SiLU(),
32
+ torch.nn.Linear(1280, 1280)
33
+ )
34
+ self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
35
+
36
+ self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
37
+
38
+ self.blocks = torch.nn.ModuleList([
39
+ # CrossAttnDownBlock2D
40
+ ResnetBlock(320, 320, 1280),
41
+ AttentionBlock(8, 40, 320, 1, 768),
42
+ PushBlock(),
43
+ ResnetBlock(320, 320, 1280),
44
+ AttentionBlock(8, 40, 320, 1, 768),
45
+ PushBlock(),
46
+ DownSampler(320),
47
+ PushBlock(),
48
+ # CrossAttnDownBlock2D
49
+ ResnetBlock(320, 640, 1280),
50
+ AttentionBlock(8, 80, 640, 1, 768),
51
+ PushBlock(),
52
+ ResnetBlock(640, 640, 1280),
53
+ AttentionBlock(8, 80, 640, 1, 768),
54
+ PushBlock(),
55
+ DownSampler(640),
56
+ PushBlock(),
57
+ # CrossAttnDownBlock2D
58
+ ResnetBlock(640, 1280, 1280),
59
+ AttentionBlock(8, 160, 1280, 1, 768),
60
+ PushBlock(),
61
+ ResnetBlock(1280, 1280, 1280),
62
+ AttentionBlock(8, 160, 1280, 1, 768),
63
+ PushBlock(),
64
+ DownSampler(1280),
65
+ PushBlock(),
66
+ # DownBlock2D
67
+ ResnetBlock(1280, 1280, 1280),
68
+ PushBlock(),
69
+ ResnetBlock(1280, 1280, 1280),
70
+ PushBlock(),
71
+ # UNetMidBlock2DCrossAttn
72
+ ResnetBlock(1280, 1280, 1280),
73
+ AttentionBlock(8, 160, 1280, 1, 768),
74
+ ResnetBlock(1280, 1280, 1280),
75
+ PushBlock()
76
+ ])
77
+
78
+ self.controlnet_blocks = torch.nn.ModuleList([
79
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
80
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
81
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
82
+ torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
83
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
84
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
85
+ torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
86
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
87
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
88
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
89
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
90
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
91
+ torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
92
+ ])
93
+
94
+ self.global_pool = global_pool
95
+
96
+ def forward(
97
+ self,
98
+ sample, timestep, encoder_hidden_states, conditioning,
99
+ tiled=False, tile_size=64, tile_stride=32,
100
+ ):
101
+ # 1. time
102
+ time_emb = self.time_proj(timestep[None]).to(sample.dtype)
103
+ time_emb = self.time_embedding(time_emb)
104
+ time_emb = time_emb.repeat(sample.shape[0], 1)
105
+
106
+ # 2. pre-process
107
+ height, width = sample.shape[2], sample.shape[3]
108
+ hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
109
+ text_emb = encoder_hidden_states
110
+ res_stack = [hidden_states]
111
+
112
+ # 3. blocks
113
+ for i, block in enumerate(self.blocks):
114
+ if tiled and not isinstance(block, PushBlock):
115
+ _, _, inter_height, _ = hidden_states.shape
116
+ resize_scale = inter_height / height
117
+ hidden_states = TileWorker().tiled_forward(
118
+ lambda x: block(x, time_emb, text_emb, res_stack)[0],
119
+ hidden_states,
120
+ int(tile_size * resize_scale),
121
+ int(tile_stride * resize_scale),
122
+ tile_device=hidden_states.device,
123
+ tile_dtype=hidden_states.dtype
124
+ )
125
+ else:
126
+ hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
127
+
128
+ # 4. ControlNet blocks
129
+ controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
130
+
131
+ # pool
132
+ if self.global_pool:
133
+ controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
134
+
135
+ return controlnet_res_stack
136
+
137
+ def state_dict_converter(self):
138
+ return SDControlNetStateDictConverter()
139
+
140
+
141
+ class SDControlNetStateDictConverter:
142
+ def __init__(self):
143
+ pass
144
+
145
+ def from_diffusers(self, state_dict):
146
+ # architecture
147
+ block_types = [
148
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
149
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
150
+ 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
151
+ 'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
152
+ 'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
153
+ 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
154
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
155
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
156
+ 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
157
+ ]
158
+
159
+ # controlnet_rename_dict
160
+ controlnet_rename_dict = {
161
+ "controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
162
+ "controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
163
+ "controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
164
+ "controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
165
+ "controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
166
+ "controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
167
+ "controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
168
+ "controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
169
+ "controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
170
+ "controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
171
+ "controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
172
+ "controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
173
+ "controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
174
+ "controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
175
+ "controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
176
+ "controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
177
+ }
178
+
179
+ # Rename each parameter
180
+ name_list = sorted([name for name in state_dict])
181
+ rename_dict = {}
182
+ block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
183
+ last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
184
+ for name in name_list:
185
+ names = name.split(".")
186
+ if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
187
+ pass
188
+ elif name in controlnet_rename_dict:
189
+ names = controlnet_rename_dict[name].split(".")
190
+ elif names[0] == "controlnet_down_blocks":
191
+ names[0] = "controlnet_blocks"
192
+ elif names[0] == "controlnet_mid_block":
193
+ names = ["controlnet_blocks", "12", names[-1]]
194
+ elif names[0] in ["time_embedding", "add_embedding"]:
195
+ if names[0] == "add_embedding":
196
+ names[0] = "add_time_embedding"
197
+ names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
198
+ elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
199
+ if names[0] == "mid_block":
200
+ names.insert(1, "0")
201
+ block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
202
+ block_type_with_id = ".".join(names[:4])
203
+ if block_type_with_id != last_block_type_with_id[block_type]:
204
+ block_id[block_type] += 1
205
+ last_block_type_with_id[block_type] = block_type_with_id
206
+ while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
207
+ block_id[block_type] += 1
208
+ block_type_with_id = ".".join(names[:4])
209
+ names = ["blocks", str(block_id[block_type])] + names[4:]
210
+ if "ff" in names:
211
+ ff_index = names.index("ff")
212
+ component = ".".join(names[ff_index:ff_index+3])
213
+ component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
214
+ names = names[:ff_index] + [component] + names[ff_index+3:]
215
+ if "to_out" in names:
216
+ names.pop(names.index("to_out") + 1)
217
+ else:
218
+ raise ValueError(f"Unknown parameters: {name}")
219
+ rename_dict[name] = ".".join(names)
220
+
221
+ # Convert state_dict
222
+ state_dict_ = {}
223
+ for name, param in state_dict.items():
224
+ if ".proj_in." in name or ".proj_out." in name:
225
+ param = param.squeeze()
226
+ if rename_dict[name] in [
227
+ "controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
228
+ "controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
229
+ ]:
230
+ continue
231
+ state_dict_[rename_dict[name]] = param
232
+ return state_dict_
233
+
234
+ def from_civitai(self, state_dict):
235
+ rename_dict = {
236
+ "control_model.time_embed.0.weight": "time_embedding.0.weight",
237
+ "control_model.time_embed.0.bias": "time_embedding.0.bias",
238
+ "control_model.time_embed.2.weight": "time_embedding.2.weight",
239
+ "control_model.time_embed.2.bias": "time_embedding.2.bias",
240
+ "control_model.input_blocks.0.0.weight": "conv_in.weight",
241
+ "control_model.input_blocks.0.0.bias": "conv_in.bias",
242
+ "control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
243
+ "control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
244
+ "control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
245
+ "control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
246
+ "control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
247
+ "control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
248
+ "control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
249
+ "control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
250
+ "control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
251
+ "control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
252
+ "control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
253
+ "control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
254
+ "control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
255
+ "control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
256
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
257
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
258
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
259
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
260
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
261
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
262
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
263
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
264
+ "control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
265
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
266
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
267
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
268
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
269
+ "control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
270
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
271
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
272
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
273
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
274
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
275
+ "control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
276
+ "control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
277
+ "control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
278
+ "control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
279
+ "control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
280
+ "control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
281
+ "control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
282
+ "control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
283
+ "control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
284
+ "control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
285
+ "control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
286
+ "control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
287
+ "control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
288
+ "control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
289
+ "control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
290
+ "control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
291
+ "control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
292
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
293
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
294
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
295
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
296
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
297
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
298
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
299
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
300
+ "control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
301
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
302
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
303
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
304
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
305
+ "control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
306
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
307
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
308
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
309
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
310
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
311
+ "control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
312
+ "control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
313
+ "control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
314
+ "control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
315
+ "control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
316
+ "control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
317
+ "control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
318
+ "control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
319
+ "control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
320
+ "control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
321
+ "control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
322
+ "control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
323
+ "control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
324
+ "control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
325
+ "control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
326
+ "control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
327
+ "control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
328
+ "control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
329
+ "control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
330
+ "control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
331
+ "control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
332
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
333
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
334
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
335
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
336
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
337
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
338
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
339
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
340
+ "control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
341
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
342
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
343
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
344
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
345
+ "control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
346
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
347
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
348
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
349
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
350
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
351
+ "control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
352
+ "control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
353
+ "control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
354
+ "control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
355
+ "control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
356
+ "control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
357
+ "control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
358
+ "control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
359
+ "control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
360
+ "control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
361
+ "control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
362
+ "control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
363
+ "control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
364
+ "control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
365
+ "control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
366
+ "control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
367
+ "control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
368
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
369
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
370
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
371
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
372
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
373
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
374
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
375
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
376
+ "control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
377
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
378
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
379
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
380
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
381
+ "control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
382
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
383
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
384
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
385
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
386
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
387
+ "control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
388
+ "control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
389
+ "control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
390
+ "control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
391
+ "control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
392
+ "control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
393
+ "control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
394
+ "control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
395
+ "control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
396
+ "control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
397
+ "control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
398
+ "control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
399
+ "control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
400
+ "control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
401
+ "control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
402
+ "control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
403
+ "control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
404
+ "control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
405
+ "control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
406
+ "control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
407
+ "control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
408
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
409
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
410
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
411
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
412
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
413
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
414
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
415
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
416
+ "control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
417
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
418
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
419
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
420
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
421
+ "control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
422
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
423
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
424
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
425
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
426
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
427
+ "control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
428
+ "control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
429
+ "control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
430
+ "control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
431
+ "control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
432
+ "control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
433
+ "control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
434
+ "control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
435
+ "control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
436
+ "control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
437
+ "control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
438
+ "control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
439
+ "control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
440
+ "control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
441
+ "control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
442
+ "control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
443
+ "control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
444
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
445
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
446
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
447
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
448
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
449
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
450
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
451
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
452
+ "control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
453
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
454
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
455
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
456
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
457
+ "control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
458
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
459
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
460
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
461
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
462
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
463
+ "control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
464
+ "control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
465
+ "control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
466
+ "control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
467
+ "control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
468
+ "control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
469
+ "control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
470
+ "control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
471
+ "control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
472
+ "control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
473
+ "control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
474
+ "control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
475
+ "control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
476
+ "control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
477
+ "control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
478
+ "control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
479
+ "control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
480
+ "control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
481
+ "control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
482
+ "control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
483
+ "control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
484
+ "control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
485
+ "control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
486
+ "control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
487
+ "control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
488
+ "control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
489
+ "control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
490
+ "control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
491
+ "control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
492
+ "control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
493
+ "control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
494
+ "control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
495
+ "control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
496
+ "control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
497
+ "control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
498
+ "control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
499
+ "control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
500
+ "control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
501
+ "control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
502
+ "control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
503
+ "control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
504
+ "control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
505
+ "control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
506
+ "control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
507
+ "control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
508
+ "control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
509
+ "control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
510
+ "control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
511
+ "control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
512
+ "control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
513
+ "control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
514
+ "control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
515
+ "control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
516
+ "control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
517
+ "control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
518
+ "control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
519
+ "control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
520
+ "control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
521
+ "control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
522
+ "control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
523
+ "control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
524
+ "control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
525
+ "control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
526
+ "control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
527
+ "control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
528
+ "control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
529
+ "control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
530
+ "control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
531
+ "control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
532
+ "control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
533
+ "control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
534
+ "control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
535
+ "control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
536
+ "control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
537
+ "control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
538
+ "control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
539
+ "control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
540
+ "control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
541
+ "control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
542
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
543
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
544
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
545
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
546
+ "control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
547
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
548
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
549
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
550
+ "control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
551
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
552
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
553
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
554
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
555
+ "control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
556
+ "control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
557
+ "control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
558
+ "control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
559
+ "control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
560
+ "control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
561
+ "control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
562
+ "control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
563
+ "control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
564
+ "control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
565
+ "control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
566
+ "control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
567
+ "control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
568
+ "control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
569
+ "control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
570
+ "control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
571
+ "control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
572
+ "control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
573
+ "control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
574
+ "control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
575
+ "control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
576
+ }
577
+ state_dict_ = {}
578
+ for name in state_dict:
579
+ if name in rename_dict:
580
+ param = state_dict[name]
581
+ if ".proj_in." in name or ".proj_out." in name:
582
+ param = param.squeeze()
583
+ state_dict_[rename_dict[name]] = param
584
+ return state_dict_