tetrisd commited on
Commit
b4cfcd3
Β·
1 Parent(s): e9e7dc5

Add V2 and some paper code

Browse files
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. README.md +2 -2
  2. app.py +195 -144
  3. diffusers/__init__.py +0 -60
  4. diffusers/__pycache__/__init__.cpython-310.pyc +0 -0
  5. diffusers/__pycache__/configuration_utils.cpython-310.pyc +0 -0
  6. diffusers/__pycache__/dependency_versions_check.cpython-310.pyc +0 -0
  7. diffusers/__pycache__/dependency_versions_table.cpython-310.pyc +0 -0
  8. diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc +0 -0
  9. diffusers/__pycache__/hub_utils.cpython-310.pyc +0 -0
  10. diffusers/__pycache__/modeling_utils.cpython-310.pyc +0 -0
  11. diffusers/__pycache__/onnx_utils.cpython-310.pyc +0 -0
  12. diffusers/__pycache__/optimization.cpython-310.pyc +0 -0
  13. diffusers/__pycache__/pipeline_utils.cpython-310.pyc +0 -0
  14. diffusers/__pycache__/testing_utils.cpython-310.pyc +0 -0
  15. diffusers/__pycache__/training_utils.cpython-310.pyc +0 -0
  16. diffusers/commands/__init__.py +0 -27
  17. diffusers/commands/__pycache__/__init__.cpython-310.pyc +0 -0
  18. diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc +0 -0
  19. diffusers/commands/__pycache__/env.cpython-310.pyc +0 -0
  20. diffusers/commands/diffusers_cli.py +0 -41
  21. diffusers/commands/env.py +0 -70
  22. diffusers/configuration_utils.py +0 -403
  23. diffusers/dependency_versions_check.py +0 -47
  24. diffusers/dependency_versions_table.py +0 -26
  25. diffusers/dynamic_modules_utils.py +0 -335
  26. diffusers/hub_utils.py +0 -197
  27. diffusers/modeling_utils.py +0 -542
  28. diffusers/models/__init__.py +0 -17
  29. diffusers/models/__pycache__/__init__.cpython-310.pyc +0 -0
  30. diffusers/models/__pycache__/attention.cpython-310.pyc +0 -0
  31. diffusers/models/__pycache__/embeddings.cpython-310.pyc +0 -0
  32. diffusers/models/__pycache__/resnet.cpython-310.pyc +0 -0
  33. diffusers/models/__pycache__/unet_2d.cpython-310.pyc +0 -0
  34. diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc +0 -0
  35. diffusers/models/__pycache__/unet_blocks.cpython-310.pyc +0 -0
  36. diffusers/models/__pycache__/vae.cpython-310.pyc +0 -0
  37. diffusers/models/attention.py +0 -409
  38. diffusers/models/embeddings.py +0 -115
  39. diffusers/models/resnet.py +0 -483
  40. diffusers/models/unet_2d.py +0 -246
  41. diffusers/models/unet_2d_condition.py +0 -272
  42. diffusers/models/unet_blocks.py +0 -1484
  43. diffusers/models/vae.py +0 -585
  44. diffusers/onnx_utils.py +0 -189
  45. diffusers/optimization.py +0 -275
  46. diffusers/pipeline_utils.py +0 -417
  47. diffusers/pipelines/__init__.py +0 -19
  48. diffusers/pipelines/__pycache__/__init__.cpython-310.pyc +0 -0
  49. diffusers/pipelines/ddim/__init__.py +0 -2
  50. diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc +0 -0
README.md CHANGED
@@ -1,12 +1,12 @@
1
  ---
2
- title: Stable Diffusion Attentive Attribution Maps
3
  emoji: πŸ‘€
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
9
- pinned: false
10
  license: mit
11
  ---
12
 
 
1
  ---
2
+ title: Stable Diffusion V2 Attentive Attribution Maps
3
  emoji: πŸ‘€
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
  sdk_version: 3.4.1
8
  app_file: app.py
9
+ pinned: true
10
  license: mit
11
  ---
12
 
app.py CHANGED
@@ -1,170 +1,221 @@
1
- from huggingface_hub import HfApi, HfFolder
2
- import os
3
-
4
- api = HfApi()
5
- api.set_access_token(os.environ['HF_SECRET'])
6
- folder = HfFolder()
7
- folder.save_token(os.environ['HF_SECRET'])
8
-
9
- from threading import Lock
10
  import math
11
- import os
12
- import random
 
 
13
 
 
14
  from diffusers import StableDiffusionPipeline
15
- from diffusers.models.attention import get_global_heat_map, clear_heat_maps
16
  from matplotlib import pyplot as plt
17
  import gradio as gr
18
  import torch
19
- import torch.nn.functional as F
20
- import spacy
21
 
 
 
22
 
23
- if not os.environ.get('NO_DOWNLOAD_SPACY'):
24
- spacy.cli.download('en_core_web_sm')
25
 
 
 
 
26
 
27
- model_id = "runwayml/stable-diffusion-v1-5"
28
- device = "cuda"
29
 
30
- gen = torch.Generator(device='cuda')
31
- gen.manual_seed(12758672)
32
- orig_state = gen.get_state()
33
- pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(device)
34
- lock = Lock()
35
- nlp = spacy.load('en_core_web_sm')
36
 
 
 
 
 
 
 
37
 
38
- def expand_m(m, n: int = 1, o=512, mode='bicubic'):
39
- m = m.unsqueeze(0).unsqueeze(0) / n
40
- m = F.interpolate(m.float().detach(), size=(o, o), mode='bicubic', align_corners=False)
41
- m = (m - m.min()) / (m.max() - m.min() + 1e-8)
42
- m = m.cpu().detach()
 
 
 
43
 
44
- return m
45
 
46
 
47
- @torch.no_grad()
48
- def predict(prompt, inf_steps, threshold):
49
- global lock
50
- with torch.cuda.amp.autocast(), lock:
51
- try:
52
- plt.close('all')
53
- except:
54
- pass
55
 
56
- gen.set_state(orig_state.clone())
57
- clear_heat_maps()
 
 
 
 
 
 
58
 
59
- out = pipe(prompt, guidance_scale=7.5, height=512, width=512, do_intermediates=False, generator=gen, num_inference_steps=int(inf_steps))
60
- heat_maps = get_global_heat_map()
61
 
62
- with torch.cuda.amp.autocast(dtype=torch.float32):
63
- m = 0
64
- n = 0
65
- w = ''
66
- w_idx = 0
67
 
68
- fig, ax = plt.subplots()
69
- ax.imshow(out.images[0].cpu().float().detach().permute(1, 2, 0).numpy())
70
- ax.set_xticks([])
71
- ax.set_yticks([])
72
 
73
- fig1, axs1 = plt.subplots(math.ceil(len(out.words) / 4), 4)#, figsize=(20, 20))
74
- fig2, axs2 = plt.subplots(math.ceil(len(out.words) / 4), 4) # , figsize=(20, 20))
 
 
75
 
76
- for idx in range(len(out.words) + 1):
77
- if idx == 0:
78
- continue
79
 
80
- word = out.words[idx - 1]
81
- m += heat_maps[idx]
82
- n += 1
83
- w += word
84
 
85
- if '</w>' not in word:
86
- continue
87
- else:
88
- mplot = expand_m(m, n)
89
- spotlit_im = out.images[0].cpu().float().detach()
90
- w = w.replace('</w>', '')
91
- spotlit_im2 = torch.cat((spotlit_im, (1 - mplot.squeeze(0)).pow(1)), dim=0)
92
-
93
- if len(out.words) <= 4:
94
- a1 = axs1[w_idx % 4]
95
- a2 = axs2[w_idx % 4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  else:
97
- a1 = axs1[w_idx // 4, w_idx % 4]
98
- a2 = axs2[w_idx // 4, w_idx % 4]
99
-
100
- a1.set_xticks([])
101
- a1.set_yticks([])
102
- a1.imshow(mplot.squeeze().numpy(), cmap='jet')
103
- a1.imshow(spotlit_im2.permute(1, 2, 0).numpy())
104
- a1.set_title(w)
105
-
106
- mask = torch.ones_like(mplot)
107
- mask[mplot < threshold * mplot.max()] = 0
108
- im2 = spotlit_im * mask.squeeze(0)
109
- a2.set_xticks([])
110
- a2.set_yticks([])
111
- a2.imshow(im2.permute(1, 2, 0).numpy())
112
- a2.set_title(w)
113
- m = 0
114
- n = 0
115
- w_idx += 1
116
- w = ''
117
-
118
- for idx in range(w_idx, len(axs1.flatten())):
119
- fig1.delaxes(axs1.flatten()[idx])
120
- fig2.delaxes(axs2.flatten()[idx])
121
-
122
- return fig, fig1, fig2
123
-
124
-
125
- def set_prompt(prompt):
126
- return prompt
127
-
128
-
129
- with gr.Blocks() as demo:
130
- md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion
131
- Check out the paper: [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885).
132
- See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub.
133
-
134
- **Update**: We got a community grant! I'll continue running and updating the space, with a major release planned in December.
135
- '''
136
- gr.Markdown(md)
137
-
138
- with gr.Row():
139
- with gr.Column():
140
- dropdown = gr.Dropdown([
141
- 'An angry, bald man doing research',
142
- 'Doing research at Comcast Applied AI labs',
143
- 'Professor Jimmy Lin from the University of Waterloo',
144
- 'Yann Lecun teaching machine learning on a chalkboard',
145
- 'A cat eating cake for her birthday',
146
- 'Steak and dollars on a plate',
147
- 'A fox, a dog, and a wolf in a field'
148
- ], label='Examples', value='An angry, bald man doing research')
149
-
150
- text = gr.Textbox(label='Prompt', value='An angry, bald man doing research')
151
- slider1 = gr.Slider(15, 35, value=25, interactive=True, step=1, label='Inference steps')
152
- slider2 = gr.Slider(0, 1.0, value=0.4, interactive=True, step=0.05, label='Threshold (tau)')
153
- submit_btn = gr.Button('Submit')
154
-
155
- with gr.Tab('Original Image'):
156
- p0 = gr.Plot()
157
-
158
- with gr.Tab('Soft DAAM Maps'):
159
- p1 = gr.Plot()
160
-
161
- with gr.Tab('Hard DAAM Maps'):
162
- p2 = gr.Plot()
163
-
164
- submit_btn.click(fn=predict, inputs=[text, slider1, slider2], outputs=[p0, p1, p2])
165
- dropdown.change(set_prompt, dropdown, text)
166
- dropdown.update()
167
-
168
-
169
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
+ import time
3
+ from threading import Lock
4
+ from typing import Any, List
5
+ import argparse
6
 
7
+ import numpy as np
8
  from diffusers import StableDiffusionPipeline
 
9
  from matplotlib import pyplot as plt
10
  import gradio as gr
11
  import torch
12
+ from spacy import displacy
 
13
 
14
+ from daam import trace
15
+ from daam.utils import set_seed, cached_nlp
16
 
 
 
17
 
18
+ def dependency(text):
19
+ doc = cached_nlp(text)
20
+ svg = displacy.render(doc, style='dep', options={'compact': True, 'distance': 100})
21
 
22
+ return svg
 
23
 
 
 
 
 
 
 
24
 
25
+ def get_tokenizing_mapping(prompt: str, tokenizer: Any) -> List[List[int]]:
26
+ tokens = tokenizer.tokenize(prompt)
27
+ merge_idxs = []
28
+ words = []
29
+ curr_idxs = []
30
+ curr_word = ''
31
 
32
+ for i, token in enumerate(tokens):
33
+ curr_idxs.append(i + 1) # because of the [CLS] token
34
+ curr_word += token
35
+ if '</w>' in token:
36
+ merge_idxs.append(curr_idxs)
37
+ curr_idxs = []
38
+ words.append(curr_word[:-4])
39
+ curr_word = ''
40
 
41
+ return merge_idxs, words
42
 
43
 
44
+ def get_args():
45
+ model_id_map = {
46
+ 'v1': 'runwayml/stable-diffusion-v1-5',
47
+ 'v2-base': 'stabilityai/stable-diffusion-2-base',
48
+ 'v2-large': 'stabilityai/stable-diffusion-2'
49
+ }
 
 
50
 
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--model', '-m', type=str, default='v2-base', choices=list(model_id_map.keys()), help="which diffusion model to use")
53
+ parser.add_argument('--seed', '-s', type=int, default=0, help="the random seed")
54
+ parser.add_argument('--port', '-p', type=int, default=8080, help="the port to launch the demo")
55
+ parser.add_argument('--no-cuda', action='store_true', help="Use CPUs instead of GPUs")
56
+ args = parser.parse_args()
57
+ args.model = model_id_map[args.model]
58
+ return args
59
 
 
 
60
 
61
+ def main():
62
+ args = get_args()
63
+ plt.switch_backend('agg')
 
 
64
 
65
+ device = "cpu" if args.no_cuda else "cuda"
66
+ pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True).to(device)
67
+ lock = Lock()
 
68
 
69
+ @torch.no_grad()
70
+ def update_dropdown(prompt):
71
+ tokens = [''] + [x.text for x in cached_nlp(prompt) if x.pos_ == 'ADJ']
72
+ return gr.Dropdown.update(choices=tokens), dependency(prompt)
73
 
74
+ @torch.no_grad()
75
+ def plot(prompt, choice, replaced_word, inf_steps, is_random_seed):
76
+ new_prompt = prompt.replace(',', ', ').replace('.', '. ')
77
 
78
+ if choice:
79
+ if not replaced_word:
80
+ replaced_word = '.'
 
81
 
82
+ new_prompt = [replaced_word if tok.text == choice else tok.text for tok in cached_nlp(prompt)]
83
+ new_prompt = ' '.join(new_prompt)
84
+
85
+ merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer)
86
+ with torch.cuda.amp.autocast(dtype=torch.float16), lock:
87
+ try:
88
+ plt.close('all')
89
+ plt.clf()
90
+ except:
91
+ pass
92
+
93
+ seed = int(time.time()) if is_random_seed else args.seed
94
+ gen = set_seed(seed)
95
+ prompt = prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later
96
+
97
+ if choice:
98
+ new_prompt = new_prompt.replace(',', ', ').replace('.', '. ') # hacky fix to address later
99
+
100
+ with trace(pipe, save_heads=new_prompt != prompt) as tc:
101
+ out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
102
+ image = np.array(out.images[0]) / 255
103
+ heat_map = tc.compute_global_heat_map()
104
+
105
+ if new_prompt == prompt:
106
+ image2 = image
107
  else:
108
+ gen = set_seed(seed)
109
+
110
+ with trace(pipe, load_heads=True) as tc:
111
+ out2 = pipe(new_prompt, num_inference_steps=inf_steps, generator=gen)
112
+ image2 = np.array(out2.images[0]) / 255
113
+ else:
114
+ with trace(pipe) as tc:
115
+ out = pipe(prompt, num_inference_steps=inf_steps, generator=gen)
116
+ image = np.array(out.images[0]) / 255
117
+ heat_map = tc.compute_global_heat_map()
118
+
119
+ # the main image
120
+ if new_prompt == prompt:
121
+ fig, ax = plt.subplots()
122
+ ax.imshow(image)
123
+ ax.set_xticks([])
124
+ ax.set_yticks([])
125
+ else:
126
+ fig, ax = plt.subplots(1, 2)
127
+ ax[0].imshow(image)
128
+
129
+ if choice:
130
+ ax[1].imshow(image2)
131
+
132
+ ax[0].set_title(choice)
133
+ ax[0].set_xticks([])
134
+ ax[0].set_yticks([])
135
+ ax[1].set_title(replaced_word)
136
+ ax[1].set_xticks([])
137
+ ax[1].set_yticks([])
138
+
139
+ # the heat maps
140
+ num_cells = 4
141
+ w = int(num_cells * 3.5)
142
+ h = math.ceil(len(words) / num_cells * 4.5)
143
+ fig_soft, axs_soft = plt.subplots(math.ceil(len(words) / num_cells), num_cells, figsize=(w, h))
144
+ axs_soft = axs_soft.flatten()
145
+ with torch.cuda.amp.autocast(dtype=torch.float32):
146
+ for idx, parsed_map in enumerate(heat_map.parsed_heat_maps()):
147
+ word_ax_soft = axs_soft[idx]
148
+ word_ax_soft.set_xticks([])
149
+ word_ax_soft.set_yticks([])
150
+ parsed_map.word_heat_map.plot_overlay(out.images[0], ax=word_ax_soft)
151
+ word_ax_soft.set_title(parsed_map.word_heat_map.word, fontsize=12)
152
+
153
+ for idx in range(len(words), len(axs_soft)):
154
+ fig_soft.delaxes(axs_soft[idx])
155
+
156
+ return fig, fig_soft
157
+
158
+ with gr.Blocks(css='scrollbar.css') as demo:
159
+ md = '''# DAAM: Attention Maps for Interpreting Stable Diffusion
160
+ Check out the **new** paper (2022/12/7): [What the DAAM: Interpreting Stable Diffusion Using Cross Attention](http://arxiv.org/abs/2210.04885).
161
+ See our (much cleaner) [DAAM codebase](https://github.com/castorini/daam) on GitHub.
162
+ '''
163
+ gr.Markdown(md)
164
+
165
+ with gr.Row():
166
+ with gr.Column():
167
+ dropdown = gr.Dropdown([
168
+ 'An angry, bald man doing research',
169
+ 'A bear and a moose',
170
+ 'A blue car driving through the city',
171
+ 'Monkey walking with hat',
172
+ 'Doing research at Comcast Applied AI labs',
173
+ 'Professor Jimmy Lin from the modern University of Waterloo',
174
+ 'Yann Lecun teaching machine learning on a green chalkboard',
175
+ 'A brown cat eating yummy cake for her birthday',
176
+ 'A brown fox, a white dog, and a blue wolf in a green field',
177
+ ], label='Examples', value='An angry, bald man doing research')
178
+
179
+ text = gr.Textbox(label='Prompt', value='An angry, bald man doing research')
180
+
181
+ with gr.Row():
182
+ doc = cached_nlp('An angry, bald man doing research')
183
+ tokens = [''] + [x.text for x in doc if x.pos_ == 'ADJ']
184
+ dropdown2 = gr.Dropdown(tokens, label='Adjective to replace', interactive=True)
185
+ text2 = gr.Textbox(label='New adjective', value='')
186
+
187
+ checkbox = gr.Checkbox(value=False, label='Random seed')
188
+ slider1 = gr.Slider(15, 30, value=25, interactive=True, step=1, label='Inference steps')
189
+
190
+ submit_btn = gr.Button('Submit', elem_id='submit-btn')
191
+ viz = gr.HTML(dependency('An angry, bald man doing research'), elem_id='viz')
192
+
193
+ with gr.Column():
194
+ with gr.Tab('Images'):
195
+ p0 = gr.Plot()
196
+
197
+ with gr.Tab('DAAM Maps'):
198
+ p1 = gr.Plot()
199
+
200
+ text.change(fn=update_dropdown, inputs=[text], outputs=[dropdown2, viz])
201
+
202
+ submit_btn.click(
203
+ fn=plot,
204
+ inputs=[text, dropdown2, text2, slider1, checkbox],
205
+ outputs=[p0, p1])
206
+ dropdown.change(lambda prompt: prompt, dropdown, text)
207
+ dropdown.update()
208
+
209
+ while True:
210
+ try:
211
+ demo.launch()
212
+ except OSError:
213
+ gr.close_all()
214
+ except KeyboardInterrupt:
215
+ gr.close_all()
216
+ break
217
+
218
+
219
+ if __name__ == '__main__':
220
+ main()
221
 
diffusers/__init__.py DELETED
@@ -1,60 +0,0 @@
1
- from .utils import (
2
- is_inflect_available,
3
- is_onnx_available,
4
- is_scipy_available,
5
- is_transformers_available,
6
- is_unidecode_available,
7
- )
8
-
9
-
10
- __version__ = "0.3.0"
11
-
12
- from .configuration_utils import ConfigMixin
13
- from .modeling_utils import ModelMixin
14
- from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
15
- from .onnx_utils import OnnxRuntimeModel
16
- from .optimization import (
17
- get_constant_schedule,
18
- get_constant_schedule_with_warmup,
19
- get_cosine_schedule_with_warmup,
20
- get_cosine_with_hard_restarts_schedule_with_warmup,
21
- get_linear_schedule_with_warmup,
22
- get_polynomial_decay_schedule_with_warmup,
23
- get_scheduler,
24
- )
25
- from .pipeline_utils import DiffusionPipeline
26
- from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
27
- from .schedulers import (
28
- DDIMScheduler,
29
- DDPMScheduler,
30
- KarrasVeScheduler,
31
- PNDMScheduler,
32
- SchedulerMixin,
33
- ScoreSdeVeScheduler,
34
- )
35
- from .utils import logging
36
-
37
-
38
- if is_scipy_available():
39
- from .schedulers import LMSDiscreteScheduler
40
- else:
41
- from .utils.dummy_scipy_objects import * # noqa F403
42
-
43
- from .training_utils import EMAModel
44
-
45
-
46
- if is_transformers_available():
47
- from .pipelines import (
48
- LDMTextToImagePipeline,
49
- StableDiffusionImg2ImgPipeline,
50
- StableDiffusionInpaintPipeline,
51
- StableDiffusionPipeline,
52
- )
53
- else:
54
- from .utils.dummy_transformers_objects import * # noqa F403
55
-
56
-
57
- if is_transformers_available() and is_onnx_available():
58
- from .pipelines import StableDiffusionOnnxPipeline
59
- else:
60
- from .utils.dummy_transformers_and_onnx_objects import * # noqa F403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (1.85 kB)
 
diffusers/__pycache__/configuration_utils.cpython-310.pyc DELETED
Binary file (15.4 kB)
 
diffusers/__pycache__/dependency_versions_check.cpython-310.pyc DELETED
Binary file (967 Bytes)
 
diffusers/__pycache__/dependency_versions_table.cpython-310.pyc DELETED
Binary file (819 Bytes)
 
diffusers/__pycache__/dynamic_modules_utils.cpython-310.pyc DELETED
Binary file (11.6 kB)
 
diffusers/__pycache__/hub_utils.cpython-310.pyc DELETED
Binary file (5.46 kB)
 
diffusers/__pycache__/modeling_utils.cpython-310.pyc DELETED
Binary file (18.7 kB)
 
diffusers/__pycache__/onnx_utils.cpython-310.pyc DELETED
Binary file (6.3 kB)
 
diffusers/__pycache__/optimization.cpython-310.pyc DELETED
Binary file (10.1 kB)
 
diffusers/__pycache__/pipeline_utils.cpython-310.pyc DELETED
Binary file (14 kB)
 
diffusers/__pycache__/testing_utils.cpython-310.pyc DELETED
Binary file (1.66 kB)
 
diffusers/__pycache__/training_utils.cpython-310.pyc DELETED
Binary file (3.64 kB)
 
diffusers/commands/__init__.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from abc import ABC, abstractmethod
16
- from argparse import ArgumentParser
17
-
18
-
19
- class BaseDiffusersCLICommand(ABC):
20
- @staticmethod
21
- @abstractmethod
22
- def register_subcommand(parser: ArgumentParser):
23
- raise NotImplementedError()
24
-
25
- @abstractmethod
26
- def run(self):
27
- raise NotImplementedError()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/commands/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (817 Bytes)
 
diffusers/commands/__pycache__/diffusers_cli.cpython-310.pyc DELETED
Binary file (778 Bytes)
 
diffusers/commands/__pycache__/env.cpython-310.pyc DELETED
Binary file (2.17 kB)
 
diffusers/commands/diffusers_cli.py DELETED
@@ -1,41 +0,0 @@
1
- #!/usr/bin/env python
2
- # Copyright 2022 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from argparse import ArgumentParser
17
-
18
- from .env import EnvironmentCommand
19
-
20
-
21
- def main():
22
- parser = ArgumentParser("Diffusers CLI tool", usage="diffusers-cli <command> [<args>]")
23
- commands_parser = parser.add_subparsers(help="diffusers-cli command helpers")
24
-
25
- # Register commands
26
- EnvironmentCommand.register_subcommand(commands_parser)
27
-
28
- # Let's go
29
- args = parser.parse_args()
30
-
31
- if not hasattr(args, "func"):
32
- parser.print_help()
33
- exit(1)
34
-
35
- # Run
36
- service = args.func(args)
37
- service.run()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/commands/env.py DELETED
@@ -1,70 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import platform
16
- from argparse import ArgumentParser
17
-
18
- import huggingface_hub
19
-
20
- from .. import __version__ as version
21
- from ..utils import is_torch_available, is_transformers_available
22
- from . import BaseDiffusersCLICommand
23
-
24
-
25
- def info_command_factory(_):
26
- return EnvironmentCommand()
27
-
28
-
29
- class EnvironmentCommand(BaseDiffusersCLICommand):
30
- @staticmethod
31
- def register_subcommand(parser: ArgumentParser):
32
- download_parser = parser.add_parser("env")
33
- download_parser.set_defaults(func=info_command_factory)
34
-
35
- def run(self):
36
- hub_version = huggingface_hub.__version__
37
-
38
- pt_version = "not installed"
39
- pt_cuda_available = "NA"
40
- if is_torch_available():
41
- import torch
42
-
43
- pt_version = torch.__version__
44
- pt_cuda_available = torch.cuda.is_available()
45
-
46
- transformers_version = "not installed"
47
- if is_transformers_available:
48
- import transformers
49
-
50
- transformers_version = transformers.__version__
51
-
52
- info = {
53
- "`diffusers` version": version,
54
- "Platform": platform.platform(),
55
- "Python version": platform.python_version(),
56
- "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
57
- "Huggingface_hub version": hub_version,
58
- "Transformers version": transformers_version,
59
- "Using GPU in script?": "<fill in>",
60
- "Using distributed or parallel set-up in script?": "<fill in>",
61
- }
62
-
63
- print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the two last points.\n")
64
- print(self.format_dict(info))
65
-
66
- return info
67
-
68
- @staticmethod
69
- def format_dict(d):
70
- return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/configuration_utils.py DELETED
@@ -1,403 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
- """ ConfigMixinuration base class and utilities."""
17
- import functools
18
- import inspect
19
- import json
20
- import os
21
- import re
22
- from collections import OrderedDict
23
- from typing import Any, Dict, Tuple, Union
24
-
25
- from huggingface_hub import hf_hub_download
26
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
27
- from requests import HTTPError
28
-
29
- from . import __version__
30
- from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
31
-
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
- _re_configuration_file = re.compile(r"config\.(.*)\.json")
36
-
37
-
38
- class ConfigMixin:
39
- r"""
40
- Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all
41
- methods for loading/downloading/saving classes inheriting from [`ConfigMixin`] with
42
- - [`~ConfigMixin.from_config`]
43
- - [`~ConfigMixin.save_config`]
44
-
45
- Class attributes:
46
- - **config_name** (`str`) -- A filename under which the config should stored when calling
47
- [`~ConfigMixin.save_config`] (should be overriden by parent class).
48
- - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
49
- overriden by parent class).
50
- """
51
- config_name = None
52
- ignore_for_config = []
53
-
54
- def register_to_config(self, **kwargs):
55
- if self.config_name is None:
56
- raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
57
- kwargs["_class_name"] = self.__class__.__name__
58
- kwargs["_diffusers_version"] = __version__
59
-
60
- for key, value in kwargs.items():
61
- try:
62
- setattr(self, key, value)
63
- except AttributeError as err:
64
- logger.error(f"Can't set {key} with value {value} for {self}")
65
- raise err
66
-
67
- if not hasattr(self, "_internal_dict"):
68
- internal_dict = kwargs
69
- else:
70
- previous_dict = dict(self._internal_dict)
71
- internal_dict = {**self._internal_dict, **kwargs}
72
- logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
73
-
74
- self._internal_dict = FrozenDict(internal_dict)
75
-
76
- def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
77
- """
78
- Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
79
- [`~ConfigMixin.from_config`] class method.
80
-
81
- Args:
82
- save_directory (`str` or `os.PathLike`):
83
- Directory where the configuration JSON file will be saved (will be created if it does not exist).
84
- """
85
- if os.path.isfile(save_directory):
86
- raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
87
-
88
- os.makedirs(save_directory, exist_ok=True)
89
-
90
- # If we save using the predefined names, we can load using `from_config`
91
- output_config_file = os.path.join(save_directory, self.config_name)
92
-
93
- self.to_json_file(output_config_file)
94
- logger.info(f"ConfigMixinuration saved in {output_config_file}")
95
-
96
- @classmethod
97
- def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs):
98
- r"""
99
- Instantiate a Python class from a pre-defined JSON-file.
100
-
101
- Parameters:
102
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
103
- Can be either:
104
-
105
- - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an
106
- organization name, like `google/ddpm-celebahq-256`.
107
- - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g.,
108
- `./my_model_directory/`.
109
-
110
- cache_dir (`Union[str, os.PathLike]`, *optional*):
111
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
112
- standard cache should not be used.
113
- ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
114
- Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
115
- as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
116
- checkpoint with 3 labels).
117
- force_download (`bool`, *optional*, defaults to `False`):
118
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
119
- cached versions if they exist.
120
- resume_download (`bool`, *optional*, defaults to `False`):
121
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
122
- file exists.
123
- proxies (`Dict[str, str]`, *optional*):
124
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
125
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
126
- output_loading_info(`bool`, *optional*, defaults to `False`):
127
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
128
- local_files_only(`bool`, *optional*, defaults to `False`):
129
- Whether or not to only look at local files (i.e., do not try to download the model).
130
- use_auth_token (`str` or *bool*, *optional*):
131
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
132
- when running `transformers-cli login` (stored in `~/.huggingface`).
133
- revision (`str`, *optional*, defaults to `"main"`):
134
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
135
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
136
- identifier allowed by git.
137
- mirror (`str`, *optional*):
138
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
139
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
140
- Please refer to the mirror site for more information.
141
-
142
- <Tip>
143
-
144
- Passing `use_auth_token=True`` is required when you want to use a private model.
145
-
146
- </Tip>
147
-
148
- <Tip>
149
-
150
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
151
- use this method in a firewalled environment.
152
-
153
- </Tip>
154
-
155
- """
156
- config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
157
-
158
- init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
159
-
160
- model = cls(**init_dict)
161
-
162
- if return_unused_kwargs:
163
- return model, unused_kwargs
164
- else:
165
- return model
166
-
167
- @classmethod
168
- def get_config_dict(
169
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
170
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
171
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
172
- force_download = kwargs.pop("force_download", False)
173
- resume_download = kwargs.pop("resume_download", False)
174
- proxies = kwargs.pop("proxies", None)
175
- use_auth_token = kwargs.pop("use_auth_token", None)
176
- local_files_only = kwargs.pop("local_files_only", False)
177
- revision = kwargs.pop("revision", None)
178
- subfolder = kwargs.pop("subfolder", None)
179
-
180
- user_agent = {"file_type": "config"}
181
-
182
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
183
-
184
- if cls.config_name is None:
185
- raise ValueError(
186
- "`self.config_name` is not defined. Note that one should not load a config from "
187
- "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
188
- )
189
-
190
- if os.path.isfile(pretrained_model_name_or_path):
191
- config_file = pretrained_model_name_or_path
192
- elif os.path.isdir(pretrained_model_name_or_path):
193
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
194
- # Load from a PyTorch checkpoint
195
- config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
196
- elif subfolder is not None and os.path.isfile(
197
- os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
198
- ):
199
- config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
200
- else:
201
- raise EnvironmentError(
202
- f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
203
- )
204
- else:
205
- try:
206
- # Load from URL or cache if already cached
207
- config_file = hf_hub_download(
208
- pretrained_model_name_or_path,
209
- filename=cls.config_name,
210
- cache_dir=cache_dir,
211
- force_download=force_download,
212
- proxies=proxies,
213
- resume_download=resume_download,
214
- local_files_only=local_files_only,
215
- use_auth_token=use_auth_token,
216
- user_agent=user_agent,
217
- subfolder=subfolder,
218
- revision=revision,
219
- )
220
-
221
- except RepositoryNotFoundError:
222
- raise EnvironmentError(
223
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
224
- " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
225
- " token having permission to this repo with `use_auth_token` or log in with `huggingface-cli"
226
- " login` and pass `use_auth_token=True`."
227
- )
228
- except RevisionNotFoundError:
229
- raise EnvironmentError(
230
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
231
- " this model name. Check the model page at"
232
- f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
233
- )
234
- except EntryNotFoundError:
235
- raise EnvironmentError(
236
- f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
237
- )
238
- except HTTPError as err:
239
- raise EnvironmentError(
240
- "There was a specific connection error when trying to load"
241
- f" {pretrained_model_name_or_path}:\n{err}"
242
- )
243
- except ValueError:
244
- raise EnvironmentError(
245
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
246
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
247
- f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
248
- " run the library in offline mode at"
249
- " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
250
- )
251
- except EnvironmentError:
252
- raise EnvironmentError(
253
- f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
254
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
255
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
256
- f"containing a {cls.config_name} file"
257
- )
258
-
259
- try:
260
- # Load config dict
261
- config_dict = cls._dict_from_json_file(config_file)
262
- except (json.JSONDecodeError, UnicodeDecodeError):
263
- raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
264
-
265
- return config_dict
266
-
267
- @classmethod
268
- def extract_init_dict(cls, config_dict, **kwargs):
269
- expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
270
- expected_keys.remove("self")
271
- # remove general kwargs if present in dict
272
- if "kwargs" in expected_keys:
273
- expected_keys.remove("kwargs")
274
- # remove keys to be ignored
275
- if len(cls.ignore_for_config) > 0:
276
- expected_keys = expected_keys - set(cls.ignore_for_config)
277
- init_dict = {}
278
- for key in expected_keys:
279
- if key in kwargs:
280
- # overwrite key
281
- init_dict[key] = kwargs.pop(key)
282
- elif key in config_dict:
283
- # use value from config dict
284
- init_dict[key] = config_dict.pop(key)
285
-
286
- unused_kwargs = config_dict.update(kwargs)
287
-
288
- passed_keys = set(init_dict.keys())
289
- if len(expected_keys - passed_keys) > 0:
290
- logger.warning(
291
- f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
292
- )
293
-
294
- return init_dict, unused_kwargs
295
-
296
- @classmethod
297
- def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
298
- with open(json_file, "r", encoding="utf-8") as reader:
299
- text = reader.read()
300
- return json.loads(text)
301
-
302
- def __repr__(self):
303
- return f"{self.__class__.__name__} {self.to_json_string()}"
304
-
305
- @property
306
- def config(self) -> Dict[str, Any]:
307
- return self._internal_dict
308
-
309
- def to_json_string(self) -> str:
310
- """
311
- Serializes this instance to a JSON string.
312
-
313
- Returns:
314
- `str`: String containing all the attributes that make up this configuration instance in JSON format.
315
- """
316
- config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
317
- return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
318
-
319
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
320
- """
321
- Save this instance to a JSON file.
322
-
323
- Args:
324
- json_file_path (`str` or `os.PathLike`):
325
- Path to the JSON file in which this configuration instance's parameters will be saved.
326
- """
327
- with open(json_file_path, "w", encoding="utf-8") as writer:
328
- writer.write(self.to_json_string())
329
-
330
-
331
- class FrozenDict(OrderedDict):
332
- def __init__(self, *args, **kwargs):
333
- super().__init__(*args, **kwargs)
334
-
335
- for key, value in self.items():
336
- setattr(self, key, value)
337
-
338
- self.__frozen = True
339
-
340
- def __delitem__(self, *args, **kwargs):
341
- raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
342
-
343
- def setdefault(self, *args, **kwargs):
344
- raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
345
-
346
- def pop(self, *args, **kwargs):
347
- raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
348
-
349
- def update(self, *args, **kwargs):
350
- raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
351
-
352
- def __setattr__(self, name, value):
353
- if hasattr(self, "__frozen") and self.__frozen:
354
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
355
- super().__setattr__(name, value)
356
-
357
- def __setitem__(self, name, value):
358
- if hasattr(self, "__frozen") and self.__frozen:
359
- raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
360
- super().__setitem__(name, value)
361
-
362
-
363
- def register_to_config(init):
364
- r"""
365
- Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
366
- automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
367
- shouldn't be registered in the config, use the `ignore_for_config` class variable
368
-
369
- Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
370
- """
371
-
372
- @functools.wraps(init)
373
- def inner_init(self, *args, **kwargs):
374
- # Ignore private kwargs in the init.
375
- init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
376
- init(self, *args, **init_kwargs)
377
- if not isinstance(self, ConfigMixin):
378
- raise RuntimeError(
379
- f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
380
- "not inherit from `ConfigMixin`."
381
- )
382
-
383
- ignore = getattr(self, "ignore_for_config", [])
384
- # Get positional arguments aligned with kwargs
385
- new_kwargs = {}
386
- signature = inspect.signature(init)
387
- parameters = {
388
- name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
389
- }
390
- for arg, name in zip(args, parameters.keys()):
391
- new_kwargs[name] = arg
392
-
393
- # Then add all kwargs
394
- new_kwargs.update(
395
- {
396
- k: init_kwargs.get(k, default)
397
- for k, default in parameters.items()
398
- if k not in ignore and k not in new_kwargs
399
- }
400
- )
401
- getattr(self, "register_to_config")(**new_kwargs)
402
-
403
- return inner_init
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/dependency_versions_check.py DELETED
@@ -1,47 +0,0 @@
1
- # Copyright 2020 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import sys
15
-
16
- from .dependency_versions_table import deps
17
- from .utils.versions import require_version, require_version_core
18
-
19
-
20
- # define which module versions we always want to check at run time
21
- # (usually the ones defined in `install_requires` in setup.py)
22
- #
23
- # order specific notes:
24
- # - tqdm must be checked before tokenizers
25
-
26
- pkgs_to_check_at_runtime = "python tqdm regex requests packaging filelock numpy tokenizers".split()
27
- if sys.version_info < (3, 7):
28
- pkgs_to_check_at_runtime.append("dataclasses")
29
- if sys.version_info < (3, 8):
30
- pkgs_to_check_at_runtime.append("importlib_metadata")
31
-
32
- for pkg in pkgs_to_check_at_runtime:
33
- if pkg in deps:
34
- if pkg == "tokenizers":
35
- # must be loaded here, or else tqdm check may fail
36
- from .utils import is_tokenizers_available
37
-
38
- if not is_tokenizers_available():
39
- continue # not required, check version only if installed
40
-
41
- require_version_core(deps[pkg])
42
- else:
43
- raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
44
-
45
-
46
- def dep_version_check(pkg, hint=None):
47
- require_version(deps[pkg], hint)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/dependency_versions_table.py DELETED
@@ -1,26 +0,0 @@
1
- # THIS FILE HAS BEEN AUTOGENERATED. To update:
2
- # 1. modify the `_deps` dict in setup.py
3
- # 2. run `make deps_table_update``
4
- deps = {
5
- "Pillow": "Pillow",
6
- "accelerate": "accelerate>=0.11.0",
7
- "black": "black==22.3",
8
- "datasets": "datasets",
9
- "filelock": "filelock",
10
- "flake8": "flake8>=3.8.3",
11
- "hf-doc-builder": "hf-doc-builder>=0.3.0",
12
- "huggingface-hub": "huggingface-hub>=0.8.1",
13
- "importlib_metadata": "importlib_metadata",
14
- "isort": "isort>=5.5.4",
15
- "modelcards": "modelcards==0.1.4",
16
- "numpy": "numpy",
17
- "pytest": "pytest",
18
- "pytest-timeout": "pytest-timeout",
19
- "pytest-xdist": "pytest-xdist",
20
- "scipy": "scipy",
21
- "regex": "regex!=2019.12.17",
22
- "requests": "requests",
23
- "tensorboard": "tensorboard",
24
- "torch": "torch>=1.4",
25
- "transformers": "transformers>=4.21.0",
26
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/dynamic_modules_utils.py DELETED
@@ -1,335 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Utilities to dynamically load objects from the Hub."""
16
-
17
- import importlib
18
- import os
19
- import re
20
- import shutil
21
- import sys
22
- from pathlib import Path
23
- from typing import Dict, Optional, Union
24
-
25
- from huggingface_hub import cached_download
26
-
27
- from .utils import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
28
-
29
-
30
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
-
32
-
33
- def init_hf_modules():
34
- """
35
- Creates the cache directory for modules with an init, and adds it to the Python path.
36
- """
37
- # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
38
- if HF_MODULES_CACHE in sys.path:
39
- return
40
-
41
- sys.path.append(HF_MODULES_CACHE)
42
- os.makedirs(HF_MODULES_CACHE, exist_ok=True)
43
- init_path = Path(HF_MODULES_CACHE) / "__init__.py"
44
- if not init_path.exists():
45
- init_path.touch()
46
-
47
-
48
- def create_dynamic_module(name: Union[str, os.PathLike]):
49
- """
50
- Creates a dynamic module in the cache directory for modules.
51
- """
52
- init_hf_modules()
53
- dynamic_module_path = Path(HF_MODULES_CACHE) / name
54
- # If the parent module does not exist yet, recursively create it.
55
- if not dynamic_module_path.parent.exists():
56
- create_dynamic_module(dynamic_module_path.parent)
57
- os.makedirs(dynamic_module_path, exist_ok=True)
58
- init_path = dynamic_module_path / "__init__.py"
59
- if not init_path.exists():
60
- init_path.touch()
61
-
62
-
63
- def get_relative_imports(module_file):
64
- """
65
- Get the list of modules that are relatively imported in a module file.
66
-
67
- Args:
68
- module_file (`str` or `os.PathLike`): The module file to inspect.
69
- """
70
- with open(module_file, "r", encoding="utf-8") as f:
71
- content = f.read()
72
-
73
- # Imports of the form `import .xxx`
74
- relative_imports = re.findall("^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
75
- # Imports of the form `from .xxx import yyy`
76
- relative_imports += re.findall("^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
77
- # Unique-ify
78
- return list(set(relative_imports))
79
-
80
-
81
- def get_relative_import_files(module_file):
82
- """
83
- Get the list of all files that are needed for a given module. Note that this function recurses through the relative
84
- imports (if a imports b and b imports c, it will return module files for b and c).
85
-
86
- Args:
87
- module_file (`str` or `os.PathLike`): The module file to inspect.
88
- """
89
- no_change = False
90
- files_to_check = [module_file]
91
- all_relative_imports = []
92
-
93
- # Let's recurse through all relative imports
94
- while not no_change:
95
- new_imports = []
96
- for f in files_to_check:
97
- new_imports.extend(get_relative_imports(f))
98
-
99
- module_path = Path(module_file).parent
100
- new_import_files = [str(module_path / m) for m in new_imports]
101
- new_import_files = [f for f in new_import_files if f not in all_relative_imports]
102
- files_to_check = [f"{f}.py" for f in new_import_files]
103
-
104
- no_change = len(new_import_files) == 0
105
- all_relative_imports.extend(files_to_check)
106
-
107
- return all_relative_imports
108
-
109
-
110
- def check_imports(filename):
111
- """
112
- Check if the current Python environment contains all the libraries that are imported in a file.
113
- """
114
- with open(filename, "r", encoding="utf-8") as f:
115
- content = f.read()
116
-
117
- # Imports of the form `import xxx`
118
- imports = re.findall("^\s*import\s+(\S+)\s*$", content, flags=re.MULTILINE)
119
- # Imports of the form `from xxx import yyy`
120
- imports += re.findall("^\s*from\s+(\S+)\s+import", content, flags=re.MULTILINE)
121
- # Only keep the top-level module
122
- imports = [imp.split(".")[0] for imp in imports if not imp.startswith(".")]
123
-
124
- # Unique-ify and test we got them all
125
- imports = list(set(imports))
126
- missing_packages = []
127
- for imp in imports:
128
- try:
129
- importlib.import_module(imp)
130
- except ImportError:
131
- missing_packages.append(imp)
132
-
133
- if len(missing_packages) > 0:
134
- raise ImportError(
135
- "This modeling file requires the following packages that were not found in your environment: "
136
- f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
137
- )
138
-
139
- return get_relative_imports(filename)
140
-
141
-
142
- def get_class_in_module(class_name, module_path):
143
- """
144
- Import a module on the cache directory for modules and extract a class from it.
145
- """
146
- module_path = module_path.replace(os.path.sep, ".")
147
- module = importlib.import_module(module_path)
148
- return getattr(module, class_name)
149
-
150
-
151
- def get_cached_module_file(
152
- pretrained_model_name_or_path: Union[str, os.PathLike],
153
- module_file: str,
154
- cache_dir: Optional[Union[str, os.PathLike]] = None,
155
- force_download: bool = False,
156
- resume_download: bool = False,
157
- proxies: Optional[Dict[str, str]] = None,
158
- use_auth_token: Optional[Union[bool, str]] = None,
159
- revision: Optional[str] = None,
160
- local_files_only: bool = False,
161
- ):
162
- """
163
- Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
164
- Transformers module.
165
-
166
- Args:
167
- pretrained_model_name_or_path (`str` or `os.PathLike`):
168
- This can be either:
169
-
170
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
171
- huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
172
- under a user or organization name, like `dbmdz/bert-base-german-cased`.
173
- - a path to a *directory* containing a configuration file saved using the
174
- [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
175
-
176
- module_file (`str`):
177
- The name of the module file containing the class to look for.
178
- cache_dir (`str` or `os.PathLike`, *optional*):
179
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
180
- cache should not be used.
181
- force_download (`bool`, *optional*, defaults to `False`):
182
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
183
- exist.
184
- resume_download (`bool`, *optional*, defaults to `False`):
185
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
186
- proxies (`Dict[str, str]`, *optional*):
187
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
188
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
189
- use_auth_token (`str` or *bool*, *optional*):
190
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
191
- when running `transformers-cli login` (stored in `~/.huggingface`).
192
- revision (`str`, *optional*, defaults to `"main"`):
193
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
194
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
195
- identifier allowed by git.
196
- local_files_only (`bool`, *optional*, defaults to `False`):
197
- If `True`, will only try to load the tokenizer configuration from local files.
198
-
199
- <Tip>
200
-
201
- Passing `use_auth_token=True` is required when you want to use a private model.
202
-
203
- </Tip>
204
-
205
- Returns:
206
- `str`: The path to the module inside the cache.
207
- """
208
- # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
209
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
210
- module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
211
- submodule = "local"
212
-
213
- if os.path.isfile(module_file_or_url):
214
- resolved_module_file = module_file_or_url
215
- else:
216
- try:
217
- # Load from URL or cache if already cached
218
- resolved_module_file = cached_download(
219
- module_file_or_url,
220
- cache_dir=cache_dir,
221
- force_download=force_download,
222
- proxies=proxies,
223
- resume_download=resume_download,
224
- local_files_only=local_files_only,
225
- use_auth_token=use_auth_token,
226
- )
227
-
228
- except EnvironmentError:
229
- logger.error(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
230
- raise
231
-
232
- # Check we have all the requirements in our environment
233
- modules_needed = check_imports(resolved_module_file)
234
-
235
- # Now we move the module inside our cached dynamic modules.
236
- full_submodule = DIFFUSERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
237
- create_dynamic_module(full_submodule)
238
- submodule_path = Path(HF_MODULES_CACHE) / full_submodule
239
- # We always copy local files (we could hash the file to see if there was a change, and give them the name of
240
- # that hash, to only copy when there is a modification but it seems overkill for now).
241
- # The only reason we do the copy is to avoid putting too many folders in sys.path.
242
- shutil.copy(resolved_module_file, submodule_path / module_file)
243
- for module_needed in modules_needed:
244
- module_needed = f"{module_needed}.py"
245
- shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
246
- return os.path.join(full_submodule, module_file)
247
-
248
-
249
- def get_class_from_dynamic_module(
250
- pretrained_model_name_or_path: Union[str, os.PathLike],
251
- module_file: str,
252
- class_name: str,
253
- cache_dir: Optional[Union[str, os.PathLike]] = None,
254
- force_download: bool = False,
255
- resume_download: bool = False,
256
- proxies: Optional[Dict[str, str]] = None,
257
- use_auth_token: Optional[Union[bool, str]] = None,
258
- revision: Optional[str] = None,
259
- local_files_only: bool = False,
260
- **kwargs,
261
- ):
262
- """
263
- Extracts a class from a module file, present in the local folder or repository of a model.
264
-
265
- <Tip warning={true}>
266
-
267
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
268
- therefore only be called on trusted repos.
269
-
270
- </Tip>
271
-
272
- Args:
273
- pretrained_model_name_or_path (`str` or `os.PathLike`):
274
- This can be either:
275
-
276
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
277
- huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced
278
- under a user or organization name, like `dbmdz/bert-base-german-cased`.
279
- - a path to a *directory* containing a configuration file saved using the
280
- [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
281
-
282
- module_file (`str`):
283
- The name of the module file containing the class to look for.
284
- class_name (`str`):
285
- The name of the class to import in the module.
286
- cache_dir (`str` or `os.PathLike`, *optional*):
287
- Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
288
- cache should not be used.
289
- force_download (`bool`, *optional*, defaults to `False`):
290
- Whether or not to force to (re-)download the configuration files and override the cached versions if they
291
- exist.
292
- resume_download (`bool`, *optional*, defaults to `False`):
293
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
294
- proxies (`Dict[str, str]`, *optional*):
295
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
296
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
297
- use_auth_token (`str` or `bool`, *optional*):
298
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
299
- when running `transformers-cli login` (stored in `~/.huggingface`).
300
- revision (`str`, *optional*, defaults to `"main"`):
301
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
302
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
303
- identifier allowed by git.
304
- local_files_only (`bool`, *optional*, defaults to `False`):
305
- If `True`, will only try to load the tokenizer configuration from local files.
306
-
307
- <Tip>
308
-
309
- Passing `use_auth_token=True` is required when you want to use a private model.
310
-
311
- </Tip>
312
-
313
- Returns:
314
- `type`: The class, dynamically imported from the module.
315
-
316
- Examples:
317
-
318
- ```python
319
- # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
320
- # module.
321
- cls = get_class_from_dynamic_module("sgugger/my-bert-model", "modeling.py", "MyBertModel")
322
- ```"""
323
- # And lastly we get the class inside our newly created module
324
- final_module = get_cached_module_file(
325
- pretrained_model_name_or_path,
326
- module_file,
327
- cache_dir=cache_dir,
328
- force_download=force_download,
329
- resume_download=resume_download,
330
- proxies=proxies,
331
- use_auth_token=use_auth_token,
332
- revision=revision,
333
- local_files_only=local_files_only,
334
- )
335
- return get_class_in_module(class_name, final_module.replace(".py", ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/hub_utils.py DELETED
@@ -1,197 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import os
18
- import shutil
19
- from pathlib import Path
20
- from typing import Optional
21
-
22
- from huggingface_hub import HfFolder, Repository, whoami
23
-
24
- from .pipeline_utils import DiffusionPipeline
25
- from .utils import is_modelcards_available, logging
26
-
27
-
28
- if is_modelcards_available():
29
- from modelcards import CardData, ModelCard
30
-
31
-
32
- logger = logging.get_logger(__name__)
33
-
34
-
35
- MODEL_CARD_TEMPLATE_PATH = Path(__file__).parent / "utils" / "model_card_template.md"
36
-
37
-
38
- def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
39
- if token is None:
40
- token = HfFolder.get_token()
41
- if organization is None:
42
- username = whoami(token)["name"]
43
- return f"{username}/{model_id}"
44
- else:
45
- return f"{organization}/{model_id}"
46
-
47
-
48
- def init_git_repo(args, at_init: bool = False):
49
- """
50
- Args:
51
- Initializes a git repo in `args.hub_model_id`.
52
- at_init (`bool`, *optional*, defaults to `False`):
53
- Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True`
54
- and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out.
55
- """
56
- if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
57
- return
58
- hub_token = args.hub_token if hasattr(args, "hub_token") else None
59
- use_auth_token = True if hub_token is None else hub_token
60
- if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
61
- repo_name = Path(args.output_dir).absolute().name
62
- else:
63
- repo_name = args.hub_model_id
64
- if "/" not in repo_name:
65
- repo_name = get_full_repo_name(repo_name, token=hub_token)
66
-
67
- try:
68
- repo = Repository(
69
- args.output_dir,
70
- clone_from=repo_name,
71
- use_auth_token=use_auth_token,
72
- private=args.hub_private_repo,
73
- )
74
- except EnvironmentError:
75
- if args.overwrite_output_dir and at_init:
76
- # Try again after wiping output_dir
77
- shutil.rmtree(args.output_dir)
78
- repo = Repository(
79
- args.output_dir,
80
- clone_from=repo_name,
81
- use_auth_token=use_auth_token,
82
- )
83
- else:
84
- raise
85
-
86
- repo.git_pull()
87
-
88
- # By default, ignore the checkpoint folders
89
- if not os.path.exists(os.path.join(args.output_dir, ".gitignore")):
90
- with open(os.path.join(args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
91
- writer.writelines(["checkpoint-*/"])
92
-
93
- return repo
94
-
95
-
96
- def push_to_hub(
97
- args,
98
- pipeline: DiffusionPipeline,
99
- repo: Repository,
100
- commit_message: Optional[str] = "End of training",
101
- blocking: bool = True,
102
- **kwargs,
103
- ) -> str:
104
- """
105
- Parameters:
106
- Upload *self.model* and *self.tokenizer* to the πŸ€— model hub on the repo *self.args.hub_model_id*.
107
- commit_message (`str`, *optional*, defaults to `"End of training"`):
108
- Message to commit while pushing.
109
- blocking (`bool`, *optional*, defaults to `True`):
110
- Whether the function should return only when the `git push` has finished.
111
- kwargs:
112
- Additional keyword arguments passed along to [`create_model_card`].
113
- Returns:
114
- The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the
115
- commit and an object to track the progress of the commit if `blocking=True`
116
- """
117
-
118
- if not hasattr(args, "hub_model_id") or args.hub_model_id is None:
119
- model_name = Path(args.output_dir).name
120
- else:
121
- model_name = args.hub_model_id.split("/")[-1]
122
-
123
- output_dir = args.output_dir
124
- os.makedirs(output_dir, exist_ok=True)
125
- logger.info(f"Saving pipeline checkpoint to {output_dir}")
126
- pipeline.save_pretrained(output_dir)
127
-
128
- # Only push from one node.
129
- if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
130
- return
131
-
132
- # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
133
- if (
134
- blocking
135
- and len(repo.command_queue) > 0
136
- and repo.command_queue[-1] is not None
137
- and not repo.command_queue[-1].is_done
138
- ):
139
- repo.command_queue[-1]._process.kill()
140
-
141
- git_head_commit_url = repo.push_to_hub(commit_message=commit_message, blocking=blocking, auto_lfs_prune=True)
142
- # push separately the model card to be independent from the rest of the model
143
- create_model_card(args, model_name=model_name)
144
- try:
145
- repo.push_to_hub(commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True)
146
- except EnvironmentError as exc:
147
- logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
148
-
149
- return git_head_commit_url
150
-
151
-
152
- def create_model_card(args, model_name):
153
- if not is_modelcards_available:
154
- raise ValueError(
155
- "Please make sure to have `modelcards` installed when using the `create_model_card` function. You can"
156
- " install the package with `pip install modelcards`."
157
- )
158
-
159
- if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]:
160
- return
161
-
162
- hub_token = args.hub_token if hasattr(args, "hub_token") else None
163
- repo_name = get_full_repo_name(model_name, token=hub_token)
164
-
165
- model_card = ModelCard.from_template(
166
- card_data=CardData( # Card metadata object that will be converted to YAML block
167
- language="en",
168
- license="apache-2.0",
169
- library_name="diffusers",
170
- tags=[],
171
- datasets=args.dataset_name,
172
- metrics=[],
173
- ),
174
- template_path=MODEL_CARD_TEMPLATE_PATH,
175
- model_name=model_name,
176
- repo_name=repo_name,
177
- dataset_name=args.dataset_name if hasattr(args, "dataset_name") else None,
178
- learning_rate=args.learning_rate,
179
- train_batch_size=args.train_batch_size,
180
- eval_batch_size=args.eval_batch_size,
181
- gradient_accumulation_steps=args.gradient_accumulation_steps
182
- if hasattr(args, "gradient_accumulation_steps")
183
- else None,
184
- adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None,
185
- adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None,
186
- adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,
187
- adam_epsilon=args.adam_epsilon if hasattr(args, "adam_epsilon") else None,
188
- lr_scheduler=args.lr_scheduler if hasattr(args, "lr_scheduler") else None,
189
- lr_warmup_steps=args.lr_warmup_steps if hasattr(args, "lr_warmup_steps") else None,
190
- ema_inv_gamma=args.ema_inv_gamma if hasattr(args, "ema_inv_gamma") else None,
191
- ema_power=args.ema_power if hasattr(args, "ema_power") else None,
192
- ema_max_decay=args.ema_max_decay if hasattr(args, "ema_max_decay") else None,
193
- mixed_precision=args.mixed_precision,
194
- )
195
-
196
- card_path = os.path.join(args.output_dir, "README.md")
197
- model_card.save(card_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/modeling_utils.py DELETED
@@ -1,542 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import os
18
- from typing import Callable, List, Optional, Tuple, Union
19
-
20
- import torch
21
- from torch import Tensor, device
22
-
23
- from huggingface_hub import hf_hub_download
24
- from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
25
- from requests import HTTPError
26
-
27
- from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging
28
-
29
-
30
- WEIGHTS_NAME = "diffusion_pytorch_model.bin"
31
-
32
-
33
- logger = logging.get_logger(__name__)
34
-
35
-
36
- def get_parameter_device(parameter: torch.nn.Module):
37
- try:
38
- return next(parameter.parameters()).device
39
- except StopIteration:
40
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
41
-
42
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
43
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
44
- return tuples
45
-
46
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
47
- first_tuple = next(gen)
48
- return first_tuple[1].device
49
-
50
-
51
- def get_parameter_dtype(parameter: torch.nn.Module):
52
- try:
53
- return next(parameter.parameters()).dtype
54
- except StopIteration:
55
- # For torch.nn.DataParallel compatibility in PyTorch 1.5
56
-
57
- def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
58
- tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
59
- return tuples
60
-
61
- gen = parameter._named_members(get_members_fn=find_tensor_attributes)
62
- first_tuple = next(gen)
63
- return first_tuple[1].dtype
64
-
65
-
66
- def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
67
- """
68
- Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
69
- """
70
- try:
71
- return torch.load(checkpoint_file, map_location="cpu")
72
- except Exception as e:
73
- try:
74
- with open(checkpoint_file) as f:
75
- if f.read().startswith("version"):
76
- raise OSError(
77
- "You seem to have cloned a repository without having git-lfs installed. Please install "
78
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
79
- "you cloned."
80
- )
81
- else:
82
- raise ValueError(
83
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
84
- "model. Make sure you have saved the model properly."
85
- ) from e
86
- except (UnicodeDecodeError, ValueError):
87
- raise OSError(
88
- f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
89
- f"at '{checkpoint_file}'. "
90
- "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
91
- )
92
-
93
-
94
- def _load_state_dict_into_model(model_to_load, state_dict):
95
- # Convert old format to new format if needed from a PyTorch state_dict
96
- # copy state_dict so _load_from_state_dict can modify it
97
- state_dict = state_dict.copy()
98
- error_msgs = []
99
-
100
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
101
- # so we need to apply the function recursively.
102
- def load(module: torch.nn.Module, prefix=""):
103
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
104
- module._load_from_state_dict(*args)
105
-
106
- for name, child in module._modules.items():
107
- if child is not None:
108
- load(child, prefix + name + ".")
109
-
110
- load(model_to_load)
111
-
112
- return error_msgs
113
-
114
-
115
- class ModelMixin(torch.nn.Module):
116
- r"""
117
- Base class for all models.
118
-
119
- [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
120
- and saving models.
121
-
122
- - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
123
- [`~modeling_utils.ModelMixin.save_pretrained`].
124
- """
125
- config_name = CONFIG_NAME
126
- _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
127
-
128
- def __init__(self):
129
- super().__init__()
130
-
131
- def save_pretrained(
132
- self,
133
- save_directory: Union[str, os.PathLike],
134
- is_main_process: bool = True,
135
- save_function: Callable = torch.save,
136
- ):
137
- """
138
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
139
- `[`~modeling_utils.ModelMixin.from_pretrained`]` class method.
140
-
141
- Arguments:
142
- save_directory (`str` or `os.PathLike`):
143
- Directory to which to save. Will be created if it doesn't exist.
144
- is_main_process (`bool`, *optional*, defaults to `True`):
145
- Whether the process calling this is the main process or not. Useful when in distributed training like
146
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
147
- the main process to avoid race conditions.
148
- save_function (`Callable`):
149
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
150
- need to replace `torch.save` by another method.
151
- """
152
- if os.path.isfile(save_directory):
153
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
154
- return
155
-
156
- os.makedirs(save_directory, exist_ok=True)
157
-
158
- model_to_save = self
159
-
160
- # Attach architecture to the config
161
- # Save the config
162
- if is_main_process:
163
- model_to_save.save_config(save_directory)
164
-
165
- # Save the model
166
- state_dict = model_to_save.state_dict()
167
-
168
- # Clean the folder from a previous save
169
- for filename in os.listdir(save_directory):
170
- full_filename = os.path.join(save_directory, filename)
171
- # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
172
- # in distributed settings to avoid race conditions.
173
- if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
174
- os.remove(full_filename)
175
-
176
- # Save the model
177
- save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
178
-
179
- logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
180
-
181
- @classmethod
182
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
183
- r"""
184
- Instantiate a pretrained pytorch model from a pre-trained model configuration.
185
-
186
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
187
- the model, you should first set it back in training mode with `model.train()`.
188
-
189
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
190
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
191
- task.
192
-
193
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
194
- weights are discarded.
195
-
196
- Parameters:
197
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
198
- Can be either:
199
-
200
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
201
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
202
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
203
- `./my_model_directory/`.
204
-
205
- cache_dir (`Union[str, os.PathLike]`, *optional*):
206
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
207
- standard cache should not be used.
208
- torch_dtype (`str` or `torch.dtype`, *optional*):
209
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
210
- will be automatically derived from the model's weights.
211
- force_download (`bool`, *optional*, defaults to `False`):
212
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
213
- cached versions if they exist.
214
- resume_download (`bool`, *optional*, defaults to `False`):
215
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
216
- file exists.
217
- proxies (`Dict[str, str]`, *optional*):
218
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
219
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
220
- output_loading_info(`bool`, *optional*, defaults to `False`):
221
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
222
- local_files_only(`bool`, *optional*, defaults to `False`):
223
- Whether or not to only look at local files (i.e., do not try to download the model).
224
- use_auth_token (`str` or *bool*, *optional*):
225
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
226
- when running `diffusers-cli login` (stored in `~/.huggingface`).
227
- revision (`str`, *optional*, defaults to `"main"`):
228
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
229
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
230
- identifier allowed by git.
231
- mirror (`str`, *optional*):
232
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
233
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
234
- Please refer to the mirror site for more information.
235
-
236
- <Tip>
237
-
238
- Passing `use_auth_token=True`` is required when you want to use a private model.
239
-
240
- </Tip>
241
-
242
- <Tip>
243
-
244
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
245
- this method in a firewalled environment.
246
-
247
- </Tip>
248
-
249
- """
250
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
251
- ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
252
- force_download = kwargs.pop("force_download", False)
253
- resume_download = kwargs.pop("resume_download", False)
254
- proxies = kwargs.pop("proxies", None)
255
- output_loading_info = kwargs.pop("output_loading_info", False)
256
- local_files_only = kwargs.pop("local_files_only", False)
257
- use_auth_token = kwargs.pop("use_auth_token", None)
258
- revision = kwargs.pop("revision", None)
259
- from_auto_class = kwargs.pop("_from_auto", False)
260
- torch_dtype = kwargs.pop("torch_dtype", None)
261
- subfolder = kwargs.pop("subfolder", None)
262
-
263
- user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
264
-
265
- # Load config if we don't provide a configuration
266
- config_path = pretrained_model_name_or_path
267
- model, unused_kwargs = cls.from_config(
268
- config_path,
269
- cache_dir=cache_dir,
270
- return_unused_kwargs=True,
271
- force_download=force_download,
272
- resume_download=resume_download,
273
- proxies=proxies,
274
- local_files_only=local_files_only,
275
- use_auth_token=use_auth_token,
276
- revision=revision,
277
- subfolder=subfolder,
278
- **kwargs,
279
- )
280
-
281
- if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
282
- raise ValueError(
283
- f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
284
- )
285
- elif torch_dtype is not None:
286
- model = model.to(torch_dtype)
287
-
288
- model.register_to_config(_name_or_path=pretrained_model_name_or_path)
289
- # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
290
- # Load model
291
- pretrained_model_name_or_path = str(pretrained_model_name_or_path)
292
- if os.path.isdir(pretrained_model_name_or_path):
293
- if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
294
- # Load from a PyTorch checkpoint
295
- model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
296
- elif subfolder is not None and os.path.isfile(
297
- os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
298
- ):
299
- model_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
300
- else:
301
- raise EnvironmentError(
302
- f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
303
- )
304
- else:
305
- try:
306
- # Load from URL or cache if already cached
307
- model_file = hf_hub_download(
308
- pretrained_model_name_or_path,
309
- filename=WEIGHTS_NAME,
310
- cache_dir=cache_dir,
311
- force_download=force_download,
312
- proxies=proxies,
313
- resume_download=resume_download,
314
- local_files_only=local_files_only,
315
- use_auth_token=use_auth_token,
316
- user_agent=user_agent,
317
- subfolder=subfolder,
318
- revision=revision,
319
- )
320
-
321
- except RepositoryNotFoundError:
322
- raise EnvironmentError(
323
- f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
324
- "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
325
- "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
326
- "login` and pass `use_auth_token=True`."
327
- )
328
- except RevisionNotFoundError:
329
- raise EnvironmentError(
330
- f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
331
- "this model name. Check the model page at "
332
- f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
333
- )
334
- except EntryNotFoundError:
335
- raise EnvironmentError(
336
- f"{pretrained_model_name_or_path} does not appear to have a file named {WEIGHTS_NAME}."
337
- )
338
- except HTTPError as err:
339
- raise EnvironmentError(
340
- "There was a specific connection error when trying to load"
341
- f" {pretrained_model_name_or_path}:\n{err}"
342
- )
343
- except ValueError:
344
- raise EnvironmentError(
345
- f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
346
- f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
347
- f" directory containing a file named {WEIGHTS_NAME} or"
348
- " \nCheckout your internet connection or see how to run the library in"
349
- " offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
350
- )
351
- except EnvironmentError:
352
- raise EnvironmentError(
353
- f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
354
- "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
355
- f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
356
- f"containing a file named {WEIGHTS_NAME}"
357
- )
358
-
359
- # restore default dtype
360
- state_dict = load_state_dict(model_file)
361
- model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
362
- model,
363
- state_dict,
364
- model_file,
365
- pretrained_model_name_or_path,
366
- ignore_mismatched_sizes=ignore_mismatched_sizes,
367
- )
368
-
369
- # Set model in evaluation mode to deactivate DropOut modules by default
370
- model.eval()
371
-
372
- if output_loading_info:
373
- loading_info = {
374
- "missing_keys": missing_keys,
375
- "unexpected_keys": unexpected_keys,
376
- "mismatched_keys": mismatched_keys,
377
- "error_msgs": error_msgs,
378
- }
379
- return model, loading_info
380
-
381
- return model
382
-
383
- @classmethod
384
- def _load_pretrained_model(
385
- cls,
386
- model,
387
- state_dict,
388
- resolved_archive_file,
389
- pretrained_model_name_or_path,
390
- ignore_mismatched_sizes=False,
391
- ):
392
- # Retrieve missing & unexpected_keys
393
- model_state_dict = model.state_dict()
394
- loaded_keys = [k for k in state_dict.keys()]
395
-
396
- expected_keys = list(model_state_dict.keys())
397
-
398
- original_loaded_keys = loaded_keys
399
-
400
- missing_keys = list(set(expected_keys) - set(loaded_keys))
401
- unexpected_keys = list(set(loaded_keys) - set(expected_keys))
402
-
403
- # Make sure we are able to load base models as well as derived models (with heads)
404
- model_to_load = model
405
-
406
- def _find_mismatched_keys(
407
- state_dict,
408
- model_state_dict,
409
- loaded_keys,
410
- ignore_mismatched_sizes,
411
- ):
412
- mismatched_keys = []
413
- if ignore_mismatched_sizes:
414
- for checkpoint_key in loaded_keys:
415
- model_key = checkpoint_key
416
-
417
- if (
418
- model_key in model_state_dict
419
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
420
- ):
421
- mismatched_keys.append(
422
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
423
- )
424
- del state_dict[checkpoint_key]
425
- return mismatched_keys
426
-
427
- if state_dict is not None:
428
- # Whole checkpoint
429
- mismatched_keys = _find_mismatched_keys(
430
- state_dict,
431
- model_state_dict,
432
- original_loaded_keys,
433
- ignore_mismatched_sizes,
434
- )
435
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
436
-
437
- if len(error_msgs) > 0:
438
- error_msg = "\n\t".join(error_msgs)
439
- if "size mismatch" in error_msg:
440
- error_msg += (
441
- "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
442
- )
443
- raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
444
-
445
- if len(unexpected_keys) > 0:
446
- logger.warning(
447
- f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
448
- f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
449
- f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task"
450
- " or with another architecture (e.g. initializing a BertForSequenceClassification model from a"
451
- " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
452
- f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly"
453
- " identical (initializing a BertForSequenceClassification model from a"
454
- " BertForSequenceClassification model)."
455
- )
456
- else:
457
- logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
458
- if len(missing_keys) > 0:
459
- logger.warning(
460
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
461
- f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
462
- " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
463
- )
464
- elif len(mismatched_keys) == 0:
465
- logger.info(
466
- f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
467
- f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the"
468
- f" checkpoint was trained on, you can already use {model.__class__.__name__} for predictions"
469
- " without further training."
470
- )
471
- if len(mismatched_keys) > 0:
472
- mismatched_warning = "\n".join(
473
- [
474
- f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
475
- for key, shape1, shape2 in mismatched_keys
476
- ]
477
- )
478
- logger.warning(
479
- f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
480
- f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
481
- f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be"
482
- " able to use it for predictions and inference."
483
- )
484
-
485
- return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
486
-
487
- @property
488
- def device(self) -> device:
489
- """
490
- `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
491
- device).
492
- """
493
- return get_parameter_device(self)
494
-
495
- @property
496
- def dtype(self) -> torch.dtype:
497
- """
498
- `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
499
- """
500
- return get_parameter_dtype(self)
501
-
502
- def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
503
- """
504
- Get number of (optionally, trainable or non-embeddings) parameters in the module.
505
-
506
- Args:
507
- only_trainable (`bool`, *optional*, defaults to `False`):
508
- Whether or not to return only the number of trainable parameters
509
-
510
- exclude_embeddings (`bool`, *optional*, defaults to `False`):
511
- Whether or not to return only the number of non-embeddings parameters
512
-
513
- Returns:
514
- `int`: The number of parameters.
515
- """
516
-
517
- if exclude_embeddings:
518
- embedding_param_names = [
519
- f"{name}.weight"
520
- for name, module_type in self.named_modules()
521
- if isinstance(module_type, torch.nn.Embedding)
522
- ]
523
- non_embedding_parameters = [
524
- parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
525
- ]
526
- return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
527
- else:
528
- return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
529
-
530
-
531
- def unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
532
- """
533
- Recursively unwraps a model from potential containers (as used in distributed training).
534
-
535
- Args:
536
- model (`torch.nn.Module`): The model to unwrap.
537
- """
538
- # since there could be multiple levels of wrapping, unwrap recursively
539
- if hasattr(model, "module"):
540
- return unwrap_model(model.module)
541
- else:
542
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/__init__.py DELETED
@@ -1,17 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from .unet_2d import UNet2DModel
16
- from .unet_2d_condition import UNet2DConditionModel
17
- from .vae import AutoencoderKL, VQModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (313 Bytes)
 
diffusers/models/__pycache__/attention.cpython-310.pyc DELETED
Binary file (14.3 kB)
 
diffusers/models/__pycache__/embeddings.cpython-310.pyc DELETED
Binary file (3.72 kB)
 
diffusers/models/__pycache__/resnet.cpython-310.pyc DELETED
Binary file (14.5 kB)
 
diffusers/models/__pycache__/unet_2d.cpython-310.pyc DELETED
Binary file (7.94 kB)
 
diffusers/models/__pycache__/unet_2d_condition.cpython-310.pyc DELETED
Binary file (8.73 kB)
 
diffusers/models/__pycache__/unet_blocks.cpython-310.pyc DELETED
Binary file (23.7 kB)
 
diffusers/models/__pycache__/vae.cpython-310.pyc DELETED
Binary file (16.5 kB)
 
diffusers/models/attention.py DELETED
@@ -1,409 +0,0 @@
1
- import math
2
- from collections import defaultdict
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import nn
8
-
9
-
10
- class AttentionBlock(nn.Module):
11
- """
12
- An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
13
- to the N-d case.
14
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
15
- Uses three q, k, v linear layers to compute attention.
16
-
17
- Parameters:
18
- channels (:obj:`int`): The number of channels in the input and output.
19
- num_head_channels (:obj:`int`, *optional*):
20
- The number of channels in each head. If None, then `num_heads` = 1.
21
- num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
22
- rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
23
- eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
24
- """
25
-
26
- def __init__(
27
- self,
28
- channels: int,
29
- num_head_channels: Optional[int] = None,
30
- num_groups: int = 32,
31
- rescale_output_factor: float = 1.0,
32
- eps: float = 1e-5,
33
- ):
34
- super().__init__()
35
- self.channels = channels
36
-
37
- self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
38
- self.num_head_size = num_head_channels
39
- self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
40
-
41
- # define q,k,v as linear layers
42
- self.query = nn.Linear(channels, channels)
43
- self.key = nn.Linear(channels, channels)
44
- self.value = nn.Linear(channels, channels)
45
-
46
- self.rescale_output_factor = rescale_output_factor
47
- self.proj_attn = nn.Linear(channels, channels, 1)
48
-
49
- def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor:
50
- new_projection_shape = projection.size()[:-1] + (self.num_heads, -1)
51
- # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
52
- new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
53
- return new_projection
54
-
55
- def forward(self, hidden_states):
56
- residual = hidden_states
57
- batch, channel, height, width = hidden_states.shape
58
-
59
- # norm
60
- hidden_states = self.group_norm(hidden_states)
61
-
62
- hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
63
-
64
- # proj to q, k, v
65
- query_proj = self.query(hidden_states)
66
- key_proj = self.key(hidden_states)
67
- value_proj = self.value(hidden_states)
68
-
69
- # transpose
70
- query_states = self.transpose_for_scores(query_proj)
71
- key_states = self.transpose_for_scores(key_proj)
72
- value_states = self.transpose_for_scores(value_proj)
73
-
74
- # get scores
75
- scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads))
76
-
77
- attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale)
78
- attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
79
-
80
- # compute attention output
81
- hidden_states = torch.matmul(attention_probs, value_states)
82
-
83
- hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
84
- new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
85
- hidden_states = hidden_states.view(new_hidden_states_shape)
86
-
87
- # compute next hidden_states
88
- hidden_states = self.proj_attn(hidden_states)
89
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
90
-
91
- # res connect and rescale
92
- hidden_states = (hidden_states + residual) / self.rescale_output_factor
93
- return hidden_states
94
-
95
-
96
- class SpatialTransformer(nn.Module):
97
- """
98
- Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
99
- standard transformer action. Finally, reshape to image.
100
-
101
- Parameters:
102
- in_channels (:obj:`int`): The number of channels in the input and output.
103
- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
104
- d_head (:obj:`int`): The number of channels in each head.
105
- depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
106
- dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
107
- context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
108
- """
109
-
110
- def __init__(
111
- self,
112
- in_channels: int,
113
- n_heads: int,
114
- d_head: int,
115
- depth: int = 1,
116
- dropout: float = 0.0,
117
- context_dim: Optional[int] = None,
118
- ):
119
- super().__init__()
120
- self.n_heads = n_heads
121
- self.d_head = d_head
122
- self.in_channels = in_channels
123
- inner_dim = n_heads * d_head
124
- self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
125
-
126
- self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
127
-
128
- self.transformer_blocks = nn.ModuleList(
129
- [
130
- BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
131
- for d in range(depth)
132
- ]
133
- )
134
-
135
- self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
136
-
137
- def _set_attention_slice(self, slice_size):
138
- for block in self.transformer_blocks:
139
- block._set_attention_slice(slice_size)
140
-
141
- def forward(self, x, context=None):
142
- # note: if no context is given, cross-attention defaults to self-attention
143
- b, c, h, w = x.shape
144
- x_in = x
145
- x = self.norm(x)
146
- x = self.proj_in(x)
147
- x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)
148
- for block in self.transformer_blocks:
149
- x = block(x, context=context)
150
- x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
151
- x = self.proj_out(x)
152
- return x + x_in
153
-
154
-
155
- class BasicTransformerBlock(nn.Module):
156
- r"""
157
- A basic Transformer block.
158
-
159
- Parameters:
160
- dim (:obj:`int`): The number of channels in the input and output.
161
- n_heads (:obj:`int`): The number of heads to use for multi-head attention.
162
- d_head (:obj:`int`): The number of channels in each head.
163
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
164
- context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
165
- gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
166
- checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
167
- """
168
-
169
- def __init__(
170
- self,
171
- dim: int,
172
- n_heads: int,
173
- d_head: int,
174
- dropout=0.0,
175
- context_dim: Optional[int] = None,
176
- gated_ff: bool = True,
177
- checkpoint: bool = True,
178
- ):
179
- super().__init__()
180
- self.attn1 = CrossAttention(
181
- query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
182
- ) # is a self-attention
183
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
184
- self.attn2 = CrossAttention(
185
- query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
186
- ) # is self-attn if context is none
187
- self.norm1 = nn.LayerNorm(dim)
188
- self.norm2 = nn.LayerNorm(dim)
189
- self.norm3 = nn.LayerNorm(dim)
190
- self.checkpoint = checkpoint
191
-
192
- def _set_attention_slice(self, slice_size):
193
- self.attn1._slice_size = slice_size
194
- self.attn2._slice_size = slice_size
195
-
196
- def forward(self, x, context=None):
197
- x = x.contiguous() if x.device.type == "mps" else x
198
- x = self.attn1(self.norm1(x)) + x
199
- x = self.attn2(self.norm2(x), context=context) + x
200
- x = self.ff(self.norm3(x)) + x
201
- return x
202
-
203
-
204
- heat_maps = defaultdict(list)
205
- all_heat_maps = []
206
-
207
-
208
- def clear_heat_maps():
209
- global heat_maps, all_heat_maps
210
- heat_maps = defaultdict(list)
211
- all_heat_maps = []
212
-
213
-
214
- def next_heat_map():
215
- global heat_maps, all_heat_maps
216
- all_heat_maps.append(heat_maps)
217
- heat_maps = defaultdict(list)
218
-
219
-
220
- def get_global_heat_map(last_n: int = None, idx: int = None, factors=None):
221
- global heat_maps, all_heat_maps
222
-
223
- if idx is not None:
224
- heat_maps2 = [all_heat_maps[idx]]
225
- else:
226
- heat_maps2 = all_heat_maps[-last_n:] if last_n is not None else all_heat_maps
227
-
228
- if factors is None:
229
- factors = {1, 2, 4, 8, 16, 32}
230
-
231
- all_merges = []
232
-
233
- for heat_map_map in heat_maps2:
234
- merge_list = []
235
-
236
- for k, v in heat_map_map.items():
237
- if k in factors:
238
- merge_list.append(torch.stack(v, 0).mean(0))
239
-
240
- all_merges.append(merge_list)
241
-
242
- maps = torch.stack([torch.stack(x, 0) for x in all_merges], dim=0)
243
- return maps.sum(0).cuda().sum(2).sum(0)
244
-
245
-
246
- class CrossAttention(nn.Module):
247
- r"""
248
- A cross attention layer.
249
-
250
- Parameters:
251
- query_dim (:obj:`int`): The number of channels in the query.
252
- context_dim (:obj:`int`, *optional*):
253
- The number of channels in the context. If not given, defaults to `query_dim`.
254
- heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
255
- dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
256
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
257
- """
258
-
259
- def __init__(
260
- self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
261
- ):
262
- super().__init__()
263
- inner_dim = dim_head * heads
264
- context_dim = context_dim if context_dim is not None else query_dim
265
-
266
- self.scale = dim_head**-0.5
267
- self.heads = heads
268
- # for slice_size > 0 the attention score computation
269
- # is split across the batch axis to save memory
270
- # You can set slice_size with `set_attention_slice`
271
- self._slice_size = None
272
-
273
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
274
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
275
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
276
-
277
- self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
278
-
279
- def reshape_heads_to_batch_dim(self, tensor):
280
- batch_size, seq_len, dim = tensor.shape
281
- head_size = self.heads
282
- tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
283
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
284
- return tensor
285
-
286
- def reshape_batch_dim_to_heads(self, tensor):
287
- batch_size, seq_len, dim = tensor.shape
288
- head_size = self.heads
289
- tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
290
- tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
291
- return tensor
292
-
293
- def forward(self, x, context=None, mask=None):
294
- batch_size, sequence_length, dim = x.shape
295
-
296
- use_context = context is not None
297
-
298
- q = self.to_q(x)
299
- context = context if context is not None else x
300
- k = self.to_k(context)
301
- v = self.to_v(context)
302
-
303
- q = self.reshape_heads_to_batch_dim(q)
304
- k = self.reshape_heads_to_batch_dim(k)
305
- v = self.reshape_heads_to_batch_dim(v)
306
-
307
- # TODO(PVP) - mask is currently never used. Remember to re-implement when used
308
-
309
- # attention, what we cannot get enough of
310
- hidden_states = self._attention(q, k, v, sequence_length, dim, use_context=use_context)
311
-
312
- return self.to_out(hidden_states)
313
-
314
- @torch.no_grad()
315
- def _up_sample_attn(self, x, factor, method: str = 'bicubic'):
316
- weight = torch.full((factor, factor), 1 / factor**2, device=x.device)
317
- weight = weight.view(1, 1, factor, factor)
318
-
319
- h = w = int(math.sqrt(x.size(1)))
320
- maps = []
321
- x = x.permute(2, 0, 1)
322
-
323
- with torch.cuda.amp.autocast(dtype=torch.float32):
324
- for map_ in x:
325
- map_ = map_.unsqueeze(1).view(map_.size(0), 1, h, w)
326
- if method == 'bicubic':
327
- map_ = F.interpolate(map_, size=(64, 64), mode="bicubic", align_corners=False)
328
- maps.append(map_.squeeze(1))
329
- else:
330
- maps.append(F.conv_transpose2d(map_, weight, stride=factor).squeeze(1).cpu())
331
-
332
- maps = torch.stack(maps, 0).sum(1, keepdim=True).cpu()
333
- return maps
334
-
335
- def _attention(self, query, key, value, sequence_length, dim, use_context: bool = True):
336
- batch_size_attention = query.shape[0]
337
- hidden_states = torch.zeros(
338
- (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
339
- )
340
- slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
341
- for i in range(hidden_states.shape[0] // slice_size):
342
- start_idx = i * slice_size
343
- end_idx = (i + 1) * slice_size
344
- attn_slice = (
345
- torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale
346
- )
347
- factor = int(math.sqrt(4096 // attn_slice.shape[1]))
348
- attn_slice = attn_slice.softmax(-1)
349
-
350
- if use_context and attn_slice.shape[-1] == 77:
351
- if factor >= 1:
352
- factor //= 1
353
- maps = self._up_sample_attn(attn_slice, factor)
354
- global heat_maps
355
- heat_maps[factor].append(maps)
356
- # print(attn_slice.size(), query.size(), key.size(), value.size())
357
-
358
- attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
359
-
360
- hidden_states[start_idx:end_idx] = attn_slice
361
-
362
- # reshape hidden_states
363
- hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
364
- return hidden_states
365
-
366
-
367
- class FeedForward(nn.Module):
368
- r"""
369
- A feed-forward layer.
370
-
371
- Parameters:
372
- dim (:obj:`int`): The number of channels in the input.
373
- dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
374
- mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
375
- glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
376
- dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
377
- """
378
-
379
- def __init__(
380
- self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
381
- ):
382
- super().__init__()
383
- inner_dim = int(dim * mult)
384
- dim_out = dim_out if dim_out is not None else dim
385
- project_in = GEGLU(dim, inner_dim)
386
-
387
- self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
388
-
389
- def forward(self, x):
390
- return self.net(x)
391
-
392
-
393
- # feedforward
394
- class GEGLU(nn.Module):
395
- r"""
396
- A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
397
-
398
- Parameters:
399
- dim_in (:obj:`int`): The number of channels in the input.
400
- dim_out (:obj:`int`): The number of channels in the output.
401
- """
402
-
403
- def __init__(self, dim_in: int, dim_out: int):
404
- super().__init__()
405
- self.proj = nn.Linear(dim_in, dim_out * 2)
406
-
407
- def forward(self, x):
408
- x, gate = self.proj(x).chunk(2, dim=-1)
409
- return x * F.gelu(gate)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/embeddings.py DELETED
@@ -1,115 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- import math
15
-
16
- import numpy as np
17
- import torch
18
- from torch import nn
19
-
20
-
21
- def get_timestep_embedding(
22
- timesteps: torch.Tensor,
23
- embedding_dim: int,
24
- flip_sin_to_cos: bool = False,
25
- downscale_freq_shift: float = 1,
26
- scale: float = 1,
27
- max_period: int = 10000,
28
- ):
29
- """
30
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
31
-
32
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
33
- These may be fractional.
34
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
35
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
36
- """
37
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
38
-
39
- half_dim = embedding_dim // 2
40
- exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32)
41
- exponent = exponent / (half_dim - downscale_freq_shift)
42
-
43
- emb = torch.exp(exponent).to(device=timesteps.device)
44
- emb = timesteps[:, None].float() * emb[None, :]
45
-
46
- # scale embeddings
47
- emb = scale * emb
48
-
49
- # concat sine and cosine embeddings
50
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
51
-
52
- # flip sine and cosine embeddings
53
- if flip_sin_to_cos:
54
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
55
-
56
- # zero pad
57
- if embedding_dim % 2 == 1:
58
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
59
- return emb
60
-
61
-
62
- class TimestepEmbedding(nn.Module):
63
- def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
64
- super().__init__()
65
-
66
- self.linear_1 = nn.Linear(channel, time_embed_dim)
67
- self.act = None
68
- if act_fn == "silu":
69
- self.act = nn.SiLU()
70
- self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
71
-
72
- def forward(self, sample):
73
- sample = self.linear_1(sample)
74
-
75
- if self.act is not None:
76
- sample = self.act(sample)
77
-
78
- sample = self.linear_2(sample)
79
- return sample
80
-
81
-
82
- class Timesteps(nn.Module):
83
- def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
84
- super().__init__()
85
- self.num_channels = num_channels
86
- self.flip_sin_to_cos = flip_sin_to_cos
87
- self.downscale_freq_shift = downscale_freq_shift
88
-
89
- def forward(self, timesteps):
90
- t_emb = get_timestep_embedding(
91
- timesteps,
92
- self.num_channels,
93
- flip_sin_to_cos=self.flip_sin_to_cos,
94
- downscale_freq_shift=self.downscale_freq_shift,
95
- )
96
- return t_emb
97
-
98
-
99
- class GaussianFourierProjection(nn.Module):
100
- """Gaussian Fourier embeddings for noise levels."""
101
-
102
- def __init__(self, embedding_size: int = 256, scale: float = 1.0):
103
- super().__init__()
104
- self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
105
-
106
- # to delete later
107
- self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
108
-
109
- self.weight = self.W
110
-
111
- def forward(self, x):
112
- x = torch.log(x)
113
- x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
114
- out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
115
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/resnet.py DELETED
@@ -1,483 +0,0 @@
1
- from functools import partial
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
-
9
- class Upsample2D(nn.Module):
10
- """
11
- An upsampling layer with an optional convolution.
12
-
13
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
14
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
15
- upsampling occurs in the inner-two dimensions.
16
- """
17
-
18
- def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
19
- super().__init__()
20
- self.channels = channels
21
- self.out_channels = out_channels or channels
22
- self.use_conv = use_conv
23
- self.use_conv_transpose = use_conv_transpose
24
- self.name = name
25
-
26
- conv = None
27
- if use_conv_transpose:
28
- conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
29
- elif use_conv:
30
- conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)
31
-
32
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
33
- if name == "conv":
34
- self.conv = conv
35
- else:
36
- self.Conv2d_0 = conv
37
-
38
- def forward(self, x):
39
- assert x.shape[1] == self.channels
40
- if self.use_conv_transpose:
41
- return self.conv(x)
42
-
43
- x = F.interpolate(x, scale_factor=2.0, mode="nearest")
44
-
45
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
46
- if self.use_conv:
47
- if self.name == "conv":
48
- x = self.conv(x)
49
- else:
50
- x = self.Conv2d_0(x)
51
-
52
- return x
53
-
54
-
55
- class Downsample2D(nn.Module):
56
- """
57
- A downsampling layer with an optional convolution.
58
-
59
- :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is
60
- applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
61
- downsampling occurs in the inner-two dimensions.
62
- """
63
-
64
- def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
65
- super().__init__()
66
- self.channels = channels
67
- self.out_channels = out_channels or channels
68
- self.use_conv = use_conv
69
- self.padding = padding
70
- stride = 2
71
- self.name = name
72
-
73
- if use_conv:
74
- conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
75
- else:
76
- assert self.channels == self.out_channels
77
- conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
78
-
79
- # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
80
- if name == "conv":
81
- self.Conv2d_0 = conv
82
- self.conv = conv
83
- elif name == "Conv2d_0":
84
- self.conv = conv
85
- else:
86
- self.conv = conv
87
-
88
- def forward(self, x):
89
- assert x.shape[1] == self.channels
90
- if self.use_conv and self.padding == 0:
91
- pad = (0, 1, 0, 1)
92
- x = F.pad(x, pad, mode="constant", value=0)
93
-
94
- assert x.shape[1] == self.channels
95
- x = self.conv(x)
96
-
97
- return x
98
-
99
-
100
- class FirUpsample2D(nn.Module):
101
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
102
- super().__init__()
103
- out_channels = out_channels if out_channels else channels
104
- if use_conv:
105
- self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
106
- self.use_conv = use_conv
107
- self.fir_kernel = fir_kernel
108
- self.out_channels = out_channels
109
-
110
- def _upsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
111
- """Fused `upsample_2d()` followed by `Conv2d()`.
112
-
113
- Args:
114
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
115
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
116
- order.
117
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
118
- C]`.
119
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
120
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
121
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
122
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
123
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
124
-
125
- Returns:
126
- Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as
127
- `x`.
128
- """
129
-
130
- assert isinstance(factor, int) and factor >= 1
131
-
132
- # Setup filter kernel.
133
- if kernel is None:
134
- kernel = [1] * factor
135
-
136
- # setup kernel
137
- kernel = np.asarray(kernel, dtype=np.float32)
138
- if kernel.ndim == 1:
139
- kernel = np.outer(kernel, kernel)
140
- kernel /= np.sum(kernel)
141
-
142
- kernel = kernel * (gain * (factor**2))
143
-
144
- if self.use_conv:
145
- convH = weight.shape[2]
146
- convW = weight.shape[3]
147
- inC = weight.shape[1]
148
-
149
- p = (kernel.shape[0] - factor) - (convW - 1)
150
-
151
- stride = (factor, factor)
152
- # Determine data dimensions.
153
- stride = [1, 1, factor, factor]
154
- output_shape = ((x.shape[2] - 1) * factor + convH, (x.shape[3] - 1) * factor + convW)
155
- output_padding = (
156
- output_shape[0] - (x.shape[2] - 1) * stride[0] - convH,
157
- output_shape[1] - (x.shape[3] - 1) * stride[1] - convW,
158
- )
159
- assert output_padding[0] >= 0 and output_padding[1] >= 0
160
- inC = weight.shape[1]
161
- num_groups = x.shape[1] // inC
162
-
163
- # Transpose weights.
164
- weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
165
- weight = weight[..., ::-1, ::-1].permute(0, 2, 1, 3, 4)
166
- weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
167
-
168
- x = F.conv_transpose2d(x, weight, stride=stride, output_padding=output_padding, padding=0)
169
-
170
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2 + factor - 1, p // 2 + 1))
171
- else:
172
- p = kernel.shape[0] - factor
173
- x = upfirdn2d_native(
174
- x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
175
- )
176
-
177
- return x
178
-
179
- def forward(self, x):
180
- if self.use_conv:
181
- height = self._upsample_2d(x, self.Conv2d_0.weight, kernel=self.fir_kernel)
182
- height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
183
- else:
184
- height = self._upsample_2d(x, kernel=self.fir_kernel, factor=2)
185
-
186
- return height
187
-
188
-
189
- class FirDownsample2D(nn.Module):
190
- def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
191
- super().__init__()
192
- out_channels = out_channels if out_channels else channels
193
- if use_conv:
194
- self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
195
- self.fir_kernel = fir_kernel
196
- self.use_conv = use_conv
197
- self.out_channels = out_channels
198
-
199
- def _downsample_2d(self, x, weight=None, kernel=None, factor=2, gain=1):
200
- """Fused `Conv2d()` followed by `downsample_2d()`.
201
-
202
- Args:
203
- Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
204
- efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary:
205
- order.
206
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. w: Weight tensor of the shape `[filterH,
207
- filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] //
208
- numGroups`. k: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
209
- factor`, which corresponds to average pooling. factor: Integer downsampling factor (default: 2). gain:
210
- Scaling factor for signal magnitude (default: 1.0).
211
-
212
- Returns:
213
- Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
214
- datatype as `x`.
215
- """
216
-
217
- assert isinstance(factor, int) and factor >= 1
218
- if kernel is None:
219
- kernel = [1] * factor
220
-
221
- # setup kernel
222
- kernel = np.asarray(kernel, dtype=np.float32)
223
- if kernel.ndim == 1:
224
- kernel = np.outer(kernel, kernel)
225
- kernel /= np.sum(kernel)
226
-
227
- kernel = kernel * gain
228
-
229
- if self.use_conv:
230
- _, _, convH, convW = weight.shape
231
- p = (kernel.shape[0] - factor) + (convW - 1)
232
- s = [factor, factor]
233
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), pad=((p + 1) // 2, p // 2))
234
- x = F.conv2d(x, weight, stride=s, padding=0)
235
- else:
236
- p = kernel.shape[0] - factor
237
- x = upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
238
-
239
- return x
240
-
241
- def forward(self, x):
242
- if self.use_conv:
243
- x = self._downsample_2d(x, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
244
- x = x + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
245
- else:
246
- x = self._downsample_2d(x, kernel=self.fir_kernel, factor=2)
247
-
248
- return x
249
-
250
-
251
- class ResnetBlock2D(nn.Module):
252
- def __init__(
253
- self,
254
- *,
255
- in_channels,
256
- out_channels=None,
257
- conv_shortcut=False,
258
- dropout=0.0,
259
- temb_channels=512,
260
- groups=32,
261
- groups_out=None,
262
- pre_norm=True,
263
- eps=1e-6,
264
- non_linearity="swish",
265
- time_embedding_norm="default",
266
- kernel=None,
267
- output_scale_factor=1.0,
268
- use_nin_shortcut=None,
269
- up=False,
270
- down=False,
271
- ):
272
- super().__init__()
273
- self.pre_norm = pre_norm
274
- self.pre_norm = True
275
- self.in_channels = in_channels
276
- out_channels = in_channels if out_channels is None else out_channels
277
- self.out_channels = out_channels
278
- self.use_conv_shortcut = conv_shortcut
279
- self.time_embedding_norm = time_embedding_norm
280
- self.up = up
281
- self.down = down
282
- self.output_scale_factor = output_scale_factor
283
-
284
- if groups_out is None:
285
- groups_out = groups
286
-
287
- self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
288
-
289
- self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
290
-
291
- if temb_channels is not None:
292
- self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
293
- else:
294
- self.time_emb_proj = None
295
-
296
- self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
297
- self.dropout = torch.nn.Dropout(dropout)
298
- self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
299
-
300
- if non_linearity == "swish":
301
- self.nonlinearity = lambda x: F.silu(x)
302
- elif non_linearity == "mish":
303
- self.nonlinearity = Mish()
304
- elif non_linearity == "silu":
305
- self.nonlinearity = nn.SiLU()
306
-
307
- self.upsample = self.downsample = None
308
- if self.up:
309
- if kernel == "fir":
310
- fir_kernel = (1, 3, 3, 1)
311
- self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
312
- elif kernel == "sde_vp":
313
- self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
314
- else:
315
- self.upsample = Upsample2D(in_channels, use_conv=False)
316
- elif self.down:
317
- if kernel == "fir":
318
- fir_kernel = (1, 3, 3, 1)
319
- self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
320
- elif kernel == "sde_vp":
321
- self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
322
- else:
323
- self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
324
-
325
- self.use_nin_shortcut = self.in_channels != self.out_channels if use_nin_shortcut is None else use_nin_shortcut
326
-
327
- self.conv_shortcut = None
328
- if self.use_nin_shortcut:
329
- self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
330
-
331
- def forward(self, x, temb):
332
- hidden_states = x
333
-
334
- # make sure hidden states is in float32
335
- # when running in half-precision
336
- hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
337
- hidden_states = self.nonlinearity(hidden_states)
338
-
339
- if self.upsample is not None:
340
- x = self.upsample(x)
341
- hidden_states = self.upsample(hidden_states)
342
- elif self.downsample is not None:
343
- x = self.downsample(x)
344
- hidden_states = self.downsample(hidden_states)
345
-
346
- hidden_states = self.conv1(hidden_states)
347
-
348
- if temb is not None:
349
- temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
350
- hidden_states = hidden_states + temb
351
-
352
- # make sure hidden states is in float32
353
- # when running in half-precision
354
- hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
355
- hidden_states = self.nonlinearity(hidden_states)
356
-
357
- hidden_states = self.dropout(hidden_states)
358
- hidden_states = self.conv2(hidden_states)
359
-
360
- if self.conv_shortcut is not None:
361
- x = self.conv_shortcut(x)
362
-
363
- out = (x + hidden_states) / self.output_scale_factor
364
-
365
- return out
366
-
367
-
368
- class Mish(torch.nn.Module):
369
- def forward(self, x):
370
- return x * torch.tanh(torch.nn.functional.softplus(x))
371
-
372
-
373
- def upsample_2d(x, kernel=None, factor=2, gain=1):
374
- r"""Upsample2D a batch of 2D images with the given filter.
375
-
376
- Args:
377
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
378
- filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
379
- `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a:
380
- multiple of the upsampling factor.
381
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
382
- C]`.
383
- k: FIR filter of the shape `[firH, firW]` or `[firN]`
384
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
385
- factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
386
-
387
- Returns:
388
- Tensor of the shape `[N, C, H * factor, W * factor]`
389
- """
390
- assert isinstance(factor, int) and factor >= 1
391
- if kernel is None:
392
- kernel = [1] * factor
393
-
394
- kernel = np.asarray(kernel, dtype=np.float32)
395
- if kernel.ndim == 1:
396
- kernel = np.outer(kernel, kernel)
397
- kernel /= np.sum(kernel)
398
-
399
- kernel = kernel * (gain * (factor**2))
400
- p = kernel.shape[0] - factor
401
- return upfirdn2d_native(
402
- x, torch.tensor(kernel, device=x.device), up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)
403
- )
404
-
405
-
406
- def downsample_2d(x, kernel=None, factor=2, gain=1):
407
- r"""Downsample2D a batch of 2D images with the given filter.
408
-
409
- Args:
410
- Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
411
- given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
412
- specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
413
- shape is a multiple of the downsampling factor.
414
- x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W,
415
- C]`.
416
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
417
- (separable). The default is `[1] * factor`, which corresponds to average pooling.
418
- factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0).
419
-
420
- Returns:
421
- Tensor of the shape `[N, C, H // factor, W // factor]`
422
- """
423
-
424
- assert isinstance(factor, int) and factor >= 1
425
- if kernel is None:
426
- kernel = [1] * factor
427
-
428
- kernel = np.asarray(kernel, dtype=np.float32)
429
- if kernel.ndim == 1:
430
- kernel = np.outer(kernel, kernel)
431
- kernel /= np.sum(kernel)
432
-
433
- kernel = kernel * gain
434
- p = kernel.shape[0] - factor
435
- return upfirdn2d_native(x, torch.tensor(kernel, device=x.device), down=factor, pad=((p + 1) // 2, p // 2))
436
-
437
-
438
- def upfirdn2d_native(input, kernel, up=1, down=1, pad=(0, 0)):
439
- up_x = up_y = up
440
- down_x = down_y = down
441
- pad_x0 = pad_y0 = pad[0]
442
- pad_x1 = pad_y1 = pad[1]
443
-
444
- _, channel, in_h, in_w = input.shape
445
- input = input.reshape(-1, in_h, in_w, 1)
446
-
447
- _, in_h, in_w, minor = input.shape
448
- kernel_h, kernel_w = kernel.shape
449
-
450
- out = input.view(-1, in_h, 1, in_w, 1, minor)
451
-
452
- # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
453
- if input.device.type == "mps":
454
- out = out.to("cpu")
455
- out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
456
- out = out.view(-1, in_h * up_y, in_w * up_x, minor)
457
-
458
- out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
459
- out = out.to(input.device) # Move back to mps if necessary
460
- out = out[
461
- :,
462
- max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
463
- max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
464
- :,
465
- ]
466
-
467
- out = out.permute(0, 3, 1, 2)
468
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
469
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
470
- out = F.conv2d(out, w)
471
- out = out.reshape(
472
- -1,
473
- minor,
474
- in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
475
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
476
- )
477
- out = out.permute(0, 2, 3, 1)
478
- out = out[:, ::down_y, ::down_x, :]
479
-
480
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
481
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
482
-
483
- return out.view(-1, channel, out_h, out_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/unet_2d.py DELETED
@@ -1,246 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from ..configuration_utils import ConfigMixin, register_to_config
8
- from ..modeling_utils import ModelMixin
9
- from ..utils import BaseOutput
10
- from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
11
- from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
-
13
-
14
- @dataclass
15
- class UNet2DOutput(BaseOutput):
16
- """
17
- Args:
18
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
- Hidden states output. Output of last layer of model.
20
- """
21
-
22
- sample: torch.FloatTensor
23
-
24
-
25
- class UNet2DModel(ModelMixin, ConfigMixin):
26
- r"""
27
- UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output.
28
-
29
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
30
- implements for all the model (such as downloading or saving, etc.)
31
-
32
- Parameters:
33
- sample_size (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
34
- Input sample size.
35
- in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image.
36
- out_channels (`int`, *optional*, defaults to 3): Number of channels in the output.
37
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
- time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use.
39
- freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding.
40
- flip_sin_to_cos (`bool`, *optional*, defaults to :
41
- obj:`False`): Whether to flip sin to cos for fourier time embedding.
42
- down_block_types (`Tuple[str]`, *optional*, defaults to :
43
- obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
44
- types.
45
- up_block_types (`Tuple[str]`, *optional*, defaults to :
46
- obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
47
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
48
- obj:`(224, 448, 672, 896)`): Tuple of block output channels.
49
- layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block.
50
- mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block.
51
- downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution.
52
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
53
- attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
54
- norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
55
- norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
56
- """
57
-
58
- @register_to_config
59
- def __init__(
60
- self,
61
- sample_size: Optional[int] = None,
62
- in_channels: int = 3,
63
- out_channels: int = 3,
64
- center_input_sample: bool = False,
65
- time_embedding_type: str = "positional",
66
- freq_shift: int = 0,
67
- flip_sin_to_cos: bool = True,
68
- down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
69
- up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
70
- block_out_channels: Tuple[int] = (224, 448, 672, 896),
71
- layers_per_block: int = 2,
72
- mid_block_scale_factor: float = 1,
73
- downsample_padding: int = 1,
74
- act_fn: str = "silu",
75
- attention_head_dim: int = 8,
76
- norm_num_groups: int = 32,
77
- norm_eps: float = 1e-5,
78
- ):
79
- super().__init__()
80
-
81
- self.sample_size = sample_size
82
- time_embed_dim = block_out_channels[0] * 4
83
-
84
- # input
85
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
86
-
87
- # time
88
- if time_embedding_type == "fourier":
89
- self.time_proj = GaussianFourierProjection(embedding_size=block_out_channels[0], scale=16)
90
- timestep_input_dim = 2 * block_out_channels[0]
91
- elif time_embedding_type == "positional":
92
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
- timestep_input_dim = block_out_channels[0]
94
-
95
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
-
97
- self.down_blocks = nn.ModuleList([])
98
- self.mid_block = None
99
- self.up_blocks = nn.ModuleList([])
100
-
101
- # down
102
- output_channel = block_out_channels[0]
103
- for i, down_block_type in enumerate(down_block_types):
104
- input_channel = output_channel
105
- output_channel = block_out_channels[i]
106
- is_final_block = i == len(block_out_channels) - 1
107
-
108
- down_block = get_down_block(
109
- down_block_type,
110
- num_layers=layers_per_block,
111
- in_channels=input_channel,
112
- out_channels=output_channel,
113
- temb_channels=time_embed_dim,
114
- add_downsample=not is_final_block,
115
- resnet_eps=norm_eps,
116
- resnet_act_fn=act_fn,
117
- attn_num_head_channels=attention_head_dim,
118
- downsample_padding=downsample_padding,
119
- )
120
- self.down_blocks.append(down_block)
121
-
122
- # mid
123
- self.mid_block = UNetMidBlock2D(
124
- in_channels=block_out_channels[-1],
125
- temb_channels=time_embed_dim,
126
- resnet_eps=norm_eps,
127
- resnet_act_fn=act_fn,
128
- output_scale_factor=mid_block_scale_factor,
129
- resnet_time_scale_shift="default",
130
- attn_num_head_channels=attention_head_dim,
131
- resnet_groups=norm_num_groups,
132
- )
133
-
134
- # up
135
- reversed_block_out_channels = list(reversed(block_out_channels))
136
- output_channel = reversed_block_out_channels[0]
137
- for i, up_block_type in enumerate(up_block_types):
138
- prev_output_channel = output_channel
139
- output_channel = reversed_block_out_channels[i]
140
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
141
-
142
- is_final_block = i == len(block_out_channels) - 1
143
-
144
- up_block = get_up_block(
145
- up_block_type,
146
- num_layers=layers_per_block + 1,
147
- in_channels=input_channel,
148
- out_channels=output_channel,
149
- prev_output_channel=prev_output_channel,
150
- temb_channels=time_embed_dim,
151
- add_upsample=not is_final_block,
152
- resnet_eps=norm_eps,
153
- resnet_act_fn=act_fn,
154
- attn_num_head_channels=attention_head_dim,
155
- )
156
- self.up_blocks.append(up_block)
157
- prev_output_channel = output_channel
158
-
159
- # out
160
- num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
161
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=norm_eps)
162
- self.conv_act = nn.SiLU()
163
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
164
-
165
- def forward(
166
- self,
167
- sample: torch.FloatTensor,
168
- timestep: Union[torch.Tensor, float, int],
169
- return_dict: bool = True,
170
- ) -> Union[UNet2DOutput, Tuple]:
171
- """r
172
- Args:
173
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
174
- timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
175
- return_dict (`bool`, *optional*, defaults to `True`):
176
- Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple.
177
-
178
- Returns:
179
- [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True,
180
- otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
181
- """
182
- # 0. center input if necessary
183
- if self.config.center_input_sample:
184
- sample = 2 * sample - 1.0
185
-
186
- # 1. time
187
- timesteps = timestep
188
- if not torch.is_tensor(timesteps):
189
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
190
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
191
- timesteps = timesteps[None].to(sample.device)
192
-
193
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
194
- timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)
195
-
196
- t_emb = self.time_proj(timesteps)
197
- emb = self.time_embedding(t_emb)
198
-
199
- # 2. pre-process
200
- skip_sample = sample
201
- sample = self.conv_in(sample)
202
-
203
- # 3. down
204
- down_block_res_samples = (sample,)
205
- for downsample_block in self.down_blocks:
206
- if hasattr(downsample_block, "skip_conv"):
207
- sample, res_samples, skip_sample = downsample_block(
208
- hidden_states=sample, temb=emb, skip_sample=skip_sample
209
- )
210
- else:
211
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
212
-
213
- down_block_res_samples += res_samples
214
-
215
- # 4. mid
216
- sample = self.mid_block(sample, emb)
217
-
218
- # 5. up
219
- skip_sample = None
220
- for upsample_block in self.up_blocks:
221
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
222
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
223
-
224
- if hasattr(upsample_block, "skip_conv"):
225
- sample, skip_sample = upsample_block(sample, res_samples, emb, skip_sample)
226
- else:
227
- sample = upsample_block(sample, res_samples, emb)
228
-
229
- # 6. post-process
230
- # make sure hidden states is in float32
231
- # when running in half-precision
232
- sample = self.conv_norm_out(sample.float()).type(sample.dtype)
233
- sample = self.conv_act(sample)
234
- sample = self.conv_out(sample)
235
-
236
- if skip_sample is not None:
237
- sample += skip_sample
238
-
239
- if self.config.time_embedding_type == "fourier":
240
- timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
241
- sample = sample / timesteps
242
-
243
- if not return_dict:
244
- return (sample,)
245
-
246
- return UNet2DOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/unet_2d_condition.py DELETED
@@ -1,272 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from ..configuration_utils import ConfigMixin, register_to_config
8
- from ..modeling_utils import ModelMixin
9
- from ..utils import BaseOutput
10
- from .embeddings import TimestepEmbedding, Timesteps
11
- from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
12
-
13
-
14
- @dataclass
15
- class UNet2DConditionOutput(BaseOutput):
16
- """
17
- Args:
18
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
19
- Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
20
- """
21
-
22
- sample: torch.FloatTensor
23
-
24
-
25
- class UNet2DConditionModel(ModelMixin, ConfigMixin):
26
- r"""
27
- UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
28
- and returns sample shaped output.
29
-
30
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
31
- implements for all the model (such as downloading or saving, etc.)
32
-
33
- Parameters:
34
- sample_size (`int`, *optional*): The size of the input sample.
35
- in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
36
- out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
37
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
38
- flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
39
- Whether to flip the sin to cos in the time embedding.
40
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
41
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
42
- The tuple of downsample blocks to use.
43
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
44
- The tuple of upsample blocks to use.
45
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
46
- The tuple of output channels for each block.
47
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
48
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
49
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
50
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
51
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
52
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
53
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
54
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
55
- """
56
-
57
- @register_to_config
58
- def __init__(
59
- self,
60
- sample_size: Optional[int] = None,
61
- in_channels: int = 4,
62
- out_channels: int = 4,
63
- center_input_sample: bool = False,
64
- flip_sin_to_cos: bool = True,
65
- freq_shift: int = 0,
66
- down_block_types: Tuple[str] = (
67
- "CrossAttnDownBlock2D",
68
- "CrossAttnDownBlock2D",
69
- "CrossAttnDownBlock2D",
70
- "DownBlock2D",
71
- ),
72
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
73
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
74
- layers_per_block: int = 2,
75
- downsample_padding: int = 1,
76
- mid_block_scale_factor: float = 1,
77
- act_fn: str = "silu",
78
- norm_num_groups: int = 32,
79
- norm_eps: float = 1e-5,
80
- cross_attention_dim: int = 1280,
81
- attention_head_dim: int = 8,
82
- ):
83
- super().__init__()
84
-
85
- self.sample_size = sample_size
86
- time_embed_dim = block_out_channels[0] * 4
87
-
88
- # input
89
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
90
-
91
- # time
92
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
93
- timestep_input_dim = block_out_channels[0]
94
-
95
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
96
-
97
- self.down_blocks = nn.ModuleList([])
98
- self.mid_block = None
99
- self.up_blocks = nn.ModuleList([])
100
-
101
- # down
102
- output_channel = block_out_channels[0]
103
- for i, down_block_type in enumerate(down_block_types):
104
- input_channel = output_channel
105
- output_channel = block_out_channels[i]
106
- is_final_block = i == len(block_out_channels) - 1
107
-
108
- down_block = get_down_block(
109
- down_block_type,
110
- num_layers=layers_per_block,
111
- in_channels=input_channel,
112
- out_channels=output_channel,
113
- temb_channels=time_embed_dim,
114
- add_downsample=not is_final_block,
115
- resnet_eps=norm_eps,
116
- resnet_act_fn=act_fn,
117
- cross_attention_dim=cross_attention_dim,
118
- attn_num_head_channels=attention_head_dim,
119
- downsample_padding=downsample_padding,
120
- )
121
- self.down_blocks.append(down_block)
122
-
123
- # mid
124
- self.mid_block = UNetMidBlock2DCrossAttn(
125
- in_channels=block_out_channels[-1],
126
- temb_channels=time_embed_dim,
127
- resnet_eps=norm_eps,
128
- resnet_act_fn=act_fn,
129
- output_scale_factor=mid_block_scale_factor,
130
- resnet_time_scale_shift="default",
131
- cross_attention_dim=cross_attention_dim,
132
- attn_num_head_channels=attention_head_dim,
133
- resnet_groups=norm_num_groups,
134
- )
135
-
136
- # up
137
- reversed_block_out_channels = list(reversed(block_out_channels))
138
- output_channel = reversed_block_out_channels[0]
139
- for i, up_block_type in enumerate(up_block_types):
140
- prev_output_channel = output_channel
141
- output_channel = reversed_block_out_channels[i]
142
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
143
-
144
- is_final_block = i == len(block_out_channels) - 1
145
-
146
- up_block = get_up_block(
147
- up_block_type,
148
- num_layers=layers_per_block + 1,
149
- in_channels=input_channel,
150
- out_channels=output_channel,
151
- prev_output_channel=prev_output_channel,
152
- temb_channels=time_embed_dim,
153
- add_upsample=not is_final_block,
154
- resnet_eps=norm_eps,
155
- resnet_act_fn=act_fn,
156
- cross_attention_dim=cross_attention_dim,
157
- attn_num_head_channels=attention_head_dim,
158
- )
159
- self.up_blocks.append(up_block)
160
- prev_output_channel = output_channel
161
-
162
- # out
163
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
164
- self.conv_act = nn.SiLU()
165
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
166
-
167
- def set_attention_slice(self, slice_size):
168
- if slice_size is not None and self.config.attention_head_dim % slice_size != 0:
169
- raise ValueError(
170
- f"Make sure slice_size {slice_size} is a divisor of "
171
- f"the number of heads used in cross_attention {self.config.attention_head_dim}"
172
- )
173
- if slice_size is not None and slice_size > self.config.attention_head_dim:
174
- raise ValueError(
175
- f"Chunk_size {slice_size} has to be smaller or equal to "
176
- f"the number of heads used in cross_attention {self.config.attention_head_dim}"
177
- )
178
-
179
- for block in self.down_blocks:
180
- if hasattr(block, "attentions") and block.attentions is not None:
181
- block.set_attention_slice(slice_size)
182
-
183
- self.mid_block.set_attention_slice(slice_size)
184
-
185
- for block in self.up_blocks:
186
- if hasattr(block, "attentions") and block.attentions is not None:
187
- block.set_attention_slice(slice_size)
188
-
189
- def forward(
190
- self,
191
- sample: torch.FloatTensor,
192
- timestep: Union[torch.Tensor, float, int],
193
- encoder_hidden_states: torch.Tensor,
194
- return_dict: bool = True,
195
- ) -> Union[UNet2DConditionOutput, Tuple]:
196
- """r
197
- Args:
198
- sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
199
- timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps
200
- encoder_hidden_states (`torch.FloatTensor`): (batch, channel, height, width) encoder hidden states
201
- return_dict (`bool`, *optional*, defaults to `True`):
202
- Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
203
-
204
- Returns:
205
- [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
206
- [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
207
- returning a tuple, the first element is the sample tensor.
208
- """
209
- # 0. center input if necessary
210
- if self.config.center_input_sample:
211
- sample = 2 * sample - 1.0
212
-
213
- # 1. time
214
- timesteps = timestep
215
- if not torch.is_tensor(timesteps):
216
- timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
217
- elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
218
- timesteps = timesteps.to(dtype=torch.float32)
219
- timesteps = timesteps[None].to(device=sample.device)
220
-
221
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
222
- timesteps = timesteps.expand(sample.shape[0])
223
-
224
- t_emb = self.time_proj(timesteps)
225
- emb = self.time_embedding(t_emb)
226
-
227
- # 2. pre-process
228
- sample = self.conv_in(sample)
229
-
230
- # 3. down
231
- down_block_res_samples = (sample,)
232
- for downsample_block in self.down_blocks:
233
- if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
234
- sample, res_samples = downsample_block(
235
- hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
236
- )
237
- else:
238
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
239
-
240
- down_block_res_samples += res_samples
241
-
242
- # 4. mid
243
- sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
244
-
245
- # 5. up
246
- for upsample_block in self.up_blocks:
247
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
248
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
249
-
250
- if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
251
- sample = upsample_block(
252
- hidden_states=sample,
253
- temb=emb,
254
- res_hidden_states_tuple=res_samples,
255
- encoder_hidden_states=encoder_hidden_states,
256
- )
257
- else:
258
- sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)
259
-
260
- # 6. post-process
261
- # make sure hidden states is in float32
262
- # when running in half-precision
263
- sample = self.conv_norm_out(sample.float()).type(sample.dtype)
264
- sample = self.conv_act(sample)
265
- sample = self.conv_out(sample)
266
-
267
- return sample
268
-
269
- if not return_dict:
270
- return (sample,)
271
-
272
- return UNet2DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/unet_blocks.py DELETED
@@ -1,1484 +0,0 @@
1
- # Copyright 2022 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
-
14
- import numpy as np
15
-
16
- # limitations under the License.
17
- import torch
18
- from torch import nn
19
-
20
- from .attention import AttentionBlock, SpatialTransformer
21
- from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
22
-
23
-
24
- def get_down_block(
25
- down_block_type,
26
- num_layers,
27
- in_channels,
28
- out_channels,
29
- temb_channels,
30
- add_downsample,
31
- resnet_eps,
32
- resnet_act_fn,
33
- attn_num_head_channels,
34
- cross_attention_dim=None,
35
- downsample_padding=None,
36
- ):
37
- down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
38
- print(down_block_type)
39
- if down_block_type == "DownBlock2D":
40
- return DownBlock2D(
41
- num_layers=num_layers,
42
- in_channels=in_channels,
43
- out_channels=out_channels,
44
- temb_channels=temb_channels,
45
- add_downsample=add_downsample,
46
- resnet_eps=resnet_eps,
47
- resnet_act_fn=resnet_act_fn,
48
- downsample_padding=downsample_padding,
49
- )
50
- elif down_block_type == "AttnDownBlock2D":
51
- return AttnDownBlock2D(
52
- num_layers=num_layers,
53
- in_channels=in_channels,
54
- out_channels=out_channels,
55
- temb_channels=temb_channels,
56
- add_downsample=add_downsample,
57
- resnet_eps=resnet_eps,
58
- resnet_act_fn=resnet_act_fn,
59
- downsample_padding=downsample_padding,
60
- attn_num_head_channels=attn_num_head_channels,
61
- )
62
- elif down_block_type == "CrossAttnDownBlock2D":
63
- if cross_attention_dim is None:
64
- raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
65
- return CrossAttnDownBlock2D(
66
- num_layers=num_layers,
67
- in_channels=in_channels,
68
- out_channels=out_channels,
69
- temb_channels=temb_channels,
70
- add_downsample=add_downsample,
71
- resnet_eps=resnet_eps,
72
- resnet_act_fn=resnet_act_fn,
73
- downsample_padding=downsample_padding,
74
- cross_attention_dim=cross_attention_dim,
75
- attn_num_head_channels=attn_num_head_channels,
76
- )
77
- elif down_block_type == "SkipDownBlock2D":
78
- return SkipDownBlock2D(
79
- num_layers=num_layers,
80
- in_channels=in_channels,
81
- out_channels=out_channels,
82
- temb_channels=temb_channels,
83
- add_downsample=add_downsample,
84
- resnet_eps=resnet_eps,
85
- resnet_act_fn=resnet_act_fn,
86
- downsample_padding=downsample_padding,
87
- )
88
- elif down_block_type == "AttnSkipDownBlock2D":
89
- return AttnSkipDownBlock2D(
90
- num_layers=num_layers,
91
- in_channels=in_channels,
92
- out_channels=out_channels,
93
- temb_channels=temb_channels,
94
- add_downsample=add_downsample,
95
- resnet_eps=resnet_eps,
96
- resnet_act_fn=resnet_act_fn,
97
- downsample_padding=downsample_padding,
98
- attn_num_head_channels=attn_num_head_channels,
99
- )
100
- elif down_block_type == "DownEncoderBlock2D":
101
- return DownEncoderBlock2D(
102
- num_layers=num_layers,
103
- in_channels=in_channels,
104
- out_channels=out_channels,
105
- add_downsample=add_downsample,
106
- resnet_eps=resnet_eps,
107
- resnet_act_fn=resnet_act_fn,
108
- downsample_padding=downsample_padding,
109
- )
110
-
111
-
112
- def get_up_block(
113
- up_block_type,
114
- num_layers,
115
- in_channels,
116
- out_channels,
117
- prev_output_channel,
118
- temb_channels,
119
- add_upsample,
120
- resnet_eps,
121
- resnet_act_fn,
122
- attn_num_head_channels,
123
- cross_attention_dim=None,
124
- ):
125
- up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
126
- print(up_block_type)
127
- if up_block_type == "UpBlock2D":
128
- return UpBlock2D(
129
- num_layers=num_layers,
130
- in_channels=in_channels,
131
- out_channels=out_channels,
132
- prev_output_channel=prev_output_channel,
133
- temb_channels=temb_channels,
134
- add_upsample=add_upsample,
135
- resnet_eps=resnet_eps,
136
- resnet_act_fn=resnet_act_fn,
137
- )
138
- elif up_block_type == "CrossAttnUpBlock2D":
139
- if cross_attention_dim is None:
140
- raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
141
- return CrossAttnUpBlock2D(
142
- num_layers=num_layers,
143
- in_channels=in_channels,
144
- out_channels=out_channels,
145
- prev_output_channel=prev_output_channel,
146
- temb_channels=temb_channels,
147
- add_upsample=add_upsample,
148
- resnet_eps=resnet_eps,
149
- resnet_act_fn=resnet_act_fn,
150
- cross_attention_dim=cross_attention_dim,
151
- attn_num_head_channels=attn_num_head_channels,
152
- )
153
- elif up_block_type == "AttnUpBlock2D":
154
- return AttnUpBlock2D(
155
- num_layers=num_layers,
156
- in_channels=in_channels,
157
- out_channels=out_channels,
158
- prev_output_channel=prev_output_channel,
159
- temb_channels=temb_channels,
160
- add_upsample=add_upsample,
161
- resnet_eps=resnet_eps,
162
- resnet_act_fn=resnet_act_fn,
163
- attn_num_head_channels=attn_num_head_channels,
164
- )
165
- elif up_block_type == "SkipUpBlock2D":
166
- return SkipUpBlock2D(
167
- num_layers=num_layers,
168
- in_channels=in_channels,
169
- out_channels=out_channels,
170
- prev_output_channel=prev_output_channel,
171
- temb_channels=temb_channels,
172
- add_upsample=add_upsample,
173
- resnet_eps=resnet_eps,
174
- resnet_act_fn=resnet_act_fn,
175
- )
176
- elif up_block_type == "AttnSkipUpBlock2D":
177
- return AttnSkipUpBlock2D(
178
- num_layers=num_layers,
179
- in_channels=in_channels,
180
- out_channels=out_channels,
181
- prev_output_channel=prev_output_channel,
182
- temb_channels=temb_channels,
183
- add_upsample=add_upsample,
184
- resnet_eps=resnet_eps,
185
- resnet_act_fn=resnet_act_fn,
186
- attn_num_head_channels=attn_num_head_channels,
187
- )
188
- elif up_block_type == "UpDecoderBlock2D":
189
- return UpDecoderBlock2D(
190
- num_layers=num_layers,
191
- in_channels=in_channels,
192
- out_channels=out_channels,
193
- add_upsample=add_upsample,
194
- resnet_eps=resnet_eps,
195
- resnet_act_fn=resnet_act_fn,
196
- )
197
- raise ValueError(f"{up_block_type} does not exist.")
198
-
199
-
200
- class UNetMidBlock2D(nn.Module):
201
- def __init__(
202
- self,
203
- in_channels: int,
204
- temb_channels: int,
205
- dropout: float = 0.0,
206
- num_layers: int = 1,
207
- resnet_eps: float = 1e-6,
208
- resnet_time_scale_shift: str = "default",
209
- resnet_act_fn: str = "swish",
210
- resnet_groups: int = 32,
211
- resnet_pre_norm: bool = True,
212
- attn_num_head_channels=1,
213
- attention_type="default",
214
- output_scale_factor=1.0,
215
- **kwargs,
216
- ):
217
- super().__init__()
218
-
219
- self.attention_type = attention_type
220
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
221
-
222
- # there is always at least one resnet
223
- resnets = [
224
- ResnetBlock2D(
225
- in_channels=in_channels,
226
- out_channels=in_channels,
227
- temb_channels=temb_channels,
228
- eps=resnet_eps,
229
- groups=resnet_groups,
230
- dropout=dropout,
231
- time_embedding_norm=resnet_time_scale_shift,
232
- non_linearity=resnet_act_fn,
233
- output_scale_factor=output_scale_factor,
234
- pre_norm=resnet_pre_norm,
235
- )
236
- ]
237
- attentions = []
238
-
239
- for _ in range(num_layers):
240
- attentions.append(
241
- AttentionBlock(
242
- in_channels,
243
- num_head_channels=attn_num_head_channels,
244
- rescale_output_factor=output_scale_factor,
245
- eps=resnet_eps,
246
- num_groups=resnet_groups,
247
- )
248
- )
249
- resnets.append(
250
- ResnetBlock2D(
251
- in_channels=in_channels,
252
- out_channels=in_channels,
253
- temb_channels=temb_channels,
254
- eps=resnet_eps,
255
- groups=resnet_groups,
256
- dropout=dropout,
257
- time_embedding_norm=resnet_time_scale_shift,
258
- non_linearity=resnet_act_fn,
259
- output_scale_factor=output_scale_factor,
260
- pre_norm=resnet_pre_norm,
261
- )
262
- )
263
-
264
- self.attentions = nn.ModuleList(attentions)
265
- self.resnets = nn.ModuleList(resnets)
266
-
267
- def forward(self, hidden_states, temb=None, encoder_states=None):
268
- hidden_states = self.resnets[0](hidden_states, temb)
269
- print(self.attention_type)
270
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
271
- if self.attention_type == "default":
272
- hidden_states = attn(hidden_states)
273
- else:
274
- hidden_states = attn(hidden_states, encoder_states)
275
- hidden_states = resnet(hidden_states, temb)
276
-
277
- return hidden_states
278
-
279
-
280
- class UNetMidBlock2DCrossAttn(nn.Module):
281
- def __init__(
282
- self,
283
- in_channels: int,
284
- temb_channels: int,
285
- dropout: float = 0.0,
286
- num_layers: int = 1,
287
- resnet_eps: float = 1e-6,
288
- resnet_time_scale_shift: str = "default",
289
- resnet_act_fn: str = "swish",
290
- resnet_groups: int = 32,
291
- resnet_pre_norm: bool = True,
292
- attn_num_head_channels=1,
293
- attention_type="default",
294
- output_scale_factor=1.0,
295
- cross_attention_dim=1280,
296
- **kwargs,
297
- ):
298
- super().__init__()
299
-
300
- self.attention_type = attention_type
301
- self.attn_num_head_channels = attn_num_head_channels
302
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
303
-
304
- # there is always at least one resnet
305
- resnets = [
306
- ResnetBlock2D(
307
- in_channels=in_channels,
308
- out_channels=in_channels,
309
- temb_channels=temb_channels,
310
- eps=resnet_eps,
311
- groups=resnet_groups,
312
- dropout=dropout,
313
- time_embedding_norm=resnet_time_scale_shift,
314
- non_linearity=resnet_act_fn,
315
- output_scale_factor=output_scale_factor,
316
- pre_norm=resnet_pre_norm,
317
- )
318
- ]
319
- attentions = []
320
-
321
- for _ in range(num_layers):
322
- attentions.append(
323
- SpatialTransformer(
324
- in_channels,
325
- attn_num_head_channels,
326
- in_channels // attn_num_head_channels,
327
- depth=1,
328
- context_dim=cross_attention_dim,
329
- )
330
- )
331
- resnets.append(
332
- ResnetBlock2D(
333
- in_channels=in_channels,
334
- out_channels=in_channels,
335
- temb_channels=temb_channels,
336
- eps=resnet_eps,
337
- groups=resnet_groups,
338
- dropout=dropout,
339
- time_embedding_norm=resnet_time_scale_shift,
340
- non_linearity=resnet_act_fn,
341
- output_scale_factor=output_scale_factor,
342
- pre_norm=resnet_pre_norm,
343
- )
344
- )
345
-
346
- self.attentions = nn.ModuleList(attentions)
347
- self.resnets = nn.ModuleList(resnets)
348
-
349
- def set_attention_slice(self, slice_size):
350
- if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
351
- raise ValueError(
352
- f"Make sure slice_size {slice_size} is a divisor of "
353
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
354
- )
355
- if slice_size is not None and slice_size > self.attn_num_head_channels:
356
- raise ValueError(
357
- f"Chunk_size {slice_size} has to be smaller or equal to "
358
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
359
- )
360
-
361
- for attn in self.attentions:
362
- attn._set_attention_slice(slice_size)
363
-
364
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
365
- hidden_states = self.resnets[0](hidden_states, temb)
366
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
367
- hidden_states = attn(hidden_states, encoder_hidden_states)
368
- hidden_states = resnet(hidden_states, temb)
369
-
370
- return hidden_states
371
-
372
-
373
- class AttnDownBlock2D(nn.Module):
374
- def __init__(
375
- self,
376
- in_channels: int,
377
- out_channels: int,
378
- temb_channels: int,
379
- dropout: float = 0.0,
380
- num_layers: int = 1,
381
- resnet_eps: float = 1e-6,
382
- resnet_time_scale_shift: str = "default",
383
- resnet_act_fn: str = "swish",
384
- resnet_groups: int = 32,
385
- resnet_pre_norm: bool = True,
386
- attn_num_head_channels=1,
387
- attention_type="default",
388
- output_scale_factor=1.0,
389
- downsample_padding=1,
390
- add_downsample=True,
391
- ):
392
- super().__init__()
393
- resnets = []
394
- attentions = []
395
-
396
- self.attention_type = attention_type
397
-
398
- for i in range(num_layers):
399
- in_channels = in_channels if i == 0 else out_channels
400
- resnets.append(
401
- ResnetBlock2D(
402
- in_channels=in_channels,
403
- out_channels=out_channels,
404
- temb_channels=temb_channels,
405
- eps=resnet_eps,
406
- groups=resnet_groups,
407
- dropout=dropout,
408
- time_embedding_norm=resnet_time_scale_shift,
409
- non_linearity=resnet_act_fn,
410
- output_scale_factor=output_scale_factor,
411
- pre_norm=resnet_pre_norm,
412
- )
413
- )
414
- attentions.append(
415
- AttentionBlock(
416
- out_channels,
417
- num_head_channels=attn_num_head_channels,
418
- rescale_output_factor=output_scale_factor,
419
- eps=resnet_eps,
420
- )
421
- )
422
-
423
- self.attentions = nn.ModuleList(attentions)
424
- self.resnets = nn.ModuleList(resnets)
425
-
426
- if add_downsample:
427
- self.downsamplers = nn.ModuleList(
428
- [
429
- Downsample2D(
430
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
431
- )
432
- ]
433
- )
434
- else:
435
- self.downsamplers = None
436
-
437
- def forward(self, hidden_states, temb=None):
438
- output_states = ()
439
-
440
- for resnet, attn in zip(self.resnets, self.attentions):
441
- hidden_states = resnet(hidden_states, temb)
442
- hidden_states = attn(hidden_states)
443
- output_states += (hidden_states,)
444
-
445
- if self.downsamplers is not None:
446
- for downsampler in self.downsamplers:
447
- hidden_states = downsampler(hidden_states)
448
-
449
- output_states += (hidden_states,)
450
-
451
- return hidden_states, output_states
452
-
453
-
454
- class CrossAttnDownBlock2D(nn.Module):
455
- def __init__(
456
- self,
457
- in_channels: int,
458
- out_channels: int,
459
- temb_channels: int,
460
- dropout: float = 0.0,
461
- num_layers: int = 1,
462
- resnet_eps: float = 1e-6,
463
- resnet_time_scale_shift: str = "default",
464
- resnet_act_fn: str = "swish",
465
- resnet_groups: int = 32,
466
- resnet_pre_norm: bool = True,
467
- attn_num_head_channels=1,
468
- cross_attention_dim=1280,
469
- attention_type="default",
470
- output_scale_factor=1.0,
471
- downsample_padding=1,
472
- add_downsample=True,
473
- ):
474
- super().__init__()
475
- resnets = []
476
- attentions = []
477
-
478
- self.attention_type = attention_type
479
- self.attn_num_head_channels = attn_num_head_channels
480
-
481
- for i in range(num_layers):
482
- in_channels = in_channels if i == 0 else out_channels
483
- resnets.append(
484
- ResnetBlock2D(
485
- in_channels=in_channels,
486
- out_channels=out_channels,
487
- temb_channels=temb_channels,
488
- eps=resnet_eps,
489
- groups=resnet_groups,
490
- dropout=dropout,
491
- time_embedding_norm=resnet_time_scale_shift,
492
- non_linearity=resnet_act_fn,
493
- output_scale_factor=output_scale_factor,
494
- pre_norm=resnet_pre_norm,
495
- )
496
- )
497
- attentions.append(
498
- SpatialTransformer(
499
- out_channels,
500
- attn_num_head_channels,
501
- out_channels // attn_num_head_channels,
502
- depth=1,
503
- context_dim=cross_attention_dim,
504
- )
505
- )
506
- self.attentions = nn.ModuleList(attentions)
507
- self.resnets = nn.ModuleList(resnets)
508
-
509
- if add_downsample:
510
- self.downsamplers = nn.ModuleList(
511
- [
512
- Downsample2D(
513
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
514
- )
515
- ]
516
- )
517
- else:
518
- self.downsamplers = None
519
-
520
- def set_attention_slice(self, slice_size):
521
- if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
522
- raise ValueError(
523
- f"Make sure slice_size {slice_size} is a divisor of "
524
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
525
- )
526
- if slice_size is not None and slice_size > self.attn_num_head_channels:
527
- raise ValueError(
528
- f"Chunk_size {slice_size} has to be smaller or equal to "
529
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
530
- )
531
-
532
- for attn in self.attentions:
533
- attn._set_attention_slice(slice_size)
534
-
535
- def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
536
- output_states = ()
537
-
538
- for resnet, attn in zip(self.resnets, self.attentions):
539
- hidden_states = resnet(hidden_states, temb)
540
- hidden_states = attn(hidden_states, context=encoder_hidden_states)
541
- output_states += (hidden_states,)
542
-
543
- if self.downsamplers is not None:
544
- for downsampler in self.downsamplers:
545
- hidden_states = downsampler(hidden_states)
546
-
547
- output_states += (hidden_states,)
548
-
549
- return hidden_states, output_states
550
-
551
-
552
- class DownBlock2D(nn.Module):
553
- def __init__(
554
- self,
555
- in_channels: int,
556
- out_channels: int,
557
- temb_channels: int,
558
- dropout: float = 0.0,
559
- num_layers: int = 1,
560
- resnet_eps: float = 1e-6,
561
- resnet_time_scale_shift: str = "default",
562
- resnet_act_fn: str = "swish",
563
- resnet_groups: int = 32,
564
- resnet_pre_norm: bool = True,
565
- output_scale_factor=1.0,
566
- add_downsample=True,
567
- downsample_padding=1,
568
- ):
569
- super().__init__()
570
- resnets = []
571
-
572
- for i in range(num_layers):
573
- in_channels = in_channels if i == 0 else out_channels
574
- resnets.append(
575
- ResnetBlock2D(
576
- in_channels=in_channels,
577
- out_channels=out_channels,
578
- temb_channels=temb_channels,
579
- eps=resnet_eps,
580
- groups=resnet_groups,
581
- dropout=dropout,
582
- time_embedding_norm=resnet_time_scale_shift,
583
- non_linearity=resnet_act_fn,
584
- output_scale_factor=output_scale_factor,
585
- pre_norm=resnet_pre_norm,
586
- )
587
- )
588
-
589
- self.resnets = nn.ModuleList(resnets)
590
-
591
- if add_downsample:
592
- self.downsamplers = nn.ModuleList(
593
- [
594
- Downsample2D(
595
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
596
- )
597
- ]
598
- )
599
- else:
600
- self.downsamplers = None
601
-
602
- def forward(self, hidden_states, temb=None):
603
- output_states = ()
604
-
605
- for resnet in self.resnets:
606
- hidden_states = resnet(hidden_states, temb)
607
- output_states += (hidden_states,)
608
-
609
- if self.downsamplers is not None:
610
- for downsampler in self.downsamplers:
611
- hidden_states = downsampler(hidden_states)
612
-
613
- output_states += (hidden_states,)
614
-
615
- return hidden_states, output_states
616
-
617
-
618
- class DownEncoderBlock2D(nn.Module):
619
- def __init__(
620
- self,
621
- in_channels: int,
622
- out_channels: int,
623
- dropout: float = 0.0,
624
- num_layers: int = 1,
625
- resnet_eps: float = 1e-6,
626
- resnet_time_scale_shift: str = "default",
627
- resnet_act_fn: str = "swish",
628
- resnet_groups: int = 32,
629
- resnet_pre_norm: bool = True,
630
- output_scale_factor=1.0,
631
- add_downsample=True,
632
- downsample_padding=1,
633
- ):
634
- super().__init__()
635
- resnets = []
636
-
637
- for i in range(num_layers):
638
- in_channels = in_channels if i == 0 else out_channels
639
- resnets.append(
640
- ResnetBlock2D(
641
- in_channels=in_channels,
642
- out_channels=out_channels,
643
- temb_channels=None,
644
- eps=resnet_eps,
645
- groups=resnet_groups,
646
- dropout=dropout,
647
- time_embedding_norm=resnet_time_scale_shift,
648
- non_linearity=resnet_act_fn,
649
- output_scale_factor=output_scale_factor,
650
- pre_norm=resnet_pre_norm,
651
- )
652
- )
653
-
654
- self.resnets = nn.ModuleList(resnets)
655
-
656
- if add_downsample:
657
- self.downsamplers = nn.ModuleList(
658
- [
659
- Downsample2D(
660
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
661
- )
662
- ]
663
- )
664
- else:
665
- self.downsamplers = None
666
-
667
- def forward(self, hidden_states):
668
- for resnet in self.resnets:
669
- hidden_states = resnet(hidden_states, temb=None)
670
-
671
- if self.downsamplers is not None:
672
- for downsampler in self.downsamplers:
673
- hidden_states = downsampler(hidden_states)
674
-
675
- return hidden_states
676
-
677
-
678
- class AttnDownEncoderBlock2D(nn.Module):
679
- def __init__(
680
- self,
681
- in_channels: int,
682
- out_channels: int,
683
- dropout: float = 0.0,
684
- num_layers: int = 1,
685
- resnet_eps: float = 1e-6,
686
- resnet_time_scale_shift: str = "default",
687
- resnet_act_fn: str = "swish",
688
- resnet_groups: int = 32,
689
- resnet_pre_norm: bool = True,
690
- attn_num_head_channels=1,
691
- output_scale_factor=1.0,
692
- add_downsample=True,
693
- downsample_padding=1,
694
- ):
695
- super().__init__()
696
- resnets = []
697
- attentions = []
698
-
699
- for i in range(num_layers):
700
- in_channels = in_channels if i == 0 else out_channels
701
- resnets.append(
702
- ResnetBlock2D(
703
- in_channels=in_channels,
704
- out_channels=out_channels,
705
- temb_channels=None,
706
- eps=resnet_eps,
707
- groups=resnet_groups,
708
- dropout=dropout,
709
- time_embedding_norm=resnet_time_scale_shift,
710
- non_linearity=resnet_act_fn,
711
- output_scale_factor=output_scale_factor,
712
- pre_norm=resnet_pre_norm,
713
- )
714
- )
715
- attentions.append(
716
- AttentionBlock(
717
- out_channels,
718
- num_head_channels=attn_num_head_channels,
719
- rescale_output_factor=output_scale_factor,
720
- eps=resnet_eps,
721
- num_groups=resnet_groups,
722
- )
723
- )
724
-
725
- self.attentions = nn.ModuleList(attentions)
726
- self.resnets = nn.ModuleList(resnets)
727
-
728
- if add_downsample:
729
- self.downsamplers = nn.ModuleList(
730
- [
731
- Downsample2D(
732
- in_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
733
- )
734
- ]
735
- )
736
- else:
737
- self.downsamplers = None
738
-
739
- def forward(self, hidden_states):
740
- for resnet, attn in zip(self.resnets, self.attentions):
741
- hidden_states = resnet(hidden_states, temb=None)
742
- hidden_states = attn(hidden_states)
743
-
744
- if self.downsamplers is not None:
745
- for downsampler in self.downsamplers:
746
- hidden_states = downsampler(hidden_states)
747
-
748
- return hidden_states
749
-
750
-
751
- class AttnSkipDownBlock2D(nn.Module):
752
- def __init__(
753
- self,
754
- in_channels: int,
755
- out_channels: int,
756
- temb_channels: int,
757
- dropout: float = 0.0,
758
- num_layers: int = 1,
759
- resnet_eps: float = 1e-6,
760
- resnet_time_scale_shift: str = "default",
761
- resnet_act_fn: str = "swish",
762
- resnet_pre_norm: bool = True,
763
- attn_num_head_channels=1,
764
- attention_type="default",
765
- output_scale_factor=np.sqrt(2.0),
766
- downsample_padding=1,
767
- add_downsample=True,
768
- ):
769
- super().__init__()
770
- self.attentions = nn.ModuleList([])
771
- self.resnets = nn.ModuleList([])
772
-
773
- self.attention_type = attention_type
774
-
775
- for i in range(num_layers):
776
- in_channels = in_channels if i == 0 else out_channels
777
- self.resnets.append(
778
- ResnetBlock2D(
779
- in_channels=in_channels,
780
- out_channels=out_channels,
781
- temb_channels=temb_channels,
782
- eps=resnet_eps,
783
- groups=min(in_channels // 4, 32),
784
- groups_out=min(out_channels // 4, 32),
785
- dropout=dropout,
786
- time_embedding_norm=resnet_time_scale_shift,
787
- non_linearity=resnet_act_fn,
788
- output_scale_factor=output_scale_factor,
789
- pre_norm=resnet_pre_norm,
790
- )
791
- )
792
- self.attentions.append(
793
- AttentionBlock(
794
- out_channels,
795
- num_head_channels=attn_num_head_channels,
796
- rescale_output_factor=output_scale_factor,
797
- eps=resnet_eps,
798
- )
799
- )
800
-
801
- if add_downsample:
802
- self.resnet_down = ResnetBlock2D(
803
- in_channels=out_channels,
804
- out_channels=out_channels,
805
- temb_channels=temb_channels,
806
- eps=resnet_eps,
807
- groups=min(out_channels // 4, 32),
808
- dropout=dropout,
809
- time_embedding_norm=resnet_time_scale_shift,
810
- non_linearity=resnet_act_fn,
811
- output_scale_factor=output_scale_factor,
812
- pre_norm=resnet_pre_norm,
813
- use_nin_shortcut=True,
814
- down=True,
815
- kernel="fir",
816
- )
817
- self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
818
- self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
819
- else:
820
- self.resnet_down = None
821
- self.downsamplers = None
822
- self.skip_conv = None
823
-
824
- def forward(self, hidden_states, temb=None, skip_sample=None):
825
- output_states = ()
826
-
827
- for resnet, attn in zip(self.resnets, self.attentions):
828
- hidden_states = resnet(hidden_states, temb)
829
- hidden_states = attn(hidden_states)
830
- output_states += (hidden_states,)
831
-
832
- if self.downsamplers is not None:
833
- hidden_states = self.resnet_down(hidden_states, temb)
834
- for downsampler in self.downsamplers:
835
- skip_sample = downsampler(skip_sample)
836
-
837
- hidden_states = self.skip_conv(skip_sample) + hidden_states
838
-
839
- output_states += (hidden_states,)
840
-
841
- return hidden_states, output_states, skip_sample
842
-
843
-
844
- class SkipDownBlock2D(nn.Module):
845
- def __init__(
846
- self,
847
- in_channels: int,
848
- out_channels: int,
849
- temb_channels: int,
850
- dropout: float = 0.0,
851
- num_layers: int = 1,
852
- resnet_eps: float = 1e-6,
853
- resnet_time_scale_shift: str = "default",
854
- resnet_act_fn: str = "swish",
855
- resnet_pre_norm: bool = True,
856
- output_scale_factor=np.sqrt(2.0),
857
- add_downsample=True,
858
- downsample_padding=1,
859
- ):
860
- super().__init__()
861
- self.resnets = nn.ModuleList([])
862
-
863
- for i in range(num_layers):
864
- in_channels = in_channels if i == 0 else out_channels
865
- self.resnets.append(
866
- ResnetBlock2D(
867
- in_channels=in_channels,
868
- out_channels=out_channels,
869
- temb_channels=temb_channels,
870
- eps=resnet_eps,
871
- groups=min(in_channels // 4, 32),
872
- groups_out=min(out_channels // 4, 32),
873
- dropout=dropout,
874
- time_embedding_norm=resnet_time_scale_shift,
875
- non_linearity=resnet_act_fn,
876
- output_scale_factor=output_scale_factor,
877
- pre_norm=resnet_pre_norm,
878
- )
879
- )
880
-
881
- if add_downsample:
882
- self.resnet_down = ResnetBlock2D(
883
- in_channels=out_channels,
884
- out_channels=out_channels,
885
- temb_channels=temb_channels,
886
- eps=resnet_eps,
887
- groups=min(out_channels // 4, 32),
888
- dropout=dropout,
889
- time_embedding_norm=resnet_time_scale_shift,
890
- non_linearity=resnet_act_fn,
891
- output_scale_factor=output_scale_factor,
892
- pre_norm=resnet_pre_norm,
893
- use_nin_shortcut=True,
894
- down=True,
895
- kernel="fir",
896
- )
897
- self.downsamplers = nn.ModuleList([FirDownsample2D(in_channels, out_channels=out_channels)])
898
- self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
899
- else:
900
- self.resnet_down = None
901
- self.downsamplers = None
902
- self.skip_conv = None
903
-
904
- def forward(self, hidden_states, temb=None, skip_sample=None):
905
- output_states = ()
906
-
907
- for resnet in self.resnets:
908
- hidden_states = resnet(hidden_states, temb)
909
- output_states += (hidden_states,)
910
-
911
- if self.downsamplers is not None:
912
- hidden_states = self.resnet_down(hidden_states, temb)
913
- for downsampler in self.downsamplers:
914
- skip_sample = downsampler(skip_sample)
915
-
916
- hidden_states = self.skip_conv(skip_sample) + hidden_states
917
-
918
- output_states += (hidden_states,)
919
-
920
- return hidden_states, output_states, skip_sample
921
-
922
-
923
- class AttnUpBlock2D(nn.Module):
924
- def __init__(
925
- self,
926
- in_channels: int,
927
- prev_output_channel: int,
928
- out_channels: int,
929
- temb_channels: int,
930
- dropout: float = 0.0,
931
- num_layers: int = 1,
932
- resnet_eps: float = 1e-6,
933
- resnet_time_scale_shift: str = "default",
934
- resnet_act_fn: str = "swish",
935
- resnet_groups: int = 32,
936
- resnet_pre_norm: bool = True,
937
- attention_type="default",
938
- attn_num_head_channels=1,
939
- output_scale_factor=1.0,
940
- add_upsample=True,
941
- ):
942
- super().__init__()
943
- resnets = []
944
- attentions = []
945
-
946
- self.attention_type = attention_type
947
-
948
- for i in range(num_layers):
949
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
950
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
951
-
952
- resnets.append(
953
- ResnetBlock2D(
954
- in_channels=resnet_in_channels + res_skip_channels,
955
- out_channels=out_channels,
956
- temb_channels=temb_channels,
957
- eps=resnet_eps,
958
- groups=resnet_groups,
959
- dropout=dropout,
960
- time_embedding_norm=resnet_time_scale_shift,
961
- non_linearity=resnet_act_fn,
962
- output_scale_factor=output_scale_factor,
963
- pre_norm=resnet_pre_norm,
964
- )
965
- )
966
- attentions.append(
967
- AttentionBlock(
968
- out_channels,
969
- num_head_channels=attn_num_head_channels,
970
- rescale_output_factor=output_scale_factor,
971
- eps=resnet_eps,
972
- )
973
- )
974
-
975
- self.attentions = nn.ModuleList(attentions)
976
- self.resnets = nn.ModuleList(resnets)
977
-
978
- if add_upsample:
979
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
980
- else:
981
- self.upsamplers = None
982
-
983
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
984
- for resnet, attn in zip(self.resnets, self.attentions):
985
-
986
- # pop res hidden states
987
- res_hidden_states = res_hidden_states_tuple[-1]
988
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
989
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
990
-
991
- hidden_states = resnet(hidden_states, temb)
992
- hidden_states = attn(hidden_states)
993
-
994
- if self.upsamplers is not None:
995
- for upsampler in self.upsamplers:
996
- hidden_states = upsampler(hidden_states)
997
-
998
- return hidden_states
999
-
1000
-
1001
- class CrossAttnUpBlock2D(nn.Module):
1002
- def __init__(
1003
- self,
1004
- in_channels: int,
1005
- out_channels: int,
1006
- prev_output_channel: int,
1007
- temb_channels: int,
1008
- dropout: float = 0.0,
1009
- num_layers: int = 1,
1010
- resnet_eps: float = 1e-6,
1011
- resnet_time_scale_shift: str = "default",
1012
- resnet_act_fn: str = "swish",
1013
- resnet_groups: int = 32,
1014
- resnet_pre_norm: bool = True,
1015
- attn_num_head_channels=1,
1016
- cross_attention_dim=1280,
1017
- attention_type="default",
1018
- output_scale_factor=1.0,
1019
- downsample_padding=1,
1020
- add_upsample=True,
1021
- ):
1022
- super().__init__()
1023
- resnets = []
1024
- attentions = []
1025
-
1026
- self.attention_type = attention_type
1027
- self.attn_num_head_channels = attn_num_head_channels
1028
-
1029
- for i in range(num_layers):
1030
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1031
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1032
-
1033
- resnets.append(
1034
- ResnetBlock2D(
1035
- in_channels=resnet_in_channels + res_skip_channels,
1036
- out_channels=out_channels,
1037
- temb_channels=temb_channels,
1038
- eps=resnet_eps,
1039
- groups=resnet_groups,
1040
- dropout=dropout,
1041
- time_embedding_norm=resnet_time_scale_shift,
1042
- non_linearity=resnet_act_fn,
1043
- output_scale_factor=output_scale_factor,
1044
- pre_norm=resnet_pre_norm,
1045
- )
1046
- )
1047
- attentions.append(
1048
- SpatialTransformer(
1049
- out_channels,
1050
- attn_num_head_channels,
1051
- out_channels // attn_num_head_channels,
1052
- depth=1,
1053
- context_dim=cross_attention_dim,
1054
- )
1055
- )
1056
- self.attentions = nn.ModuleList(attentions)
1057
- self.resnets = nn.ModuleList(resnets)
1058
-
1059
- if add_upsample:
1060
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1061
- else:
1062
- self.upsamplers = None
1063
-
1064
- def set_attention_slice(self, slice_size):
1065
- if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
1066
- raise ValueError(
1067
- f"Make sure slice_size {slice_size} is a divisor of "
1068
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1069
- )
1070
- if slice_size is not None and slice_size > self.attn_num_head_channels:
1071
- raise ValueError(
1072
- f"Chunk_size {slice_size} has to be smaller or equal to "
1073
- f"the number of heads used in cross_attention {self.attn_num_head_channels}"
1074
- )
1075
-
1076
- for attn in self.attentions:
1077
- attn._set_attention_slice(slice_size)
1078
-
1079
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
1080
- for resnet, attn in zip(self.resnets, self.attentions):
1081
-
1082
- # pop res hidden states
1083
- res_hidden_states = res_hidden_states_tuple[-1]
1084
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1085
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1086
-
1087
- hidden_states = resnet(hidden_states, temb)
1088
- hidden_states = attn(hidden_states, context=encoder_hidden_states)
1089
-
1090
- if self.upsamplers is not None:
1091
- for upsampler in self.upsamplers:
1092
- hidden_states = upsampler(hidden_states)
1093
-
1094
- return hidden_states
1095
-
1096
-
1097
- class UpBlock2D(nn.Module):
1098
- def __init__(
1099
- self,
1100
- in_channels: int,
1101
- prev_output_channel: int,
1102
- out_channels: int,
1103
- temb_channels: int,
1104
- dropout: float = 0.0,
1105
- num_layers: int = 1,
1106
- resnet_eps: float = 1e-6,
1107
- resnet_time_scale_shift: str = "default",
1108
- resnet_act_fn: str = "swish",
1109
- resnet_groups: int = 32,
1110
- resnet_pre_norm: bool = True,
1111
- output_scale_factor=1.0,
1112
- add_upsample=True,
1113
- ):
1114
- super().__init__()
1115
- resnets = []
1116
-
1117
- for i in range(num_layers):
1118
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1119
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1120
-
1121
- resnets.append(
1122
- ResnetBlock2D(
1123
- in_channels=resnet_in_channels + res_skip_channels,
1124
- out_channels=out_channels,
1125
- temb_channels=temb_channels,
1126
- eps=resnet_eps,
1127
- groups=resnet_groups,
1128
- dropout=dropout,
1129
- time_embedding_norm=resnet_time_scale_shift,
1130
- non_linearity=resnet_act_fn,
1131
- output_scale_factor=output_scale_factor,
1132
- pre_norm=resnet_pre_norm,
1133
- )
1134
- )
1135
-
1136
- self.resnets = nn.ModuleList(resnets)
1137
-
1138
- if add_upsample:
1139
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1140
- else:
1141
- self.upsamplers = None
1142
-
1143
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
1144
- for resnet in self.resnets:
1145
-
1146
- # pop res hidden states
1147
- res_hidden_states = res_hidden_states_tuple[-1]
1148
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1149
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1150
-
1151
- hidden_states = resnet(hidden_states, temb)
1152
-
1153
- if self.upsamplers is not None:
1154
- for upsampler in self.upsamplers:
1155
- hidden_states = upsampler(hidden_states)
1156
-
1157
- return hidden_states
1158
-
1159
-
1160
- class UpDecoderBlock2D(nn.Module):
1161
- def __init__(
1162
- self,
1163
- in_channels: int,
1164
- out_channels: int,
1165
- dropout: float = 0.0,
1166
- num_layers: int = 1,
1167
- resnet_eps: float = 1e-6,
1168
- resnet_time_scale_shift: str = "default",
1169
- resnet_act_fn: str = "swish",
1170
- resnet_groups: int = 32,
1171
- resnet_pre_norm: bool = True,
1172
- output_scale_factor=1.0,
1173
- add_upsample=True,
1174
- ):
1175
- super().__init__()
1176
- resnets = []
1177
-
1178
- for i in range(num_layers):
1179
- input_channels = in_channels if i == 0 else out_channels
1180
-
1181
- resnets.append(
1182
- ResnetBlock2D(
1183
- in_channels=input_channels,
1184
- out_channels=out_channels,
1185
- temb_channels=None,
1186
- eps=resnet_eps,
1187
- groups=resnet_groups,
1188
- dropout=dropout,
1189
- time_embedding_norm=resnet_time_scale_shift,
1190
- non_linearity=resnet_act_fn,
1191
- output_scale_factor=output_scale_factor,
1192
- pre_norm=resnet_pre_norm,
1193
- )
1194
- )
1195
-
1196
- self.resnets = nn.ModuleList(resnets)
1197
-
1198
- if add_upsample:
1199
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1200
- else:
1201
- self.upsamplers = None
1202
-
1203
- def forward(self, hidden_states):
1204
- for resnet in self.resnets:
1205
- hidden_states = resnet(hidden_states, temb=None)
1206
-
1207
- if self.upsamplers is not None:
1208
- for upsampler in self.upsamplers:
1209
- hidden_states = upsampler(hidden_states)
1210
-
1211
- return hidden_states
1212
-
1213
-
1214
- class AttnUpDecoderBlock2D(nn.Module):
1215
- def __init__(
1216
- self,
1217
- in_channels: int,
1218
- out_channels: int,
1219
- dropout: float = 0.0,
1220
- num_layers: int = 1,
1221
- resnet_eps: float = 1e-6,
1222
- resnet_time_scale_shift: str = "default",
1223
- resnet_act_fn: str = "swish",
1224
- resnet_groups: int = 32,
1225
- resnet_pre_norm: bool = True,
1226
- attn_num_head_channels=1,
1227
- output_scale_factor=1.0,
1228
- add_upsample=True,
1229
- ):
1230
- super().__init__()
1231
- resnets = []
1232
- attentions = []
1233
-
1234
- for i in range(num_layers):
1235
- input_channels = in_channels if i == 0 else out_channels
1236
-
1237
- resnets.append(
1238
- ResnetBlock2D(
1239
- in_channels=input_channels,
1240
- out_channels=out_channels,
1241
- temb_channels=None,
1242
- eps=resnet_eps,
1243
- groups=resnet_groups,
1244
- dropout=dropout,
1245
- time_embedding_norm=resnet_time_scale_shift,
1246
- non_linearity=resnet_act_fn,
1247
- output_scale_factor=output_scale_factor,
1248
- pre_norm=resnet_pre_norm,
1249
- )
1250
- )
1251
- attentions.append(
1252
- AttentionBlock(
1253
- out_channels,
1254
- num_head_channels=attn_num_head_channels,
1255
- rescale_output_factor=output_scale_factor,
1256
- eps=resnet_eps,
1257
- num_groups=resnet_groups,
1258
- )
1259
- )
1260
-
1261
- self.attentions = nn.ModuleList(attentions)
1262
- self.resnets = nn.ModuleList(resnets)
1263
-
1264
- if add_upsample:
1265
- self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
1266
- else:
1267
- self.upsamplers = None
1268
-
1269
- def forward(self, hidden_states):
1270
- for resnet, attn in zip(self.resnets, self.attentions):
1271
- hidden_states = resnet(hidden_states, temb=None)
1272
- hidden_states = attn(hidden_states)
1273
-
1274
- if self.upsamplers is not None:
1275
- for upsampler in self.upsamplers:
1276
- hidden_states = upsampler(hidden_states)
1277
-
1278
- return hidden_states
1279
-
1280
-
1281
- class AttnSkipUpBlock2D(nn.Module):
1282
- def __init__(
1283
- self,
1284
- in_channels: int,
1285
- prev_output_channel: int,
1286
- out_channels: int,
1287
- temb_channels: int,
1288
- dropout: float = 0.0,
1289
- num_layers: int = 1,
1290
- resnet_eps: float = 1e-6,
1291
- resnet_time_scale_shift: str = "default",
1292
- resnet_act_fn: str = "swish",
1293
- resnet_pre_norm: bool = True,
1294
- attn_num_head_channels=1,
1295
- attention_type="default",
1296
- output_scale_factor=np.sqrt(2.0),
1297
- upsample_padding=1,
1298
- add_upsample=True,
1299
- ):
1300
- super().__init__()
1301
- self.attentions = nn.ModuleList([])
1302
- self.resnets = nn.ModuleList([])
1303
-
1304
- self.attention_type = attention_type
1305
-
1306
- for i in range(num_layers):
1307
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1308
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1309
-
1310
- self.resnets.append(
1311
- ResnetBlock2D(
1312
- in_channels=resnet_in_channels + res_skip_channels,
1313
- out_channels=out_channels,
1314
- temb_channels=temb_channels,
1315
- eps=resnet_eps,
1316
- groups=min(resnet_in_channels + res_skip_channels // 4, 32),
1317
- groups_out=min(out_channels // 4, 32),
1318
- dropout=dropout,
1319
- time_embedding_norm=resnet_time_scale_shift,
1320
- non_linearity=resnet_act_fn,
1321
- output_scale_factor=output_scale_factor,
1322
- pre_norm=resnet_pre_norm,
1323
- )
1324
- )
1325
-
1326
- self.attentions.append(
1327
- AttentionBlock(
1328
- out_channels,
1329
- num_head_channels=attn_num_head_channels,
1330
- rescale_output_factor=output_scale_factor,
1331
- eps=resnet_eps,
1332
- )
1333
- )
1334
-
1335
- self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1336
- if add_upsample:
1337
- self.resnet_up = ResnetBlock2D(
1338
- in_channels=out_channels,
1339
- out_channels=out_channels,
1340
- temb_channels=temb_channels,
1341
- eps=resnet_eps,
1342
- groups=min(out_channels // 4, 32),
1343
- groups_out=min(out_channels // 4, 32),
1344
- dropout=dropout,
1345
- time_embedding_norm=resnet_time_scale_shift,
1346
- non_linearity=resnet_act_fn,
1347
- output_scale_factor=output_scale_factor,
1348
- pre_norm=resnet_pre_norm,
1349
- use_nin_shortcut=True,
1350
- up=True,
1351
- kernel="fir",
1352
- )
1353
- self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1354
- self.skip_norm = torch.nn.GroupNorm(
1355
- num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1356
- )
1357
- self.act = nn.SiLU()
1358
- else:
1359
- self.resnet_up = None
1360
- self.skip_conv = None
1361
- self.skip_norm = None
1362
- self.act = None
1363
-
1364
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1365
- for resnet in self.resnets:
1366
- # pop res hidden states
1367
- res_hidden_states = res_hidden_states_tuple[-1]
1368
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1369
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1370
-
1371
- hidden_states = resnet(hidden_states, temb)
1372
-
1373
- hidden_states = self.attentions[0](hidden_states)
1374
-
1375
- if skip_sample is not None:
1376
- skip_sample = self.upsampler(skip_sample)
1377
- else:
1378
- skip_sample = 0
1379
-
1380
- if self.resnet_up is not None:
1381
- skip_sample_states = self.skip_norm(hidden_states)
1382
- skip_sample_states = self.act(skip_sample_states)
1383
- skip_sample_states = self.skip_conv(skip_sample_states)
1384
-
1385
- skip_sample = skip_sample + skip_sample_states
1386
-
1387
- hidden_states = self.resnet_up(hidden_states, temb)
1388
-
1389
- return hidden_states, skip_sample
1390
-
1391
-
1392
- class SkipUpBlock2D(nn.Module):
1393
- def __init__(
1394
- self,
1395
- in_channels: int,
1396
- prev_output_channel: int,
1397
- out_channels: int,
1398
- temb_channels: int,
1399
- dropout: float = 0.0,
1400
- num_layers: int = 1,
1401
- resnet_eps: float = 1e-6,
1402
- resnet_time_scale_shift: str = "default",
1403
- resnet_act_fn: str = "swish",
1404
- resnet_pre_norm: bool = True,
1405
- output_scale_factor=np.sqrt(2.0),
1406
- add_upsample=True,
1407
- upsample_padding=1,
1408
- ):
1409
- super().__init__()
1410
- self.resnets = nn.ModuleList([])
1411
-
1412
- for i in range(num_layers):
1413
- res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
1414
- resnet_in_channels = prev_output_channel if i == 0 else out_channels
1415
-
1416
- self.resnets.append(
1417
- ResnetBlock2D(
1418
- in_channels=resnet_in_channels + res_skip_channels,
1419
- out_channels=out_channels,
1420
- temb_channels=temb_channels,
1421
- eps=resnet_eps,
1422
- groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
1423
- groups_out=min(out_channels // 4, 32),
1424
- dropout=dropout,
1425
- time_embedding_norm=resnet_time_scale_shift,
1426
- non_linearity=resnet_act_fn,
1427
- output_scale_factor=output_scale_factor,
1428
- pre_norm=resnet_pre_norm,
1429
- )
1430
- )
1431
-
1432
- self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
1433
- if add_upsample:
1434
- self.resnet_up = ResnetBlock2D(
1435
- in_channels=out_channels,
1436
- out_channels=out_channels,
1437
- temb_channels=temb_channels,
1438
- eps=resnet_eps,
1439
- groups=min(out_channels // 4, 32),
1440
- groups_out=min(out_channels // 4, 32),
1441
- dropout=dropout,
1442
- time_embedding_norm=resnet_time_scale_shift,
1443
- non_linearity=resnet_act_fn,
1444
- output_scale_factor=output_scale_factor,
1445
- pre_norm=resnet_pre_norm,
1446
- use_nin_shortcut=True,
1447
- up=True,
1448
- kernel="fir",
1449
- )
1450
- self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1451
- self.skip_norm = torch.nn.GroupNorm(
1452
- num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
1453
- )
1454
- self.act = nn.SiLU()
1455
- else:
1456
- self.resnet_up = None
1457
- self.skip_conv = None
1458
- self.skip_norm = None
1459
- self.act = None
1460
-
1461
- def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
1462
- for resnet in self.resnets:
1463
- # pop res hidden states
1464
- res_hidden_states = res_hidden_states_tuple[-1]
1465
- res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1466
- hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1467
-
1468
- hidden_states = resnet(hidden_states, temb)
1469
-
1470
- if skip_sample is not None:
1471
- skip_sample = self.upsampler(skip_sample)
1472
- else:
1473
- skip_sample = 0
1474
-
1475
- if self.resnet_up is not None:
1476
- skip_sample_states = self.skip_norm(hidden_states)
1477
- skip_sample_states = self.act(skip_sample_states)
1478
- skip_sample_states = self.skip_conv(skip_sample_states)
1479
-
1480
- skip_sample = skip_sample + skip_sample_states
1481
-
1482
- hidden_states = self.resnet_up(hidden_states, temb)
1483
-
1484
- return hidden_states, skip_sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/models/vae.py DELETED
@@ -1,585 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional, Tuple, Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
-
8
- from ..configuration_utils import ConfigMixin, register_to_config
9
- from ..modeling_utils import ModelMixin
10
- from ..utils import BaseOutput
11
- from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
12
-
13
-
14
- @dataclass
15
- class DecoderOutput(BaseOutput):
16
- """
17
- Output of decoding method.
18
-
19
- Args:
20
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
21
- Decoded output sample of the model. Output of the last layer of the model.
22
- """
23
-
24
- sample: torch.FloatTensor
25
-
26
-
27
- @dataclass
28
- class VQEncoderOutput(BaseOutput):
29
- """
30
- Output of VQModel encoding method.
31
-
32
- Args:
33
- latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
34
- Encoded output sample of the model. Output of the last layer of the model.
35
- """
36
-
37
- latents: torch.FloatTensor
38
-
39
-
40
- @dataclass
41
- class AutoencoderKLOutput(BaseOutput):
42
- """
43
- Output of AutoencoderKL encoding method.
44
-
45
- Args:
46
- latent_dist (`DiagonalGaussianDistribution`):
47
- Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
48
- `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
49
- """
50
-
51
- latent_dist: "DiagonalGaussianDistribution"
52
-
53
-
54
- class Encoder(nn.Module):
55
- def __init__(
56
- self,
57
- in_channels=3,
58
- out_channels=3,
59
- down_block_types=("DownEncoderBlock2D",),
60
- block_out_channels=(64,),
61
- layers_per_block=2,
62
- act_fn="silu",
63
- double_z=True,
64
- ):
65
- super().__init__()
66
- self.layers_per_block = layers_per_block
67
-
68
- self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
69
-
70
- self.mid_block = None
71
- self.down_blocks = nn.ModuleList([])
72
-
73
- # down
74
- output_channel = block_out_channels[0]
75
- for i, down_block_type in enumerate(down_block_types):
76
- input_channel = output_channel
77
- output_channel = block_out_channels[i]
78
- is_final_block = i == len(block_out_channels) - 1
79
-
80
- down_block = get_down_block(
81
- down_block_type,
82
- num_layers=self.layers_per_block,
83
- in_channels=input_channel,
84
- out_channels=output_channel,
85
- add_downsample=not is_final_block,
86
- resnet_eps=1e-6,
87
- downsample_padding=0,
88
- resnet_act_fn=act_fn,
89
- attn_num_head_channels=None,
90
- temb_channels=None,
91
- )
92
- self.down_blocks.append(down_block)
93
-
94
- # mid
95
- self.mid_block = UNetMidBlock2D(
96
- in_channels=block_out_channels[-1],
97
- resnet_eps=1e-6,
98
- resnet_act_fn=act_fn,
99
- output_scale_factor=1,
100
- resnet_time_scale_shift="default",
101
- attn_num_head_channels=None,
102
- resnet_groups=32,
103
- temb_channels=None,
104
- )
105
-
106
- # out
107
- num_groups_out = 32
108
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups_out, eps=1e-6)
109
- self.conv_act = nn.SiLU()
110
-
111
- conv_out_channels = 2 * out_channels if double_z else out_channels
112
- self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
113
-
114
- def forward(self, x):
115
- sample = x
116
- sample = self.conv_in(sample)
117
-
118
- # down
119
- for down_block in self.down_blocks:
120
- sample = down_block(sample)
121
-
122
- # middle
123
- sample = self.mid_block(sample)
124
-
125
- # post-process
126
- sample = self.conv_norm_out(sample)
127
- sample = self.conv_act(sample)
128
- sample = self.conv_out(sample)
129
-
130
- return sample
131
-
132
-
133
- class Decoder(nn.Module):
134
- def __init__(
135
- self,
136
- in_channels=3,
137
- out_channels=3,
138
- up_block_types=("UpDecoderBlock2D",),
139
- block_out_channels=(64,),
140
- layers_per_block=2,
141
- act_fn="silu",
142
- ):
143
- super().__init__()
144
- self.layers_per_block = layers_per_block
145
-
146
- self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
147
-
148
- self.mid_block = None
149
- self.up_blocks = nn.ModuleList([])
150
-
151
- # mid
152
- self.mid_block = UNetMidBlock2D(
153
- in_channels=block_out_channels[-1],
154
- resnet_eps=1e-6,
155
- resnet_act_fn=act_fn,
156
- output_scale_factor=1,
157
- resnet_time_scale_shift="default",
158
- attn_num_head_channels=None,
159
- resnet_groups=32,
160
- temb_channels=None,
161
- )
162
-
163
- # up
164
- reversed_block_out_channels = list(reversed(block_out_channels))
165
- output_channel = reversed_block_out_channels[0]
166
- for i, up_block_type in enumerate(up_block_types):
167
- prev_output_channel = output_channel
168
- output_channel = reversed_block_out_channels[i]
169
-
170
- is_final_block = i == len(block_out_channels) - 1
171
-
172
- up_block = get_up_block(
173
- up_block_type,
174
- num_layers=self.layers_per_block + 1,
175
- in_channels=prev_output_channel,
176
- out_channels=output_channel,
177
- prev_output_channel=None,
178
- add_upsample=not is_final_block,
179
- resnet_eps=1e-6,
180
- resnet_act_fn=act_fn,
181
- attn_num_head_channels=None,
182
- temb_channels=None,
183
- )
184
- self.up_blocks.append(up_block)
185
- prev_output_channel = output_channel
186
-
187
- # out
188
- num_groups_out = 32
189
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups_out, eps=1e-6)
190
- self.conv_act = nn.SiLU()
191
- self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
192
-
193
- def forward(self, z):
194
- sample = z
195
- sample = self.conv_in(sample)
196
-
197
- # middle
198
- sample = self.mid_block(sample)
199
-
200
- # up
201
- for up_block in self.up_blocks:
202
- sample = up_block(sample)
203
-
204
- # post-process
205
- sample = self.conv_norm_out(sample)
206
- sample = self.conv_act(sample)
207
- sample = self.conv_out(sample)
208
-
209
- return sample
210
-
211
-
212
- class VectorQuantizer(nn.Module):
213
- """
214
- Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
215
- multiplications and allows for post-hoc remapping of indices.
216
- """
217
-
218
- # NOTE: due to a bug the beta term was applied to the wrong term. for
219
- # backwards compatibility we use the buggy version by default, but you can
220
- # specify legacy=False to fix it.
221
- def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
222
- super().__init__()
223
- self.n_e = n_e
224
- self.e_dim = e_dim
225
- self.beta = beta
226
- self.legacy = legacy
227
-
228
- self.embedding = nn.Embedding(self.n_e, self.e_dim)
229
- self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
230
-
231
- self.remap = remap
232
- if self.remap is not None:
233
- self.register_buffer("used", torch.tensor(np.load(self.remap)))
234
- self.re_embed = self.used.shape[0]
235
- self.unknown_index = unknown_index # "random" or "extra" or integer
236
- if self.unknown_index == "extra":
237
- self.unknown_index = self.re_embed
238
- self.re_embed = self.re_embed + 1
239
- print(
240
- f"Remapping {self.n_e} indices to {self.re_embed} indices. "
241
- f"Using {self.unknown_index} for unknown indices."
242
- )
243
- else:
244
- self.re_embed = n_e
245
-
246
- self.sane_index_shape = sane_index_shape
247
-
248
- def remap_to_used(self, inds):
249
- ishape = inds.shape
250
- assert len(ishape) > 1
251
- inds = inds.reshape(ishape[0], -1)
252
- used = self.used.to(inds)
253
- match = (inds[:, :, None] == used[None, None, ...]).long()
254
- new = match.argmax(-1)
255
- unknown = match.sum(2) < 1
256
- if self.unknown_index == "random":
257
- new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
258
- else:
259
- new[unknown] = self.unknown_index
260
- return new.reshape(ishape)
261
-
262
- def unmap_to_all(self, inds):
263
- ishape = inds.shape
264
- assert len(ishape) > 1
265
- inds = inds.reshape(ishape[0], -1)
266
- used = self.used.to(inds)
267
- if self.re_embed > self.used.shape[0]: # extra token
268
- inds[inds >= self.used.shape[0]] = 0 # simply set to zero
269
- back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
270
- return back.reshape(ishape)
271
-
272
- def forward(self, z):
273
- # reshape z -> (batch, height, width, channel) and flatten
274
- z = z.permute(0, 2, 3, 1).contiguous()
275
- z_flattened = z.view(-1, self.e_dim)
276
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
277
-
278
- d = (
279
- torch.sum(z_flattened**2, dim=1, keepdim=True)
280
- + torch.sum(self.embedding.weight**2, dim=1)
281
- - 2 * torch.einsum("bd,dn->bn", z_flattened, self.embedding.weight.t())
282
- )
283
-
284
- min_encoding_indices = torch.argmin(d, dim=1)
285
- z_q = self.embedding(min_encoding_indices).view(z.shape)
286
- perplexity = None
287
- min_encodings = None
288
-
289
- # compute loss for embedding
290
- if not self.legacy:
291
- loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
292
- else:
293
- loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
294
-
295
- # preserve gradients
296
- z_q = z + (z_q - z).detach()
297
-
298
- # reshape back to match original input shape
299
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
300
-
301
- if self.remap is not None:
302
- min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
303
- min_encoding_indices = self.remap_to_used(min_encoding_indices)
304
- min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
305
-
306
- if self.sane_index_shape:
307
- min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
308
-
309
- return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
310
-
311
- def get_codebook_entry(self, indices, shape):
312
- # shape specifying (batch, height, width, channel)
313
- if self.remap is not None:
314
- indices = indices.reshape(shape[0], -1) # add batch axis
315
- indices = self.unmap_to_all(indices)
316
- indices = indices.reshape(-1) # flatten again
317
-
318
- # get quantized latent vectors
319
- z_q = self.embedding(indices)
320
-
321
- if shape is not None:
322
- z_q = z_q.view(shape)
323
- # reshape back to match original input shape
324
- z_q = z_q.permute(0, 3, 1, 2).contiguous()
325
-
326
- return z_q
327
-
328
-
329
- class DiagonalGaussianDistribution(object):
330
- def __init__(self, parameters, deterministic=False):
331
- self.parameters = parameters
332
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
333
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
334
- self.deterministic = deterministic
335
- self.std = torch.exp(0.5 * self.logvar)
336
- self.var = torch.exp(self.logvar)
337
- if self.deterministic:
338
- self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
339
-
340
- def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
341
- device = self.parameters.device
342
- sample_device = "cpu" if device.type == "mps" else device
343
- sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
344
- x = self.mean + self.std * sample
345
- return x
346
-
347
- def kl(self, other=None):
348
- if self.deterministic:
349
- return torch.Tensor([0.0])
350
- else:
351
- if other is None:
352
- return 0.5 * torch.sum(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
353
- else:
354
- return 0.5 * torch.sum(
355
- torch.pow(self.mean - other.mean, 2) / other.var
356
- + self.var / other.var
357
- - 1.0
358
- - self.logvar
359
- + other.logvar,
360
- dim=[1, 2, 3],
361
- )
362
-
363
- def nll(self, sample, dims=[1, 2, 3]):
364
- if self.deterministic:
365
- return torch.Tensor([0.0])
366
- logtwopi = np.log(2.0 * np.pi)
367
- return 0.5 * torch.sum(logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, dim=dims)
368
-
369
- def mode(self):
370
- return self.mean
371
-
372
-
373
- class VQModel(ModelMixin, ConfigMixin):
374
- r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
375
- Kavukcuoglu.
376
-
377
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
378
- implements for all the model (such as downloading or saving, etc.)
379
-
380
- Parameters:
381
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
382
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
383
- down_block_types (`Tuple[str]`, *optional*, defaults to :
384
- obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
385
- up_block_types (`Tuple[str]`, *optional*, defaults to :
386
- obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
387
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
388
- obj:`(64,)`): Tuple of block output channels.
389
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
390
- latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
391
- sample_size (`int`, *optional*, defaults to `32`): TODO
392
- num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
393
- """
394
-
395
- @register_to_config
396
- def __init__(
397
- self,
398
- in_channels: int = 3,
399
- out_channels: int = 3,
400
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
401
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
402
- block_out_channels: Tuple[int] = (64,),
403
- layers_per_block: int = 1,
404
- act_fn: str = "silu",
405
- latent_channels: int = 3,
406
- sample_size: int = 32,
407
- num_vq_embeddings: int = 256,
408
- ):
409
- super().__init__()
410
-
411
- # pass init params to Encoder
412
- self.encoder = Encoder(
413
- in_channels=in_channels,
414
- out_channels=latent_channels,
415
- down_block_types=down_block_types,
416
- block_out_channels=block_out_channels,
417
- layers_per_block=layers_per_block,
418
- act_fn=act_fn,
419
- double_z=False,
420
- )
421
-
422
- self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
423
- self.quantize = VectorQuantizer(
424
- num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
425
- )
426
- self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
427
-
428
- # pass init params to Decoder
429
- self.decoder = Decoder(
430
- in_channels=latent_channels,
431
- out_channels=out_channels,
432
- up_block_types=up_block_types,
433
- block_out_channels=block_out_channels,
434
- layers_per_block=layers_per_block,
435
- act_fn=act_fn,
436
- )
437
-
438
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
439
- h = self.encoder(x)
440
- h = self.quant_conv(h)
441
-
442
- if not return_dict:
443
- return (h,)
444
-
445
- return VQEncoderOutput(latents=h)
446
-
447
- def decode(
448
- self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
449
- ) -> Union[DecoderOutput, torch.FloatTensor]:
450
- # also go through quantization layer
451
- if not force_not_quantize:
452
- quant, emb_loss, info = self.quantize(h)
453
- else:
454
- quant = h
455
- quant = self.post_quant_conv(quant)
456
- dec = self.decoder(quant)
457
-
458
- return dec
459
-
460
- # if not return_dict:
461
- # return (dec,)
462
- #
463
- # return DecoderOutput(sample=dec)
464
-
465
- def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
466
- r"""
467
- Args:
468
- sample (`torch.FloatTensor`): Input sample.
469
- return_dict (`bool`, *optional*, defaults to `True`):
470
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
471
- """
472
- x = sample
473
- h = self.encode(x).latents
474
- dec = self.decode(h).sample
475
-
476
- if not return_dict:
477
- return (dec,)
478
-
479
- return DecoderOutput(sample=dec)
480
-
481
-
482
- class AutoencoderKL(ModelMixin, ConfigMixin):
483
- r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
484
- and Max Welling.
485
-
486
- This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
487
- implements for all the model (such as downloading or saving, etc.)
488
-
489
- Parameters:
490
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
491
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
492
- down_block_types (`Tuple[str]`, *optional*, defaults to :
493
- obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
494
- up_block_types (`Tuple[str]`, *optional*, defaults to :
495
- obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
496
- block_out_channels (`Tuple[int]`, *optional*, defaults to :
497
- obj:`(64,)`): Tuple of block output channels.
498
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
499
- latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
500
- sample_size (`int`, *optional*, defaults to `32`): TODO
501
- """
502
-
503
- @register_to_config
504
- def __init__(
505
- self,
506
- in_channels: int = 3,
507
- out_channels: int = 3,
508
- down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
509
- up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
510
- block_out_channels: Tuple[int] = (64,),
511
- layers_per_block: int = 1,
512
- act_fn: str = "silu",
513
- latent_channels: int = 4,
514
- sample_size: int = 32,
515
- ):
516
- super().__init__()
517
-
518
- # pass init params to Encoder
519
- self.encoder = Encoder(
520
- in_channels=in_channels,
521
- out_channels=latent_channels,
522
- down_block_types=down_block_types,
523
- block_out_channels=block_out_channels,
524
- layers_per_block=layers_per_block,
525
- act_fn=act_fn,
526
- double_z=True,
527
- )
528
-
529
- # pass init params to Decoder
530
- self.decoder = Decoder(
531
- in_channels=latent_channels,
532
- out_channels=out_channels,
533
- up_block_types=up_block_types,
534
- block_out_channels=block_out_channels,
535
- layers_per_block=layers_per_block,
536
- act_fn=act_fn,
537
- )
538
-
539
- self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
540
- self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
541
-
542
- def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
543
- h = self.encoder(x)
544
- moments = self.quant_conv(h)
545
- posterior = DiagonalGaussianDistribution(moments)
546
-
547
- if not return_dict:
548
- return (posterior,)
549
-
550
- return AutoencoderKLOutput(latent_dist=posterior)
551
-
552
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
553
- z = self.post_quant_conv(z)
554
- dec = self.decoder(z)
555
-
556
- return dec
557
- #
558
- # if not return_dict:
559
- # return (dec,)
560
- #
561
- # return DecoderOutput(sample=dec)
562
-
563
- def forward(
564
- self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
565
- ) -> Union[DecoderOutput, torch.FloatTensor]:
566
- r"""
567
- Args:
568
- sample (`torch.FloatTensor`): Input sample.
569
- sample_posterior (`bool`, *optional*, defaults to `False`):
570
- Whether to sample from the posterior.
571
- return_dict (`bool`, *optional*, defaults to `True`):
572
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
573
- """
574
- x = sample
575
- posterior = self.encode(x).latent_dist
576
- if sample_posterior:
577
- z = posterior.sample()
578
- else:
579
- z = posterior.mode()
580
- dec = self.decode(z).sample
581
-
582
- if not return_dict:
583
- return (dec,)
584
-
585
- return DecoderOutput(sample=dec)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/onnx_utils.py DELETED
@@ -1,189 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
-
18
- import os
19
- import shutil
20
- from pathlib import Path
21
- from typing import Optional, Union
22
-
23
- import numpy as np
24
-
25
- from huggingface_hub import hf_hub_download
26
-
27
- from .utils import is_onnx_available, logging
28
-
29
-
30
- if is_onnx_available():
31
- import onnxruntime as ort
32
-
33
-
34
- ONNX_WEIGHTS_NAME = "model.onnx"
35
-
36
-
37
- logger = logging.get_logger(__name__)
38
-
39
-
40
- class OnnxRuntimeModel:
41
- base_model_prefix = "onnx_model"
42
-
43
- def __init__(self, model=None, **kwargs):
44
- logger.info("`diffusers.OnnxRuntimeModel` is experimental and might change in the future.")
45
- self.model = model
46
- self.model_save_dir = kwargs.get("model_save_dir", None)
47
- self.latest_model_name = kwargs.get("latest_model_name", "model.onnx")
48
-
49
- def __call__(self, **kwargs):
50
- inputs = {k: np.array(v) for k, v in kwargs.items()}
51
- return self.model.run(None, inputs)
52
-
53
- @staticmethod
54
- def load_model(path: Union[str, Path], provider=None):
55
- """
56
- Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
57
-
58
- Arguments:
59
- path (`str` or `Path`):
60
- Directory from which to load
61
- provider(`str`, *optional*):
62
- Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
63
- """
64
- if provider is None:
65
- logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
66
- provider = "CPUExecutionProvider"
67
-
68
- return ort.InferenceSession(path, providers=[provider])
69
-
70
- def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
71
- """
72
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
73
- [`~optimum.onnxruntime.modeling_ort.ORTModel.from_pretrained`] class method. It will always save the
74
- latest_model_name.
75
-
76
- Arguments:
77
- save_directory (`str` or `Path`):
78
- Directory where to save the model file.
79
- file_name(`str`, *optional*):
80
- Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to save the
81
- model with a different name.
82
- """
83
- model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
84
-
85
- src_path = self.model_save_dir.joinpath(self.latest_model_name)
86
- dst_path = Path(save_directory).joinpath(model_file_name)
87
- if not src_path.samefile(dst_path):
88
- shutil.copyfile(src_path, dst_path)
89
-
90
- def save_pretrained(
91
- self,
92
- save_directory: Union[str, os.PathLike],
93
- **kwargs,
94
- ):
95
- """
96
- Save a model to a directory, so that it can be re-loaded using the [`~OnnxModel.from_pretrained`] class
97
- method.:
98
-
99
- Arguments:
100
- save_directory (`str` or `os.PathLike`):
101
- Directory to which to save. Will be created if it doesn't exist.
102
- """
103
- if os.path.isfile(save_directory):
104
- logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
105
- return
106
-
107
- os.makedirs(save_directory, exist_ok=True)
108
-
109
- # saving model weights/files
110
- self._save_pretrained(save_directory, **kwargs)
111
-
112
- @classmethod
113
- def _from_pretrained(
114
- cls,
115
- model_id: Union[str, Path],
116
- use_auth_token: Optional[Union[bool, str, None]] = None,
117
- revision: Optional[Union[str, None]] = None,
118
- force_download: bool = False,
119
- cache_dir: Optional[str] = None,
120
- file_name: Optional[str] = None,
121
- provider: Optional[str] = None,
122
- **kwargs,
123
- ):
124
- """
125
- Load a model from a directory or the HF Hub.
126
-
127
- Arguments:
128
- model_id (`str` or `Path`):
129
- Directory from which to load
130
- use_auth_token (`str` or `bool`):
131
- Is needed to load models from a private or gated repository
132
- revision (`str`):
133
- Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id
134
- cache_dir (`Union[str, Path]`, *optional*):
135
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
136
- standard cache should not be used.
137
- force_download (`bool`, *optional*, defaults to `False`):
138
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
139
- cached versions if they exist.
140
- file_name(`str`):
141
- Overwrites the default model file name from `"model.onnx"` to `file_name`. This allows you to load
142
- different model files from the same repository or directory.
143
- provider(`str`):
144
- The ONNX runtime provider, e.g. `CPUExecutionProvider` or `CUDAExecutionProvider`.
145
- kwargs (`Dict`, *optional*):
146
- kwargs will be passed to the model during initialization
147
- """
148
- model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
149
- # load model from local directory
150
- if os.path.isdir(model_id):
151
- model = OnnxRuntimeModel.load_model(os.path.join(model_id, model_file_name), provider=provider)
152
- kwargs["model_save_dir"] = Path(model_id)
153
- # load model from hub
154
- else:
155
- # download model
156
- model_cache_path = hf_hub_download(
157
- repo_id=model_id,
158
- filename=model_file_name,
159
- use_auth_token=use_auth_token,
160
- revision=revision,
161
- cache_dir=cache_dir,
162
- force_download=force_download,
163
- )
164
- kwargs["model_save_dir"] = Path(model_cache_path).parent
165
- kwargs["latest_model_name"] = Path(model_cache_path).name
166
- model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider)
167
- return cls(model=model, **kwargs)
168
-
169
- @classmethod
170
- def from_pretrained(
171
- cls,
172
- model_id: Union[str, Path],
173
- force_download: bool = True,
174
- use_auth_token: Optional[str] = None,
175
- cache_dir: Optional[str] = None,
176
- **model_kwargs,
177
- ):
178
- revision = None
179
- if len(str(model_id).split("@")) == 2:
180
- model_id, revision = model_id.split("@")
181
-
182
- return cls._from_pretrained(
183
- model_id=model_id,
184
- revision=revision,
185
- cache_dir=cache_dir,
186
- force_download=force_download,
187
- use_auth_token=use_auth_token,
188
- **model_kwargs,
189
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/optimization.py DELETED
@@ -1,275 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """PyTorch optimization for diffusion models."""
16
-
17
- import math
18
- from enum import Enum
19
- from typing import Optional, Union
20
-
21
- from torch.optim import Optimizer
22
- from torch.optim.lr_scheduler import LambdaLR
23
-
24
- from .utils import logging
25
-
26
-
27
- logger = logging.get_logger(__name__)
28
-
29
-
30
- class SchedulerType(Enum):
31
- LINEAR = "linear"
32
- COSINE = "cosine"
33
- COSINE_WITH_RESTARTS = "cosine_with_restarts"
34
- POLYNOMIAL = "polynomial"
35
- CONSTANT = "constant"
36
- CONSTANT_WITH_WARMUP = "constant_with_warmup"
37
-
38
-
39
- def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
40
- """
41
- Create a schedule with a constant learning rate, using the learning rate set in optimizer.
42
-
43
- Args:
44
- optimizer ([`~torch.optim.Optimizer`]):
45
- The optimizer for which to schedule the learning rate.
46
- last_epoch (`int`, *optional*, defaults to -1):
47
- The index of the last epoch when resuming training.
48
-
49
- Return:
50
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
51
- """
52
- return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
53
-
54
-
55
- def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
56
- """
57
- Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
58
- increases linearly between 0 and the initial lr set in the optimizer.
59
-
60
- Args:
61
- optimizer ([`~torch.optim.Optimizer`]):
62
- The optimizer for which to schedule the learning rate.
63
- num_warmup_steps (`int`):
64
- The number of steps for the warmup phase.
65
- last_epoch (`int`, *optional*, defaults to -1):
66
- The index of the last epoch when resuming training.
67
-
68
- Return:
69
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
70
- """
71
-
72
- def lr_lambda(current_step: int):
73
- if current_step < num_warmup_steps:
74
- return float(current_step) / float(max(1.0, num_warmup_steps))
75
- return 1.0
76
-
77
- return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
78
-
79
-
80
- def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
81
- """
82
- Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
83
- a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
84
-
85
- Args:
86
- optimizer ([`~torch.optim.Optimizer`]):
87
- The optimizer for which to schedule the learning rate.
88
- num_warmup_steps (`int`):
89
- The number of steps for the warmup phase.
90
- num_training_steps (`int`):
91
- The total number of training steps.
92
- last_epoch (`int`, *optional*, defaults to -1):
93
- The index of the last epoch when resuming training.
94
-
95
- Return:
96
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
97
- """
98
-
99
- def lr_lambda(current_step: int):
100
- if current_step < num_warmup_steps:
101
- return float(current_step) / float(max(1, num_warmup_steps))
102
- return max(
103
- 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
104
- )
105
-
106
- return LambdaLR(optimizer, lr_lambda, last_epoch)
107
-
108
-
109
- def get_cosine_schedule_with_warmup(
110
- optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
111
- ):
112
- """
113
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
114
- initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
115
- initial lr set in the optimizer.
116
-
117
- Args:
118
- optimizer ([`~torch.optim.Optimizer`]):
119
- The optimizer for which to schedule the learning rate.
120
- num_warmup_steps (`int`):
121
- The number of steps for the warmup phase.
122
- num_training_steps (`int`):
123
- The total number of training steps.
124
- num_cycles (`float`, *optional*, defaults to 0.5):
125
- The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
126
- following a half-cosine).
127
- last_epoch (`int`, *optional*, defaults to -1):
128
- The index of the last epoch when resuming training.
129
-
130
- Return:
131
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
132
- """
133
-
134
- def lr_lambda(current_step):
135
- if current_step < num_warmup_steps:
136
- return float(current_step) / float(max(1, num_warmup_steps))
137
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
138
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
139
-
140
- return LambdaLR(optimizer, lr_lambda, last_epoch)
141
-
142
-
143
- def get_cosine_with_hard_restarts_schedule_with_warmup(
144
- optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
145
- ):
146
- """
147
- Create a schedule with a learning rate that decreases following the values of the cosine function between the
148
- initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
149
- linearly between 0 and the initial lr set in the optimizer.
150
-
151
- Args:
152
- optimizer ([`~torch.optim.Optimizer`]):
153
- The optimizer for which to schedule the learning rate.
154
- num_warmup_steps (`int`):
155
- The number of steps for the warmup phase.
156
- num_training_steps (`int`):
157
- The total number of training steps.
158
- num_cycles (`int`, *optional*, defaults to 1):
159
- The number of hard restarts to use.
160
- last_epoch (`int`, *optional*, defaults to -1):
161
- The index of the last epoch when resuming training.
162
-
163
- Return:
164
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
165
- """
166
-
167
- def lr_lambda(current_step):
168
- if current_step < num_warmup_steps:
169
- return float(current_step) / float(max(1, num_warmup_steps))
170
- progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
171
- if progress >= 1.0:
172
- return 0.0
173
- return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
174
-
175
- return LambdaLR(optimizer, lr_lambda, last_epoch)
176
-
177
-
178
- def get_polynomial_decay_schedule_with_warmup(
179
- optimizer, num_warmup_steps, num_training_steps, lr_end=1e-7, power=1.0, last_epoch=-1
180
- ):
181
- """
182
- Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
183
- optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
184
- initial lr set in the optimizer.
185
-
186
- Args:
187
- optimizer ([`~torch.optim.Optimizer`]):
188
- The optimizer for which to schedule the learning rate.
189
- num_warmup_steps (`int`):
190
- The number of steps for the warmup phase.
191
- num_training_steps (`int`):
192
- The total number of training steps.
193
- lr_end (`float`, *optional*, defaults to 1e-7):
194
- The end LR.
195
- power (`float`, *optional*, defaults to 1.0):
196
- Power factor.
197
- last_epoch (`int`, *optional*, defaults to -1):
198
- The index of the last epoch when resuming training.
199
-
200
- Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
201
- implementation at
202
- https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
203
-
204
- Return:
205
- `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
206
-
207
- """
208
-
209
- lr_init = optimizer.defaults["lr"]
210
- if not (lr_init > lr_end):
211
- raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})")
212
-
213
- def lr_lambda(current_step: int):
214
- if current_step < num_warmup_steps:
215
- return float(current_step) / float(max(1, num_warmup_steps))
216
- elif current_step > num_training_steps:
217
- return lr_end / lr_init # as LambdaLR multiplies by lr_init
218
- else:
219
- lr_range = lr_init - lr_end
220
- decay_steps = num_training_steps - num_warmup_steps
221
- pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
222
- decay = lr_range * pct_remaining**power + lr_end
223
- return decay / lr_init # as LambdaLR multiplies by lr_init
224
-
225
- return LambdaLR(optimizer, lr_lambda, last_epoch)
226
-
227
-
228
- TYPE_TO_SCHEDULER_FUNCTION = {
229
- SchedulerType.LINEAR: get_linear_schedule_with_warmup,
230
- SchedulerType.COSINE: get_cosine_schedule_with_warmup,
231
- SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
232
- SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
233
- SchedulerType.CONSTANT: get_constant_schedule,
234
- SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
235
- }
236
-
237
-
238
- def get_scheduler(
239
- name: Union[str, SchedulerType],
240
- optimizer: Optimizer,
241
- num_warmup_steps: Optional[int] = None,
242
- num_training_steps: Optional[int] = None,
243
- ):
244
- """
245
- Unified API to get any scheduler from its name.
246
-
247
- Args:
248
- name (`str` or `SchedulerType`):
249
- The name of the scheduler to use.
250
- optimizer (`torch.optim.Optimizer`):
251
- The optimizer that will be used during training.
252
- num_warmup_steps (`int`, *optional*):
253
- The number of warmup steps to do. This is not required by all schedulers (hence the argument being
254
- optional), the function will raise an error if it's unset and the scheduler type requires it.
255
- num_training_steps (`int``, *optional*):
256
- The number of training steps to do. This is not required by all schedulers (hence the argument being
257
- optional), the function will raise an error if it's unset and the scheduler type requires it.
258
- """
259
- name = SchedulerType(name)
260
- schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
261
- if name == SchedulerType.CONSTANT:
262
- return schedule_func(optimizer)
263
-
264
- # All other schedulers require `num_warmup_steps`
265
- if num_warmup_steps is None:
266
- raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
267
-
268
- if name == SchedulerType.CONSTANT_WITH_WARMUP:
269
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps)
270
-
271
- # All other schedulers require `num_training_steps`
272
- if num_training_steps is None:
273
- raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
274
-
275
- return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/pipeline_utils.py DELETED
@@ -1,417 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2022 The HuggingFace Inc. team.
3
- # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
- import importlib
18
- import inspect
19
- import os
20
- from dataclasses import dataclass
21
- from typing import List, Optional, Union
22
-
23
- import numpy as np
24
- import torch
25
-
26
- import diffusers
27
- import PIL
28
- from huggingface_hub import snapshot_download
29
- from PIL import Image
30
- from tqdm.auto import tqdm
31
-
32
- from .configuration_utils import ConfigMixin
33
- from .utils import DIFFUSERS_CACHE, BaseOutput, logging
34
-
35
-
36
- INDEX_FILE = "diffusion_pytorch_model.bin"
37
-
38
-
39
- logger = logging.get_logger(__name__)
40
-
41
-
42
- LOADABLE_CLASSES = {
43
- "diffusers": {
44
- "ModelMixin": ["save_pretrained", "from_pretrained"],
45
- "SchedulerMixin": ["save_config", "from_config"],
46
- "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
47
- "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
48
- },
49
- "transformers": {
50
- "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
51
- "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
52
- "PreTrainedModel": ["save_pretrained", "from_pretrained"],
53
- "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
54
- },
55
- }
56
-
57
- ALL_IMPORTABLE_CLASSES = {}
58
- for library in LOADABLE_CLASSES:
59
- ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
60
-
61
-
62
- @dataclass
63
- class ImagePipelineOutput(BaseOutput):
64
- """
65
- Output class for image pipelines.
66
-
67
- Args:
68
- images (`List[PIL.Image.Image]` or `np.ndarray`)
69
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
70
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
71
- """
72
-
73
- images: Union[List[PIL.Image.Image], np.ndarray]
74
-
75
-
76
- class DiffusionPipeline(ConfigMixin):
77
- r"""
78
- Base class for all models.
79
-
80
- [`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
81
- and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
82
-
83
- - move all PyTorch modules to the device of your choice
84
- - enabling/disabling the progress bar for the denoising iteration
85
-
86
- Class attributes:
87
-
88
- - **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
89
- compenents of the diffusion pipeline.
90
- """
91
- config_name = "model_index.json"
92
-
93
- def register_modules(self, **kwargs):
94
- # import it here to avoid circular import
95
- from diffusers import pipelines
96
-
97
- for name, module in kwargs.items():
98
- # retrive library
99
- library = module.__module__.split(".")[0]
100
-
101
- # check if the module is a pipeline module
102
- pipeline_dir = module.__module__.split(".")[-2]
103
- path = module.__module__.split(".")
104
- is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
105
-
106
- # if library is not in LOADABLE_CLASSES, then it is a custom module.
107
- # Or if it's a pipeline module, then the module is inside the pipeline
108
- # folder so we set the library to module name.
109
- if library not in LOADABLE_CLASSES or is_pipeline_module:
110
- library = pipeline_dir
111
-
112
- # retrive class_name
113
- class_name = module.__class__.__name__
114
-
115
- register_dict = {name: (library, class_name)}
116
-
117
- # save model index config
118
- self.register_to_config(**register_dict)
119
-
120
- # set models
121
- setattr(self, name, module)
122
-
123
- def save_pretrained(self, save_directory: Union[str, os.PathLike]):
124
- """
125
- Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
126
- a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
127
- method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
128
-
129
- Arguments:
130
- save_directory (`str` or `os.PathLike`):
131
- Directory to which to save. Will be created if it doesn't exist.
132
- """
133
- self.save_config(save_directory)
134
-
135
- model_index_dict = dict(self.config)
136
- model_index_dict.pop("_class_name")
137
- model_index_dict.pop("_diffusers_version")
138
- model_index_dict.pop("_module", None)
139
-
140
- for pipeline_component_name in model_index_dict.keys():
141
- sub_model = getattr(self, pipeline_component_name)
142
- model_cls = sub_model.__class__
143
-
144
- save_method_name = None
145
- # search for the model's base class in LOADABLE_CLASSES
146
- for library_name, library_classes in LOADABLE_CLASSES.items():
147
- library = importlib.import_module(library_name)
148
- for base_class, save_load_methods in library_classes.items():
149
- class_candidate = getattr(library, base_class)
150
- if issubclass(model_cls, class_candidate):
151
- # if we found a suitable base class in LOADABLE_CLASSES then grab its save method
152
- save_method_name = save_load_methods[0]
153
- break
154
- if save_method_name is not None:
155
- break
156
-
157
- save_method = getattr(sub_model, save_method_name)
158
- save_method(os.path.join(save_directory, pipeline_component_name))
159
-
160
- def to(self, torch_device: Optional[Union[str, torch.device]] = None):
161
- if torch_device is None:
162
- return self
163
-
164
- module_names, _ = self.extract_init_dict(dict(self.config))
165
- for name in module_names.keys():
166
- module = getattr(self, name)
167
- if isinstance(module, torch.nn.Module):
168
- module.to(torch_device)
169
- return self
170
-
171
- @property
172
- def device(self) -> torch.device:
173
- r"""
174
- Returns:
175
- `torch.device`: The torch device on which the pipeline is located.
176
- """
177
- module_names, _ = self.extract_init_dict(dict(self.config))
178
- for name in module_names.keys():
179
- module = getattr(self, name)
180
- if isinstance(module, torch.nn.Module):
181
- return module.device
182
- return torch.device("cpu")
183
-
184
- @classmethod
185
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
186
- r"""
187
- Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
188
-
189
- The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
190
-
191
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
192
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
193
- task.
194
-
195
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
196
- weights are discarded.
197
-
198
- Parameters:
199
- pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
200
- Can be either:
201
-
202
- - A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
203
- https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
204
- `CompVis/ldm-text2im-large-256`.
205
- - A path to a *directory* containing pipeline weights saved using
206
- [`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
207
- torch_dtype (`str` or `torch.dtype`, *optional*):
208
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
209
- will be automatically derived from the model's weights.
210
- force_download (`bool`, *optional*, defaults to `False`):
211
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
212
- cached versions if they exist.
213
- resume_download (`bool`, *optional*, defaults to `False`):
214
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
215
- file exists.
216
- proxies (`Dict[str, str]`, *optional*):
217
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
218
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
219
- output_loading_info(`bool`, *optional*, defaults to `False`):
220
- Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
221
- local_files_only(`bool`, *optional*, defaults to `False`):
222
- Whether or not to only look at local files (i.e., do not try to download the model).
223
- use_auth_token (`str` or *bool*, *optional*):
224
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
225
- when running `huggingface-cli login` (stored in `~/.huggingface`).
226
- revision (`str`, *optional*, defaults to `"main"`):
227
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
228
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
229
- identifier allowed by git.
230
- mirror (`str`, *optional*):
231
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
232
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
233
- Please refer to the mirror site for more information. specify the folder name here.
234
-
235
- kwargs (remaining dictionary of keyword arguments, *optional*):
236
- Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
237
- speficic pipeline class. The overritten components are then directly passed to the pipelines `__init__`
238
- method. See example below for more information.
239
-
240
- <Tip>
241
-
242
- Passing `use_auth_token=True`` is required when you want to use a private model, *e.g.*
243
- `"CompVis/stable-diffusion-v1-4"`
244
-
245
- </Tip>
246
-
247
- <Tip>
248
-
249
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
250
- this method in a firewalled environment.
251
-
252
- </Tip>
253
-
254
- Examples:
255
-
256
- ```py
257
- >>> from diffusers import DiffusionPipeline
258
-
259
- >>> # Download pipeline from huggingface.co and cache.
260
- >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
261
-
262
- >>> # Download pipeline that requires an authorization token
263
- >>> # For more information on access tokens, please refer to this section
264
- >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
265
- >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", use_auth_token=True)
266
-
267
- >>> # Download pipeline, but overwrite scheduler
268
- >>> from diffusers import LMSDiscreteScheduler
269
-
270
- >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
271
- >>> pipeline = DiffusionPipeline.from_pretrained(
272
- ... "CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True
273
- ... )
274
- ```
275
- """
276
- cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
277
- resume_download = kwargs.pop("resume_download", False)
278
- proxies = kwargs.pop("proxies", None)
279
- local_files_only = kwargs.pop("local_files_only", False)
280
- use_auth_token = kwargs.pop("use_auth_token", None)
281
- revision = kwargs.pop("revision", None)
282
- torch_dtype = kwargs.pop("torch_dtype", None)
283
- provider = kwargs.pop("provider", None)
284
-
285
- # 1. Download the checkpoints and configs
286
- # use snapshot download here to get it working from from_pretrained
287
- if not os.path.isdir(pretrained_model_name_or_path):
288
- cached_folder = snapshot_download(
289
- pretrained_model_name_or_path,
290
- cache_dir=cache_dir,
291
- resume_download=resume_download,
292
- proxies=proxies,
293
- local_files_only=local_files_only,
294
- use_auth_token=use_auth_token,
295
- revision=revision,
296
- )
297
- else:
298
- cached_folder = pretrained_model_name_or_path
299
-
300
- config_dict = cls.get_config_dict(cached_folder)
301
-
302
- # 2. Load the pipeline class, if using custom module then load it from the hub
303
- # if we load from explicit class, let's use it
304
- if cls != DiffusionPipeline:
305
- pipeline_class = cls
306
- else:
307
- diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
308
- pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
309
-
310
- # some modules can be passed directly to the init
311
- # in this case they are already instantiated in `kwargs`
312
- # extract them here
313
- expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys())
314
- passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
315
-
316
- init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
317
-
318
- init_kwargs = {}
319
-
320
- # import it here to avoid circular import
321
- from diffusers import pipelines
322
-
323
- # 3. Load each module in the pipeline
324
- for name, (library_name, class_name) in init_dict.items():
325
- is_pipeline_module = hasattr(pipelines, library_name)
326
- loaded_sub_model = None
327
-
328
- # if the model is in a pipeline module, then we load it from the pipeline
329
- if name in passed_class_obj:
330
- # 1. check that passed_class_obj has correct parent class
331
- if not is_pipeline_module:
332
- library = importlib.import_module(library_name)
333
- class_obj = getattr(library, class_name)
334
- importable_classes = LOADABLE_CLASSES[library_name]
335
- class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
336
-
337
- expected_class_obj = None
338
- for class_name, class_candidate in class_candidates.items():
339
- if issubclass(class_obj, class_candidate):
340
- expected_class_obj = class_candidate
341
-
342
- if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
343
- raise ValueError(
344
- f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
345
- f" {expected_class_obj}"
346
- )
347
- else:
348
- logger.warn(
349
- f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
350
- " has the correct type"
351
- )
352
-
353
- # set passed class object
354
- loaded_sub_model = passed_class_obj[name]
355
- elif is_pipeline_module:
356
- pipeline_module = getattr(pipelines, library_name)
357
- class_obj = getattr(pipeline_module, class_name)
358
- importable_classes = ALL_IMPORTABLE_CLASSES
359
- class_candidates = {c: class_obj for c in importable_classes.keys()}
360
- else:
361
- # else we just import it from the library.
362
- library = importlib.import_module(library_name)
363
- class_obj = getattr(library, class_name)
364
- importable_classes = LOADABLE_CLASSES[library_name]
365
- class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
366
-
367
- if loaded_sub_model is None:
368
- load_method_name = None
369
- for class_name, class_candidate in class_candidates.items():
370
- if issubclass(class_obj, class_candidate):
371
- load_method_name = importable_classes[class_name][1]
372
-
373
- load_method = getattr(class_obj, load_method_name)
374
-
375
- loading_kwargs = {}
376
- if issubclass(class_obj, torch.nn.Module):
377
- loading_kwargs["torch_dtype"] = torch_dtype
378
- if issubclass(class_obj, diffusers.OnnxRuntimeModel):
379
- loading_kwargs["provider"] = provider
380
-
381
- # check if the module is in a subdirectory
382
- if os.path.isdir(os.path.join(cached_folder, name)):
383
- loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
384
- else:
385
- # else load from the root directory
386
- loaded_sub_model = load_method(cached_folder, **loading_kwargs)
387
-
388
- init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
389
-
390
- # 4. Instantiate the pipeline
391
- model = pipeline_class(**init_kwargs)
392
- return model
393
-
394
- @staticmethod
395
- def numpy_to_pil(images):
396
- """
397
- Convert a numpy image or a batch of images to a PIL image.
398
- """
399
- if images.ndim == 3:
400
- images = images[None, ...]
401
- images = (images * 255).round().astype("uint8")
402
- pil_images = [Image.fromarray(image) for image in images]
403
-
404
- return pil_images
405
-
406
- def progress_bar(self, iterable):
407
- if not hasattr(self, "_progress_bar_config"):
408
- self._progress_bar_config = {}
409
- elif not isinstance(self._progress_bar_config, dict):
410
- raise ValueError(
411
- f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
412
- )
413
-
414
- return tqdm(iterable, **self._progress_bar_config)
415
-
416
- def set_progress_bar_config(self, **kwargs):
417
- self._progress_bar_config = kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/pipelines/__init__.py DELETED
@@ -1,19 +0,0 @@
1
- from ..utils import is_onnx_available, is_transformers_available
2
- from .ddim import DDIMPipeline
3
- from .ddpm import DDPMPipeline
4
- from .latent_diffusion_uncond import LDMPipeline
5
- from .pndm import PNDMPipeline
6
- from .score_sde_ve import ScoreSdeVePipeline
7
- from .stochastic_karras_ve import KarrasVePipeline
8
-
9
-
10
- if is_transformers_available():
11
- from .latent_diffusion import LDMTextToImagePipeline
12
- from .stable_diffusion import (
13
- StableDiffusionImg2ImgPipeline,
14
- StableDiffusionInpaintPipeline,
15
- StableDiffusionPipeline,
16
- )
17
-
18
- if is_transformers_available() and is_onnx_available():
19
- from .stable_diffusion import StableDiffusionOnnxPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
diffusers/pipelines/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (829 Bytes)
 
diffusers/pipelines/ddim/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- # flake8: noqa
2
- from .pipeline_ddim import DDIMPipeline
 
 
 
diffusers/pipelines/ddim/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (207 Bytes)