hatmanstack commited on
Commit
78b9267
Β·
1 Parent(s): cfae7c7

Back to StabilityXL base

Browse files
app.py CHANGED
@@ -1,46 +1,22 @@
1
  import torch
2
  import random
3
- import spaces ## For ZeroGPU
4
  import gradio as gr
5
  from PIL import Image
6
- from models_transformer_sd3 import SD3Transformer2DModel
7
- from pipeline_stable_diffusion_3_ipa import StableDiffusion3Pipeline
8
- import gc
9
- import os
10
- from huggingface_hub import login
11
-
12
- TOKEN = os.getenv('TOKEN')
13
- login(TOKEN)
14
-
15
- model_path = 'stabilityai/stable-diffusion-3.5-large'
16
- ip_adapter_path = './ip-adapter.bin'
17
- image_encoder_path = "google/siglip-so400m-patch14-384"
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
21
-
22
- transformer = SD3Transformer2DModel.from_pretrained(
23
- model_path, subfolder="transformer", torch_dtype=torch.bfloat16
24
- )
25
-
26
- pipe = StableDiffusion3Pipeline.from_pretrained(
27
- model_path, transformer=transformer, torch_dtype=torch.bfloat16
28
- ) ## For ZeroGPU no .to("cuda")
29
-
30
- pipe.init_ipadapter(
31
- ip_adapter_path=ip_adapter_path,
32
- image_encoder_path=image_encoder_path,
33
- nb_token=64,
34
- )
35
-
36
  pipe.to(device)
37
-
38
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
39
  if randomize_seed:
40
  seed = random.randint(0, 2000)
41
  return seed
42
-
43
- @spaces.GPU() ## For ZeroGPU
44
  def create_image(image_pil,
45
  prompt,
46
  n_prompt,
@@ -66,31 +42,24 @@ def create_image(image_pil,
66
  "down": {"block_2": [0.0, control_scale]},
67
  "up": {"block_0": [0.0, control_scale, 0.0]},
68
  }
69
- #pipe.set_ip_adapter_scale(scale) ##Waiting for Diffuser integration of SD3 pipeline
70
-
71
- style_image = Image.open(image_pil).convert('RGB')
 
72
 
73
 
74
  image = pipe(
75
- width=1024,
76
- height=1024,
77
- prompt=prompt,
78
- negative_prompt="lowres, low quality, worst quality",
79
- num_inference_steps=24,
80
- guidance_scale=guidance_scale,
81
- generator=torch.Generator("cuda").manual_seed(randomize_seed_fn(seed, True)), ## For ZeroGPU no device="cpu"
82
- clip_image=style_image,
83
- ipadapter_scale=scale,
84
- ).images[0]
85
-
86
- if torch.cuda.is_available():
87
- torch.cuda.empty_cache()
88
- gc.collect()
89
 
90
  return image
91
 
92
 
93
-
94
 
95
  # Description
96
  title = r"""
@@ -113,6 +82,7 @@ article = r"""
113
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
114
  journal={arXiv preprint arXiv:2404.02733},
115
  year={2024}
 
116
  ```
117
  """
118
 
@@ -176,4 +146,4 @@ with block:
176
 
177
  gr.Markdown(article)
178
 
179
- block.launch(show_error=True, share=True)
 
1
  import torch
2
  import random
3
+ import spaces
4
  import gradio as gr
5
  from PIL import Image
6
+ from diffusers import AutoPipelineForText2Image
7
+ from diffusers.utils import load_image
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
11
+ pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype)
12
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  pipe.to(device)
 
14
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
15
  if randomize_seed:
16
  seed = random.randint(0, 2000)
17
  return seed
18
+
19
+ @spaces.GPU()
20
  def create_image(image_pil,
21
  prompt,
22
  n_prompt,
 
42
  "down": {"block_2": [0.0, control_scale]},
43
  "up": {"block_0": [0.0, control_scale, 0.0]},
44
  }
45
+ pipe.set_ip_adapter_scale(scale)
46
+
47
+ style_image = load_image(image_pil)
48
+ generator = torch.Generator().manual_seed(randomize_seed_fn(seed, True))
49
 
50
 
51
  image = pipe(
52
+ prompt=prompt,
53
+ ip_adapter_image=style_image,
54
+ negative_prompt=n_prompt,
55
+ guidance_scale=guidance_scale,
56
+ num_inference_steps=num_inference_steps,
57
+ generator=generator,
58
+ ).images[0]
 
 
 
 
 
 
 
59
 
60
  return image
61
 
62
 
 
63
 
64
  # Description
65
  title = r"""
 
82
  author={Wang, Haofan and Wang, Qixun and Bai, Xu and Qin, Zekui and Chen, Anthony},
83
  journal={arXiv preprint arXiv:2404.02733},
84
  year={2024}
85
+ }
86
  ```
87
  """
88
 
 
146
 
147
  gr.Markdown(article)
148
 
149
+ block.launch(show_error=True)
.gitattributes β†’ gitattributes RENAMED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
models_attention.py DELETED
@@ -1,1279 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from typing import Any, Dict, List, Optional, Tuple
15
-
16
- import torch
17
- import torch.nn.functional as F
18
- from torch import nn
19
-
20
- from diffusers.utils import deprecate, logging
21
- from diffusers.utils.torch_utils import maybe_allow_in_graph
22
- from diffusers.models.activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
23
- from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0
24
- from diffusers.models.embeddings import SinusoidalPositionalEmbedding
25
- from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
-
27
-
28
- logger = logging.get_logger(__name__)
29
-
30
-
31
-
32
- def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
33
- # "feed_forward_chunk_size" can be used to save memory
34
- if hidden_states.shape[chunk_dim] % chunk_size != 0:
35
- raise ValueError(
36
- f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
37
- )
38
-
39
- num_chunks = hidden_states.shape[chunk_dim] // chunk_size
40
- ff_output = torch.cat(
41
- [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
42
- dim=chunk_dim,
43
- )
44
- return ff_output
45
-
46
- @maybe_allow_in_graph
47
- class SD35AdaLayerNormZeroX(nn.Module):
48
- r"""
49
- Norm layer adaptive layer norm zero (AdaLN-Zero).
50
-
51
- Parameters:
52
- embedding_dim (`int`): The size of each embedding vector.
53
- num_embeddings (`int`): The size of the embeddings dictionary.
54
- """
55
-
56
- def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
57
- super().__init__()
58
-
59
- self.silu = nn.SiLU()
60
- self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
61
- if norm_type == "layer_norm":
62
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
63
- else:
64
- raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
65
-
66
- def forward(
67
- self,
68
- hidden_states: torch.Tensor,
69
- emb: Optional[torch.Tensor] = None,
70
- ) -> Tuple[torch.Tensor, ...]:
71
- emb = self.linear(self.silu(emb))
72
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
73
- 9, dim=1
74
- )
75
- norm_hidden_states = self.norm(hidden_states)
76
- hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
77
- norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
78
- return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
79
-
80
- @maybe_allow_in_graph
81
- class GatedSelfAttentionDense(nn.Module):
82
- r"""
83
- A gated self-attention dense layer that combines visual features and object features.
84
-
85
- Parameters:
86
- query_dim (`int`): The number of channels in the query.
87
- context_dim (`int`): The number of channels in the context.
88
- n_heads (`int`): The number of heads to use for attention.
89
- d_head (`int`): The number of channels in each head.
90
- """
91
-
92
- def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
93
- super().__init__()
94
-
95
- # we need a linear projection since we need cat visual feature and obj feature
96
- self.linear = nn.Linear(context_dim, query_dim)
97
-
98
- self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
99
- self.ff = FeedForward(query_dim, activation_fn="geglu")
100
-
101
- self.norm1 = nn.LayerNorm(query_dim)
102
- self.norm2 = nn.LayerNorm(query_dim)
103
-
104
- self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
105
- self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
106
-
107
- self.enabled = True
108
-
109
- def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
110
- if not self.enabled:
111
- return x
112
-
113
- n_visual = x.shape[1]
114
- objs = self.linear(objs)
115
-
116
- x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
117
- x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
118
-
119
- return x
120
-
121
-
122
- @maybe_allow_in_graph
123
- class JointTransformerBlock(nn.Module):
124
- r"""
125
- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
126
-
127
- Reference: https://arxiv.org/abs/2403.03206
128
-
129
- Parameters:
130
- dim (`int`): The number of channels in the input and output.
131
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
132
- attention_head_dim (`int`): The number of channels in each head.
133
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
134
- processing of `context` conditions.
135
- """
136
-
137
- def __init__(
138
- self,
139
- dim: int,
140
- num_attention_heads: int,
141
- attention_head_dim: int,
142
- context_pre_only: bool = False,
143
- qk_norm: Optional[str] = None,
144
- use_dual_attention: bool = False,
145
- ):
146
- super().__init__()
147
-
148
- self.use_dual_attention = use_dual_attention
149
- self.context_pre_only = context_pre_only
150
- context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
151
-
152
- if use_dual_attention:
153
- self.norm1 = SD35AdaLayerNormZeroX(dim)
154
- else:
155
- self.norm1 = AdaLayerNormZero(dim)
156
-
157
- if context_norm_type == "ada_norm_continous":
158
- self.norm1_context = AdaLayerNormContinuous(
159
- dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
160
- )
161
- elif context_norm_type == "ada_norm_zero":
162
- self.norm1_context = AdaLayerNormZero(dim)
163
- else:
164
- raise ValueError(
165
- f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
166
- )
167
-
168
- if hasattr(F, "scaled_dot_product_attention"):
169
- processor = JointAttnProcessor2_0()
170
- else:
171
- raise ValueError(
172
- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
173
- )
174
-
175
- self.attn = Attention(
176
- query_dim=dim,
177
- cross_attention_dim=None,
178
- added_kv_proj_dim=dim,
179
- dim_head=attention_head_dim,
180
- heads=num_attention_heads,
181
- out_dim=dim,
182
- context_pre_only=context_pre_only,
183
- bias=True,
184
- processor=processor,
185
- qk_norm=qk_norm,
186
- eps=1e-6,
187
- )
188
-
189
- if use_dual_attention:
190
- self.attn2 = Attention(
191
- query_dim=dim,
192
- cross_attention_dim=None,
193
- dim_head=attention_head_dim,
194
- heads=num_attention_heads,
195
- out_dim=dim,
196
- bias=True,
197
- processor=processor,
198
- qk_norm=qk_norm,
199
- eps=1e-6,
200
- )
201
- else:
202
- self.attn2 = None
203
-
204
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
205
- self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
206
-
207
- if not context_pre_only:
208
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
209
- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
210
- else:
211
- self.norm2_context = None
212
- self.ff_context = None
213
-
214
- # let chunk size default to None
215
- self._chunk_size = None
216
- self._chunk_dim = 0
217
-
218
- # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
219
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
220
- # Sets chunk feed-forward
221
- self._chunk_size = chunk_size
222
- self._chunk_dim = dim
223
-
224
- def forward(
225
- self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor,
226
- joint_attention_kwargs=None,
227
- ):
228
- if self.use_dual_attention:
229
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
230
- hidden_states, emb=temb
231
- )
232
- else:
233
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
234
-
235
- if self.context_pre_only:
236
- norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
237
- else:
238
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
239
- encoder_hidden_states, emb=temb
240
- )
241
-
242
- # Attention.
243
- attn_output, context_attn_output = self.attn(
244
- hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
245
- **({} if joint_attention_kwargs is None else joint_attention_kwargs),
246
- )
247
-
248
- # Process attention outputs for the `hidden_states`.
249
- attn_output = gate_msa.unsqueeze(1) * attn_output
250
- hidden_states = hidden_states + attn_output
251
-
252
- if self.use_dual_attention:
253
- attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **({} if joint_attention_kwargs is None else joint_attention_kwargs),)
254
- attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
255
- hidden_states = hidden_states + attn_output2
256
-
257
- norm_hidden_states = self.norm2(hidden_states)
258
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
259
- if self._chunk_size is not None:
260
- # "feed_forward_chunk_size" can be used to save memory
261
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
262
- else:
263
- ff_output = self.ff(norm_hidden_states)
264
- ff_output = gate_mlp.unsqueeze(1) * ff_output
265
-
266
- hidden_states = hidden_states + ff_output
267
-
268
- # Process attention outputs for the `encoder_hidden_states`.
269
- if self.context_pre_only:
270
- encoder_hidden_states = None
271
- else:
272
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
273
- encoder_hidden_states = encoder_hidden_states + context_attn_output
274
-
275
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
276
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
277
- if self._chunk_size is not None:
278
- # "feed_forward_chunk_size" can be used to save memory
279
- context_ff_output = _chunked_feed_forward(
280
- self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
281
- )
282
- else:
283
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
284
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
285
-
286
- return encoder_hidden_states, hidden_states
287
-
288
-
289
- @maybe_allow_in_graph
290
- class BasicTransformerBlock(nn.Module):
291
- r"""
292
- A basic Transformer block.
293
-
294
- Parameters:
295
- dim (`int`): The number of channels in the input and output.
296
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
297
- attention_head_dim (`int`): The number of channels in each head.
298
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
299
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
300
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
301
- num_embeds_ada_norm (:
302
- obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
303
- attention_bias (:
304
- obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
305
- only_cross_attention (`bool`, *optional*):
306
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
307
- double_self_attention (`bool`, *optional*):
308
- Whether to use two self-attention layers. In this case no cross attention layers are used.
309
- upcast_attention (`bool`, *optional*):
310
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
311
- norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
312
- Whether to use learnable elementwise affine parameters for normalization.
313
- norm_type (`str`, *optional*, defaults to `"layer_norm"`):
314
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
315
- final_dropout (`bool` *optional*, defaults to False):
316
- Whether to apply a final dropout after the last feed-forward layer.
317
- attention_type (`str`, *optional*, defaults to `"default"`):
318
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
319
- positional_embeddings (`str`, *optional*, defaults to `None`):
320
- The type of positional embeddings to apply to.
321
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
322
- The maximum number of positional embeddings to apply.
323
- """
324
-
325
- def __init__(
326
- self,
327
- dim: int,
328
- num_attention_heads: int,
329
- attention_head_dim: int,
330
- dropout=0.0,
331
- cross_attention_dim: Optional[int] = None,
332
- activation_fn: str = "geglu",
333
- num_embeds_ada_norm: Optional[int] = None,
334
- attention_bias: bool = False,
335
- only_cross_attention: bool = False,
336
- double_self_attention: bool = False,
337
- upcast_attention: bool = False,
338
- norm_elementwise_affine: bool = True,
339
- norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single', 'ada_norm_continuous', 'layer_norm_i2vgen'
340
- norm_eps: float = 1e-5,
341
- final_dropout: bool = False,
342
- attention_type: str = "default",
343
- positional_embeddings: Optional[str] = None,
344
- num_positional_embeddings: Optional[int] = None,
345
- ada_norm_continous_conditioning_embedding_dim: Optional[int] = None,
346
- ada_norm_bias: Optional[int] = None,
347
- ff_inner_dim: Optional[int] = None,
348
- ff_bias: bool = True,
349
- attention_out_bias: bool = True,
350
- ):
351
- super().__init__()
352
- self.dim = dim
353
- self.num_attention_heads = num_attention_heads
354
- self.attention_head_dim = attention_head_dim
355
- self.dropout = dropout
356
- self.cross_attention_dim = cross_attention_dim
357
- self.activation_fn = activation_fn
358
- self.attention_bias = attention_bias
359
- self.double_self_attention = double_self_attention
360
- self.norm_elementwise_affine = norm_elementwise_affine
361
- self.positional_embeddings = positional_embeddings
362
- self.num_positional_embeddings = num_positional_embeddings
363
- self.only_cross_attention = only_cross_attention
364
-
365
- # We keep these boolean flags for backward-compatibility.
366
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
367
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
368
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
369
- self.use_layer_norm = norm_type == "layer_norm"
370
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
371
-
372
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
373
- raise ValueError(
374
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
375
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
376
- )
377
-
378
- self.norm_type = norm_type
379
- self.num_embeds_ada_norm = num_embeds_ada_norm
380
-
381
- if positional_embeddings and (num_positional_embeddings is None):
382
- raise ValueError(
383
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
384
- )
385
-
386
- if positional_embeddings == "sinusoidal":
387
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
388
- else:
389
- self.pos_embed = None
390
-
391
- # Define 3 blocks. Each block has its own normalization layer.
392
- # 1. Self-Attn
393
- if norm_type == "ada_norm":
394
- self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
395
- elif norm_type == "ada_norm_zero":
396
- self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
397
- elif norm_type == "ada_norm_continuous":
398
- self.norm1 = AdaLayerNormContinuous(
399
- dim,
400
- ada_norm_continous_conditioning_embedding_dim,
401
- norm_elementwise_affine,
402
- norm_eps,
403
- ada_norm_bias,
404
- "rms_norm",
405
- )
406
- else:
407
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
408
-
409
- self.attn1 = Attention(
410
- query_dim=dim,
411
- heads=num_attention_heads,
412
- dim_head=attention_head_dim,
413
- dropout=dropout,
414
- bias=attention_bias,
415
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
416
- upcast_attention=upcast_attention,
417
- out_bias=attention_out_bias,
418
- )
419
-
420
- # 2. Cross-Attn
421
- if cross_attention_dim is not None or double_self_attention:
422
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
423
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
424
- # the second cross attention block.
425
- if norm_type == "ada_norm":
426
- self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
427
- elif norm_type == "ada_norm_continuous":
428
- self.norm2 = AdaLayerNormContinuous(
429
- dim,
430
- ada_norm_continous_conditioning_embedding_dim,
431
- norm_elementwise_affine,
432
- norm_eps,
433
- ada_norm_bias,
434
- "rms_norm",
435
- )
436
- else:
437
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
438
-
439
- self.attn2 = Attention(
440
- query_dim=dim,
441
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
442
- heads=num_attention_heads,
443
- dim_head=attention_head_dim,
444
- dropout=dropout,
445
- bias=attention_bias,
446
- upcast_attention=upcast_attention,
447
- out_bias=attention_out_bias,
448
- ) # is self-attn if encoder_hidden_states is none
449
- else:
450
- if norm_type == "ada_norm_single": # For Latte
451
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
452
- else:
453
- self.norm2 = None
454
- self.attn2 = None
455
-
456
- # 3. Feed-forward
457
- if norm_type == "ada_norm_continuous":
458
- self.norm3 = AdaLayerNormContinuous(
459
- dim,
460
- ada_norm_continous_conditioning_embedding_dim,
461
- norm_elementwise_affine,
462
- norm_eps,
463
- ada_norm_bias,
464
- "layer_norm",
465
- )
466
-
467
- elif norm_type in ["ada_norm_zero", "ada_norm", "layer_norm"]:
468
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
469
- elif norm_type == "layer_norm_i2vgen":
470
- self.norm3 = None
471
-
472
- self.ff = FeedForward(
473
- dim,
474
- dropout=dropout,
475
- activation_fn=activation_fn,
476
- final_dropout=final_dropout,
477
- inner_dim=ff_inner_dim,
478
- bias=ff_bias,
479
- )
480
-
481
- # 4. Fuser
482
- if attention_type == "gated" or attention_type == "gated-text-image":
483
- self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
484
-
485
- # 5. Scale-shift for PixArt-Alpha.
486
- if norm_type == "ada_norm_single":
487
- self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
488
-
489
- # let chunk size default to None
490
- self._chunk_size = None
491
- self._chunk_dim = 0
492
-
493
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
494
- # Sets chunk feed-forward
495
- self._chunk_size = chunk_size
496
- self._chunk_dim = dim
497
-
498
- def forward(
499
- self,
500
- hidden_states: torch.Tensor,
501
- attention_mask: Optional[torch.Tensor] = None,
502
- encoder_hidden_states: Optional[torch.Tensor] = None,
503
- encoder_attention_mask: Optional[torch.Tensor] = None,
504
- timestep: Optional[torch.LongTensor] = None,
505
- cross_attention_kwargs: Dict[str, Any] = None,
506
- class_labels: Optional[torch.LongTensor] = None,
507
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
508
- ) -> torch.Tensor:
509
- if cross_attention_kwargs is not None:
510
- if cross_attention_kwargs.get("scale", None) is not None:
511
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
512
-
513
- # Notice that normalization is always applied before the real computation in the following blocks.
514
- # 0. Self-Attention
515
- batch_size = hidden_states.shape[0]
516
-
517
- if self.norm_type == "ada_norm":
518
- norm_hidden_states = self.norm1(hidden_states, timestep)
519
- elif self.norm_type == "ada_norm_zero":
520
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
521
- hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
522
- )
523
- elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
524
- norm_hidden_states = self.norm1(hidden_states)
525
- elif self.norm_type == "ada_norm_continuous":
526
- norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
527
- elif self.norm_type == "ada_norm_single":
528
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
529
- self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
530
- ).chunk(6, dim=1)
531
- norm_hidden_states = self.norm1(hidden_states)
532
- norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
533
- else:
534
- raise ValueError("Incorrect norm used")
535
-
536
- if self.pos_embed is not None:
537
- norm_hidden_states = self.pos_embed(norm_hidden_states)
538
-
539
- # 1. Prepare GLIGEN inputs
540
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
541
- gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
542
-
543
- attn_output = self.attn1(
544
- norm_hidden_states,
545
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
546
- attention_mask=attention_mask,
547
- **cross_attention_kwargs,
548
- )
549
-
550
- if self.norm_type == "ada_norm_zero":
551
- attn_output = gate_msa.unsqueeze(1) * attn_output
552
- elif self.norm_type == "ada_norm_single":
553
- attn_output = gate_msa * attn_output
554
-
555
- hidden_states = attn_output + hidden_states
556
- if hidden_states.ndim == 4:
557
- hidden_states = hidden_states.squeeze(1)
558
-
559
- # 1.2 GLIGEN Control
560
- if gligen_kwargs is not None:
561
- hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
562
-
563
- # 3. Cross-Attention
564
- if self.attn2 is not None:
565
- if self.norm_type == "ada_norm":
566
- norm_hidden_states = self.norm2(hidden_states, timestep)
567
- elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
568
- norm_hidden_states = self.norm2(hidden_states)
569
- elif self.norm_type == "ada_norm_single":
570
- # For PixArt norm2 isn't applied here:
571
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
572
- norm_hidden_states = hidden_states
573
- elif self.norm_type == "ada_norm_continuous":
574
- norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
575
- else:
576
- raise ValueError("Incorrect norm")
577
-
578
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
579
- norm_hidden_states = self.pos_embed(norm_hidden_states)
580
-
581
- attn_output = self.attn2(
582
- norm_hidden_states,
583
- encoder_hidden_states=encoder_hidden_states,
584
- attention_mask=encoder_attention_mask,
585
- **cross_attention_kwargs,
586
- )
587
- hidden_states = attn_output + hidden_states
588
-
589
- # 4. Feed-forward
590
- # i2vgen doesn't have this norm πŸ€·β€β™‚οΈ
591
- if self.norm_type == "ada_norm_continuous":
592
- norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
593
- elif not self.norm_type == "ada_norm_single":
594
- norm_hidden_states = self.norm3(hidden_states)
595
-
596
- if self.norm_type == "ada_norm_zero":
597
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
598
-
599
- if self.norm_type == "ada_norm_single":
600
- norm_hidden_states = self.norm2(hidden_states)
601
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
602
-
603
- if self._chunk_size is not None:
604
- # "feed_forward_chunk_size" can be used to save memory
605
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
606
- else:
607
- ff_output = self.ff(norm_hidden_states)
608
-
609
- if self.norm_type == "ada_norm_zero":
610
- ff_output = gate_mlp.unsqueeze(1) * ff_output
611
- elif self.norm_type == "ada_norm_single":
612
- ff_output = gate_mlp * ff_output
613
-
614
- hidden_states = ff_output + hidden_states
615
- if hidden_states.ndim == 4:
616
- hidden_states = hidden_states.squeeze(1)
617
-
618
- return hidden_states
619
-
620
-
621
- class LuminaFeedForward(nn.Module):
622
- r"""
623
- A feed-forward layer.
624
-
625
- Parameters:
626
- hidden_size (`int`):
627
- The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
628
- hidden representations.
629
- intermediate_size (`int`): The intermediate dimension of the feedforward layer.
630
- multiple_of (`int`, *optional*): Value to ensure hidden dimension is a multiple
631
- of this value.
632
- ffn_dim_multiplier (float, *optional*): Custom multiplier for hidden
633
- dimension. Defaults to None.
634
- """
635
-
636
- def __init__(
637
- self,
638
- dim: int,
639
- inner_dim: int,
640
- multiple_of: Optional[int] = 256,
641
- ffn_dim_multiplier: Optional[float] = None,
642
- ):
643
- super().__init__()
644
- inner_dim = int(2 * inner_dim / 3)
645
- # custom hidden_size factor multiplier
646
- if ffn_dim_multiplier is not None:
647
- inner_dim = int(ffn_dim_multiplier * inner_dim)
648
- inner_dim = multiple_of * ((inner_dim + multiple_of - 1) // multiple_of)
649
-
650
- self.linear_1 = nn.Linear(
651
- dim,
652
- inner_dim,
653
- bias=False,
654
- )
655
- self.linear_2 = nn.Linear(
656
- inner_dim,
657
- dim,
658
- bias=False,
659
- )
660
- self.linear_3 = nn.Linear(
661
- dim,
662
- inner_dim,
663
- bias=False,
664
- )
665
- self.silu = FP32SiLU()
666
-
667
- def forward(self, x):
668
- return self.linear_2(self.silu(self.linear_1(x)) * self.linear_3(x))
669
-
670
-
671
- @maybe_allow_in_graph
672
- class TemporalBasicTransformerBlock(nn.Module):
673
- r"""
674
- A basic Transformer block for video like data.
675
-
676
- Parameters:
677
- dim (`int`): The number of channels in the input and output.
678
- time_mix_inner_dim (`int`): The number of channels for temporal attention.
679
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
680
- attention_head_dim (`int`): The number of channels in each head.
681
- cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
682
- """
683
-
684
- def __init__(
685
- self,
686
- dim: int,
687
- time_mix_inner_dim: int,
688
- num_attention_heads: int,
689
- attention_head_dim: int,
690
- cross_attention_dim: Optional[int] = None,
691
- ):
692
- super().__init__()
693
- self.is_res = dim == time_mix_inner_dim
694
-
695
- self.norm_in = nn.LayerNorm(dim)
696
-
697
- # Define 3 blocks. Each block has its own normalization layer.
698
- # 1. Self-Attn
699
- self.ff_in = FeedForward(
700
- dim,
701
- dim_out=time_mix_inner_dim,
702
- activation_fn="geglu",
703
- )
704
-
705
- self.norm1 = nn.LayerNorm(time_mix_inner_dim)
706
- self.attn1 = Attention(
707
- query_dim=time_mix_inner_dim,
708
- heads=num_attention_heads,
709
- dim_head=attention_head_dim,
710
- cross_attention_dim=None,
711
- )
712
-
713
- # 2. Cross-Attn
714
- if cross_attention_dim is not None:
715
- # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
716
- # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
717
- # the second cross attention block.
718
- self.norm2 = nn.LayerNorm(time_mix_inner_dim)
719
- self.attn2 = Attention(
720
- query_dim=time_mix_inner_dim,
721
- cross_attention_dim=cross_attention_dim,
722
- heads=num_attention_heads,
723
- dim_head=attention_head_dim,
724
- ) # is self-attn if encoder_hidden_states is none
725
- else:
726
- self.norm2 = None
727
- self.attn2 = None
728
-
729
- # 3. Feed-forward
730
- self.norm3 = nn.LayerNorm(time_mix_inner_dim)
731
- self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
732
-
733
- # let chunk size default to None
734
- self._chunk_size = None
735
- self._chunk_dim = None
736
-
737
- def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
738
- # Sets chunk feed-forward
739
- self._chunk_size = chunk_size
740
- # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
741
- self._chunk_dim = 1
742
-
743
- def forward(
744
- self,
745
- hidden_states: torch.Tensor,
746
- num_frames: int,
747
- encoder_hidden_states: Optional[torch.Tensor] = None,
748
- ) -> torch.Tensor:
749
- # Notice that normalization is always applied before the real computation in the following blocks.
750
- # 0. Self-Attention
751
- batch_size = hidden_states.shape[0]
752
-
753
- batch_frames, seq_length, channels = hidden_states.shape
754
- batch_size = batch_frames // num_frames
755
-
756
- hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
757
- hidden_states = hidden_states.permute(0, 2, 1, 3)
758
- hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
759
-
760
- residual = hidden_states
761
- hidden_states = self.norm_in(hidden_states)
762
-
763
- if self._chunk_size is not None:
764
- hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size)
765
- else:
766
- hidden_states = self.ff_in(hidden_states)
767
-
768
- if self.is_res:
769
- hidden_states = hidden_states + residual
770
-
771
- norm_hidden_states = self.norm1(hidden_states)
772
- attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
773
- hidden_states = attn_output + hidden_states
774
-
775
- # 3. Cross-Attention
776
- if self.attn2 is not None:
777
- norm_hidden_states = self.norm2(hidden_states)
778
- attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
779
- hidden_states = attn_output + hidden_states
780
-
781
- # 4. Feed-forward
782
- norm_hidden_states = self.norm3(hidden_states)
783
-
784
- if self._chunk_size is not None:
785
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
786
- else:
787
- ff_output = self.ff(norm_hidden_states)
788
-
789
- if self.is_res:
790
- hidden_states = ff_output + hidden_states
791
- else:
792
- hidden_states = ff_output
793
-
794
- hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
795
- hidden_states = hidden_states.permute(0, 2, 1, 3)
796
- hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
797
-
798
- return hidden_states
799
-
800
-
801
- class SkipFFTransformerBlock(nn.Module):
802
- def __init__(
803
- self,
804
- dim: int,
805
- num_attention_heads: int,
806
- attention_head_dim: int,
807
- kv_input_dim: int,
808
- kv_input_dim_proj_use_bias: bool,
809
- dropout=0.0,
810
- cross_attention_dim: Optional[int] = None,
811
- attention_bias: bool = False,
812
- attention_out_bias: bool = True,
813
- ):
814
- super().__init__()
815
- if kv_input_dim != dim:
816
- self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias)
817
- else:
818
- self.kv_mapper = None
819
-
820
- self.norm1 = RMSNorm(dim, 1e-06)
821
-
822
- self.attn1 = Attention(
823
- query_dim=dim,
824
- heads=num_attention_heads,
825
- dim_head=attention_head_dim,
826
- dropout=dropout,
827
- bias=attention_bias,
828
- cross_attention_dim=cross_attention_dim,
829
- out_bias=attention_out_bias,
830
- )
831
-
832
- self.norm2 = RMSNorm(dim, 1e-06)
833
-
834
- self.attn2 = Attention(
835
- query_dim=dim,
836
- cross_attention_dim=cross_attention_dim,
837
- heads=num_attention_heads,
838
- dim_head=attention_head_dim,
839
- dropout=dropout,
840
- bias=attention_bias,
841
- out_bias=attention_out_bias,
842
- )
843
-
844
- def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs):
845
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
846
-
847
- if self.kv_mapper is not None:
848
- encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))
849
-
850
- norm_hidden_states = self.norm1(hidden_states)
851
-
852
- attn_output = self.attn1(
853
- norm_hidden_states,
854
- encoder_hidden_states=encoder_hidden_states,
855
- **cross_attention_kwargs,
856
- )
857
-
858
- hidden_states = attn_output + hidden_states
859
-
860
- norm_hidden_states = self.norm2(hidden_states)
861
-
862
- attn_output = self.attn2(
863
- norm_hidden_states,
864
- encoder_hidden_states=encoder_hidden_states,
865
- **cross_attention_kwargs,
866
- )
867
-
868
- hidden_states = attn_output + hidden_states
869
-
870
- return hidden_states
871
-
872
-
873
- @maybe_allow_in_graph
874
- class FreeNoiseTransformerBlock(nn.Module):
875
- r"""
876
- A FreeNoise Transformer block.
877
-
878
- Parameters:
879
- dim (`int`):
880
- The number of channels in the input and output.
881
- num_attention_heads (`int`):
882
- The number of heads to use for multi-head attention.
883
- attention_head_dim (`int`):
884
- The number of channels in each head.
885
- dropout (`float`, *optional*, defaults to 0.0):
886
- The dropout probability to use.
887
- cross_attention_dim (`int`, *optional*):
888
- The size of the encoder_hidden_states vector for cross attention.
889
- activation_fn (`str`, *optional*, defaults to `"geglu"`):
890
- Activation function to be used in feed-forward.
891
- num_embeds_ada_norm (`int`, *optional*):
892
- The number of diffusion steps used during training. See `Transformer2DModel`.
893
- attention_bias (`bool`, defaults to `False`):
894
- Configure if the attentions should contain a bias parameter.
895
- only_cross_attention (`bool`, defaults to `False`):
896
- Whether to use only cross-attention layers. In this case two cross attention layers are used.
897
- double_self_attention (`bool`, defaults to `False`):
898
- Whether to use two self-attention layers. In this case no cross attention layers are used.
899
- upcast_attention (`bool`, defaults to `False`):
900
- Whether to upcast the attention computation to float32. This is useful for mixed precision training.
901
- norm_elementwise_affine (`bool`, defaults to `True`):
902
- Whether to use learnable elementwise affine parameters for normalization.
903
- norm_type (`str`, defaults to `"layer_norm"`):
904
- The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
905
- final_dropout (`bool` defaults to `False`):
906
- Whether to apply a final dropout after the last feed-forward layer.
907
- attention_type (`str`, defaults to `"default"`):
908
- The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
909
- positional_embeddings (`str`, *optional*):
910
- The type of positional embeddings to apply to.
911
- num_positional_embeddings (`int`, *optional*, defaults to `None`):
912
- The maximum number of positional embeddings to apply.
913
- ff_inner_dim (`int`, *optional*):
914
- Hidden dimension of feed-forward MLP.
915
- ff_bias (`bool`, defaults to `True`):
916
- Whether or not to use bias in feed-forward MLP.
917
- attention_out_bias (`bool`, defaults to `True`):
918
- Whether or not to use bias in attention output project layer.
919
- context_length (`int`, defaults to `16`):
920
- The maximum number of frames that the FreeNoise block processes at once.
921
- context_stride (`int`, defaults to `4`):
922
- The number of frames to be skipped before starting to process a new batch of `context_length` frames.
923
- weighting_scheme (`str`, defaults to `"pyramid"`):
924
- The weighting scheme to use for weighting averaging of processed latent frames. As described in the
925
- Equation 9. of the [FreeNoise](https://arxiv.org/abs/2310.15169) paper, "pyramid" is the default setting
926
- used.
927
- """
928
-
929
- def __init__(
930
- self,
931
- dim: int,
932
- num_attention_heads: int,
933
- attention_head_dim: int,
934
- dropout: float = 0.0,
935
- cross_attention_dim: Optional[int] = None,
936
- activation_fn: str = "geglu",
937
- num_embeds_ada_norm: Optional[int] = None,
938
- attention_bias: bool = False,
939
- only_cross_attention: bool = False,
940
- double_self_attention: bool = False,
941
- upcast_attention: bool = False,
942
- norm_elementwise_affine: bool = True,
943
- norm_type: str = "layer_norm",
944
- norm_eps: float = 1e-5,
945
- final_dropout: bool = False,
946
- positional_embeddings: Optional[str] = None,
947
- num_positional_embeddings: Optional[int] = None,
948
- ff_inner_dim: Optional[int] = None,
949
- ff_bias: bool = True,
950
- attention_out_bias: bool = True,
951
- context_length: int = 16,
952
- context_stride: int = 4,
953
- weighting_scheme: str = "pyramid",
954
- ):
955
- super().__init__()
956
- self.dim = dim
957
- self.num_attention_heads = num_attention_heads
958
- self.attention_head_dim = attention_head_dim
959
- self.dropout = dropout
960
- self.cross_attention_dim = cross_attention_dim
961
- self.activation_fn = activation_fn
962
- self.attention_bias = attention_bias
963
- self.double_self_attention = double_self_attention
964
- self.norm_elementwise_affine = norm_elementwise_affine
965
- self.positional_embeddings = positional_embeddings
966
- self.num_positional_embeddings = num_positional_embeddings
967
- self.only_cross_attention = only_cross_attention
968
-
969
- self.set_free_noise_properties(context_length, context_stride, weighting_scheme)
970
-
971
- # We keep these boolean flags for backward-compatibility.
972
- self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
973
- self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
974
- self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
975
- self.use_layer_norm = norm_type == "layer_norm"
976
- self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
977
-
978
- if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
979
- raise ValueError(
980
- f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
981
- f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
982
- )
983
-
984
- self.norm_type = norm_type
985
- self.num_embeds_ada_norm = num_embeds_ada_norm
986
-
987
- if positional_embeddings and (num_positional_embeddings is None):
988
- raise ValueError(
989
- "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
990
- )
991
-
992
- if positional_embeddings == "sinusoidal":
993
- self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
994
- else:
995
- self.pos_embed = None
996
-
997
- # Define 3 blocks. Each block has its own normalization layer.
998
- # 1. Self-Attn
999
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
1000
-
1001
- self.attn1 = Attention(
1002
- query_dim=dim,
1003
- heads=num_attention_heads,
1004
- dim_head=attention_head_dim,
1005
- dropout=dropout,
1006
- bias=attention_bias,
1007
- cross_attention_dim=cross_attention_dim if only_cross_attention else None,
1008
- upcast_attention=upcast_attention,
1009
- out_bias=attention_out_bias,
1010
- )
1011
-
1012
- # 2. Cross-Attn
1013
- if cross_attention_dim is not None or double_self_attention:
1014
- self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1015
-
1016
- self.attn2 = Attention(
1017
- query_dim=dim,
1018
- cross_attention_dim=cross_attention_dim if not double_self_attention else None,
1019
- heads=num_attention_heads,
1020
- dim_head=attention_head_dim,
1021
- dropout=dropout,
1022
- bias=attention_bias,
1023
- upcast_attention=upcast_attention,
1024
- out_bias=attention_out_bias,
1025
- ) # is self-attn if encoder_hidden_states is none
1026
-
1027
- # 3. Feed-forward
1028
- self.ff = FeedForward(
1029
- dim,
1030
- dropout=dropout,
1031
- activation_fn=activation_fn,
1032
- final_dropout=final_dropout,
1033
- inner_dim=ff_inner_dim,
1034
- bias=ff_bias,
1035
- )
1036
-
1037
- self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine)
1038
-
1039
- # let chunk size default to None
1040
- self._chunk_size = None
1041
- self._chunk_dim = 0
1042
-
1043
- def _get_frame_indices(self, num_frames: int) -> List[Tuple[int, int]]:
1044
- frame_indices = []
1045
- for i in range(0, num_frames - self.context_length + 1, self.context_stride):
1046
- window_start = i
1047
- window_end = min(num_frames, i + self.context_length)
1048
- frame_indices.append((window_start, window_end))
1049
- return frame_indices
1050
-
1051
- def _get_frame_weights(self, num_frames: int, weighting_scheme: str = "pyramid") -> List[float]:
1052
- if weighting_scheme == "flat":
1053
- weights = [1.0] * num_frames
1054
-
1055
- elif weighting_scheme == "pyramid":
1056
- if num_frames % 2 == 0:
1057
- # num_frames = 4 => [1, 2, 2, 1]
1058
- mid = num_frames // 2
1059
- weights = list(range(1, mid + 1))
1060
- weights = weights + weights[::-1]
1061
- else:
1062
- # num_frames = 5 => [1, 2, 3, 2, 1]
1063
- mid = (num_frames + 1) // 2
1064
- weights = list(range(1, mid))
1065
- weights = weights + [mid] + weights[::-1]
1066
-
1067
- elif weighting_scheme == "delayed_reverse_sawtooth":
1068
- if num_frames % 2 == 0:
1069
- # num_frames = 4 => [0.01, 2, 2, 1]
1070
- mid = num_frames // 2
1071
- weights = [0.01] * (mid - 1) + [mid]
1072
- weights = weights + list(range(mid, 0, -1))
1073
- else:
1074
- # num_frames = 5 => [0.01, 0.01, 3, 2, 1]
1075
- mid = (num_frames + 1) // 2
1076
- weights = [0.01] * mid
1077
- weights = weights + list(range(mid, 0, -1))
1078
- else:
1079
- raise ValueError(f"Unsupported value for weighting_scheme={weighting_scheme}")
1080
-
1081
- return weights
1082
-
1083
- def set_free_noise_properties(
1084
- self, context_length: int, context_stride: int, weighting_scheme: str = "pyramid"
1085
- ) -> None:
1086
- self.context_length = context_length
1087
- self.context_stride = context_stride
1088
- self.weighting_scheme = weighting_scheme
1089
-
1090
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0) -> None:
1091
- # Sets chunk feed-forward
1092
- self._chunk_size = chunk_size
1093
- self._chunk_dim = dim
1094
-
1095
- def forward(
1096
- self,
1097
- hidden_states: torch.Tensor,
1098
- attention_mask: Optional[torch.Tensor] = None,
1099
- encoder_hidden_states: Optional[torch.Tensor] = None,
1100
- encoder_attention_mask: Optional[torch.Tensor] = None,
1101
- cross_attention_kwargs: Dict[str, Any] = None,
1102
- *args,
1103
- **kwargs,
1104
- ) -> torch.Tensor:
1105
- if cross_attention_kwargs is not None:
1106
- if cross_attention_kwargs.get("scale", None) is not None:
1107
- logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1108
-
1109
- cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
1110
-
1111
- # hidden_states: [B x H x W, F, C]
1112
- device = hidden_states.device
1113
- dtype = hidden_states.dtype
1114
-
1115
- num_frames = hidden_states.size(1)
1116
- frame_indices = self._get_frame_indices(num_frames)
1117
- frame_weights = self._get_frame_weights(self.context_length, self.weighting_scheme)
1118
- frame_weights = torch.tensor(frame_weights, device=device, dtype=dtype).unsqueeze(0).unsqueeze(-1)
1119
- is_last_frame_batch_complete = frame_indices[-1][1] == num_frames
1120
-
1121
- # Handle out-of-bounds case if num_frames isn't perfectly divisible by context_length
1122
- # For example, num_frames=25, context_length=16, context_stride=4, then we expect the ranges:
1123
- # [(0, 16), (4, 20), (8, 24), (10, 26)]
1124
- if not is_last_frame_batch_complete:
1125
- if num_frames < self.context_length:
1126
- raise ValueError(f"Expected {num_frames=} to be greater or equal than {self.context_length=}")
1127
- last_frame_batch_length = num_frames - frame_indices[-1][1]
1128
- frame_indices.append((num_frames - self.context_length, num_frames))
1129
-
1130
- num_times_accumulated = torch.zeros((1, num_frames, 1), device=device)
1131
- accumulated_values = torch.zeros_like(hidden_states)
1132
-
1133
- for i, (frame_start, frame_end) in enumerate(frame_indices):
1134
- # The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
1135
- # cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
1136
- # essentially a non-multiple of `context_length`.
1137
- weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
1138
- weights *= frame_weights
1139
-
1140
- hidden_states_chunk = hidden_states[:, frame_start:frame_end]
1141
-
1142
- # Notice that normalization is always applied before the real computation in the following blocks.
1143
- # 1. Self-Attention
1144
- norm_hidden_states = self.norm1(hidden_states_chunk)
1145
-
1146
- if self.pos_embed is not None:
1147
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1148
-
1149
- attn_output = self.attn1(
1150
- norm_hidden_states,
1151
- encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
1152
- attention_mask=attention_mask,
1153
- **cross_attention_kwargs,
1154
- )
1155
-
1156
- hidden_states_chunk = attn_output + hidden_states_chunk
1157
- if hidden_states_chunk.ndim == 4:
1158
- hidden_states_chunk = hidden_states_chunk.squeeze(1)
1159
-
1160
- # 2. Cross-Attention
1161
- if self.attn2 is not None:
1162
- norm_hidden_states = self.norm2(hidden_states_chunk)
1163
-
1164
- if self.pos_embed is not None and self.norm_type != "ada_norm_single":
1165
- norm_hidden_states = self.pos_embed(norm_hidden_states)
1166
-
1167
- attn_output = self.attn2(
1168
- norm_hidden_states,
1169
- encoder_hidden_states=encoder_hidden_states,
1170
- attention_mask=encoder_attention_mask,
1171
- **cross_attention_kwargs,
1172
- )
1173
- hidden_states_chunk = attn_output + hidden_states_chunk
1174
-
1175
- if i == len(frame_indices) - 1 and not is_last_frame_batch_complete:
1176
- accumulated_values[:, -last_frame_batch_length:] += (
1177
- hidden_states_chunk[:, -last_frame_batch_length:] * weights[:, -last_frame_batch_length:]
1178
- )
1179
- num_times_accumulated[:, -last_frame_batch_length:] += weights[:, -last_frame_batch_length]
1180
- else:
1181
- accumulated_values[:, frame_start:frame_end] += hidden_states_chunk * weights
1182
- num_times_accumulated[:, frame_start:frame_end] += weights
1183
-
1184
- # TODO(aryan): Maybe this could be done in a better way.
1185
- #
1186
- # Previously, this was:
1187
- # hidden_states = torch.where(
1188
- # num_times_accumulated > 0, accumulated_values / num_times_accumulated, accumulated_values
1189
- # )
1190
- #
1191
- # The reasoning for the change here is `torch.where` became a bottleneck at some point when golfing memory
1192
- # spikes. It is particularly noticeable when the number of frames is high. My understanding is that this comes
1193
- # from tensors being copied - which is why we resort to spliting and concatenating here. I've not particularly
1194
- # looked into this deeply because other memory optimizations led to more pronounced reductions.
1195
- hidden_states = torch.cat(
1196
- [
1197
- torch.where(num_times_split > 0, accumulated_split / num_times_split, accumulated_split)
1198
- for accumulated_split, num_times_split in zip(
1199
- accumulated_values.split(self.context_length, dim=1),
1200
- num_times_accumulated.split(self.context_length, dim=1),
1201
- )
1202
- ],
1203
- dim=1,
1204
- ).to(dtype)
1205
-
1206
- # 3. Feed-forward
1207
- norm_hidden_states = self.norm3(hidden_states)
1208
-
1209
- if self._chunk_size is not None:
1210
- ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
1211
- else:
1212
- ff_output = self.ff(norm_hidden_states)
1213
-
1214
- hidden_states = ff_output + hidden_states
1215
- if hidden_states.ndim == 4:
1216
- hidden_states = hidden_states.squeeze(1)
1217
-
1218
- return hidden_states
1219
-
1220
-
1221
- class FeedForward(nn.Module):
1222
- r"""
1223
- A feed-forward layer.
1224
-
1225
- Parameters:
1226
- dim (`int`): The number of channels in the input.
1227
- dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1228
- mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1229
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1230
- activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1231
- final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1232
- bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1233
- """
1234
-
1235
- def __init__(
1236
- self,
1237
- dim: int,
1238
- dim_out: Optional[int] = None,
1239
- mult: int = 4,
1240
- dropout: float = 0.0,
1241
- activation_fn: str = "geglu",
1242
- final_dropout: bool = False,
1243
- inner_dim=None,
1244
- bias: bool = True,
1245
- ):
1246
- super().__init__()
1247
- if inner_dim is None:
1248
- inner_dim = int(dim * mult)
1249
- dim_out = dim_out if dim_out is not None else dim
1250
-
1251
- if activation_fn == "gelu":
1252
- act_fn = GELU(dim, inner_dim, bias=bias)
1253
- if activation_fn == "gelu-approximate":
1254
- act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1255
- elif activation_fn == "geglu":
1256
- act_fn = GEGLU(dim, inner_dim, bias=bias)
1257
- elif activation_fn == "geglu-approximate":
1258
- act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1259
- elif activation_fn == "swiglu":
1260
- act_fn = SwiGLU(dim, inner_dim, bias=bias)
1261
-
1262
- self.net = nn.ModuleList([])
1263
- # project in
1264
- self.net.append(act_fn)
1265
- # project dropout
1266
- self.net.append(nn.Dropout(dropout))
1267
- # project out
1268
- self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
1269
- # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1270
- if final_dropout:
1271
- self.net.append(nn.Dropout(dropout))
1272
-
1273
- def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1274
- if len(args) > 0 or kwargs.get("scale", None) is not None:
1275
- deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1276
- deprecate("scale", "1.0.0", deprecation_message)
1277
- for module in self.net:
1278
- hidden_states = module(hidden_states)
1279
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models_resampler.py DELETED
@@ -1,304 +0,0 @@
1
- # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
- import math
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- from diffusers.models.embeddings import Timesteps, TimestepEmbedding
8
-
9
- def get_timestep_embedding(
10
- timesteps: torch.Tensor,
11
- embedding_dim: int,
12
- flip_sin_to_cos: bool = False,
13
- downscale_freq_shift: float = 1,
14
- scale: float = 1,
15
- max_period: int = 10000,
16
- ):
17
- """
18
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
19
-
20
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
21
- These may be fractional.
22
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
23
- embeddings. :return: an [N x dim] Tensor of positional embeddings.
24
- """
25
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
26
-
27
- half_dim = embedding_dim // 2
28
- exponent = -math.log(max_period) * torch.arange(
29
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
30
- )
31
- exponent = exponent / (half_dim - downscale_freq_shift)
32
-
33
- emb = torch.exp(exponent)
34
- emb = timesteps[:, None].float() * emb[None, :]
35
-
36
- # scale embeddings
37
- emb = scale * emb
38
-
39
- # concat sine and cosine embeddings
40
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
41
-
42
- # flip sine and cosine embeddings
43
- if flip_sin_to_cos:
44
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
45
-
46
- # zero pad
47
- if embedding_dim % 2 == 1:
48
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
49
- return emb
50
-
51
-
52
- # FFN
53
- def FeedForward(dim, mult=4):
54
- inner_dim = int(dim * mult)
55
- return nn.Sequential(
56
- nn.LayerNorm(dim),
57
- nn.Linear(dim, inner_dim, bias=False),
58
- nn.GELU(),
59
- nn.Linear(inner_dim, dim, bias=False),
60
- )
61
-
62
-
63
- def reshape_tensor(x, heads):
64
- bs, length, width = x.shape
65
- #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
66
- x = x.view(bs, length, heads, -1)
67
- # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
68
- x = x.transpose(1, 2)
69
- # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
70
- x = x.reshape(bs, heads, length, -1)
71
- return x
72
-
73
-
74
- class PerceiverAttention(nn.Module):
75
- def __init__(self, *, dim, dim_head=64, heads=8):
76
- super().__init__()
77
- self.scale = dim_head**-0.5
78
- self.dim_head = dim_head
79
- self.heads = heads
80
- inner_dim = dim_head * heads
81
-
82
- self.norm1 = nn.LayerNorm(dim)
83
- self.norm2 = nn.LayerNorm(dim)
84
-
85
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
86
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
87
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
88
-
89
-
90
- def forward(self, x, latents, shift=None, scale=None):
91
- """
92
- Args:
93
- x (torch.Tensor): image features
94
- shape (b, n1, D)
95
- latent (torch.Tensor): latent features
96
- shape (b, n2, D)
97
- """
98
- x = self.norm1(x)
99
- latents = self.norm2(latents)
100
-
101
- if shift is not None and scale is not None:
102
- latents = latents * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
103
-
104
- b, l, _ = latents.shape
105
-
106
- q = self.to_q(latents)
107
- kv_input = torch.cat((x, latents), dim=-2)
108
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
109
-
110
- q = reshape_tensor(q, self.heads)
111
- k = reshape_tensor(k, self.heads)
112
- v = reshape_tensor(v, self.heads)
113
-
114
- # attention
115
- scale = 1 / math.sqrt(math.sqrt(self.dim_head))
116
- weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
117
- weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
118
- out = weight @ v
119
-
120
- out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
121
-
122
- return self.to_out(out)
123
-
124
-
125
- class Resampler(nn.Module):
126
- def __init__(
127
- self,
128
- dim=1024,
129
- depth=8,
130
- dim_head=64,
131
- heads=16,
132
- num_queries=8,
133
- embedding_dim=768,
134
- output_dim=1024,
135
- ff_mult=4,
136
- *args,
137
- **kwargs,
138
- ):
139
- super().__init__()
140
-
141
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
142
-
143
- self.proj_in = nn.Linear(embedding_dim, dim)
144
-
145
- self.proj_out = nn.Linear(dim, output_dim)
146
- self.norm_out = nn.LayerNorm(output_dim)
147
-
148
- self.layers = nn.ModuleList([])
149
- for _ in range(depth):
150
- self.layers.append(
151
- nn.ModuleList(
152
- [
153
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
154
- FeedForward(dim=dim, mult=ff_mult),
155
- ]
156
- )
157
- )
158
-
159
- def forward(self, x):
160
-
161
- latents = self.latents.repeat(x.size(0), 1, 1)
162
-
163
- x = self.proj_in(x)
164
-
165
- for attn, ff in self.layers:
166
- latents = attn(x, latents) + latents
167
- latents = ff(latents) + latents
168
-
169
- latents = self.proj_out(latents)
170
- return self.norm_out(latents)
171
-
172
-
173
- class TimeResampler(nn.Module):
174
- def __init__(
175
- self,
176
- dim=1024,
177
- depth=8,
178
- dim_head=64,
179
- heads=16,
180
- num_queries=8,
181
- embedding_dim=768,
182
- output_dim=1024,
183
- ff_mult=4,
184
- timestep_in_dim=320,
185
- timestep_flip_sin_to_cos=True,
186
- timestep_freq_shift=0,
187
- ):
188
- super().__init__()
189
-
190
- self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
191
-
192
- self.proj_in = nn.Linear(embedding_dim, dim)
193
-
194
- self.proj_out = nn.Linear(dim, output_dim)
195
- self.norm_out = nn.LayerNorm(output_dim)
196
-
197
- self.layers = nn.ModuleList([])
198
- for _ in range(depth):
199
- self.layers.append(
200
- nn.ModuleList(
201
- [
202
- # msa
203
- PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
204
- # ff
205
- FeedForward(dim=dim, mult=ff_mult),
206
- # adaLN
207
- nn.Sequential(nn.SiLU(), nn.Linear(dim, 4 * dim, bias=True))
208
- ]
209
- )
210
- )
211
-
212
- # time
213
- self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift)
214
- self.time_embedding = TimestepEmbedding(timestep_in_dim, dim, act_fn="silu")
215
-
216
- # adaLN
217
- # self.adaLN_modulation = nn.Sequential(
218
- # nn.SiLU(),
219
- # nn.Linear(timestep_out_dim, 6 * timestep_out_dim, bias=True)
220
- # )
221
-
222
-
223
- def forward(self, x, timestep, need_temb=False):
224
- timestep_emb = self.embedding_time(x, timestep) # bs, dim
225
-
226
- latents = self.latents.repeat(x.size(0), 1, 1)
227
-
228
- x = self.proj_in(x)
229
- x = x + timestep_emb[:, None]
230
-
231
- for attn, ff, adaLN_modulation in self.layers:
232
- shift_msa, scale_msa, shift_mlp, scale_mlp = adaLN_modulation(timestep_emb).chunk(4, dim=1)
233
- latents = attn(x, latents, shift_msa, scale_msa) + latents
234
-
235
- res = latents
236
- for idx_ff in range(len(ff)):
237
- layer_ff = ff[idx_ff]
238
- latents = layer_ff(latents)
239
- if idx_ff == 0 and isinstance(layer_ff, nn.LayerNorm): # adaLN
240
- latents = latents * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
241
- latents = latents + res
242
-
243
- # latents = ff(latents) + latents
244
-
245
- latents = self.proj_out(latents)
246
- latents = self.norm_out(latents)
247
-
248
- if need_temb:
249
- return latents, timestep_emb
250
- else:
251
- return latents
252
-
253
-
254
-
255
- def embedding_time(self, sample, timestep):
256
-
257
- # 1. time
258
- timesteps = timestep
259
- if not torch.is_tensor(timesteps):
260
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
261
- # This would be a good case for the `match` statement (Python 3.10+)
262
- is_mps = sample.device.type == "mps"
263
- if isinstance(timestep, float):
264
- dtype = torch.float32 if is_mps else torch.float64
265
- else:
266
- dtype = torch.int32 if is_mps else torch.int64
267
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
268
- elif len(timesteps.shape) == 0:
269
- timesteps = timesteps[None].to(sample.device)
270
-
271
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
272
- timesteps = timesteps.expand(sample.shape[0])
273
-
274
- t_emb = self.time_proj(timesteps)
275
-
276
- # timesteps does not contain any weights and will always return f32 tensors
277
- # but time_embedding might actually be running in fp16. so we need to cast here.
278
- # there might be better ways to encapsulate this.
279
- t_emb = t_emb.to(dtype=sample.dtype)
280
-
281
- emb = self.time_embedding(t_emb, None)
282
- return emb
283
-
284
-
285
-
286
-
287
-
288
- if __name__ == '__main__':
289
- model = TimeResampler(
290
- dim=1280,
291
- depth=4,
292
- dim_head=64,
293
- heads=20,
294
- num_queries=16,
295
- embedding_dim=512,
296
- output_dim=2048,
297
- ff_mult=4,
298
- timestep_in_dim=320,
299
- timestep_flip_sin_to_cos=True,
300
- timestep_freq_shift=0,
301
- in_channel_extra_emb=2048,
302
- )
303
-
304
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models_transformer_sd3.py DELETED
@@ -1,375 +0,0 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
-
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
23
- from models_attention import JointTransformerBlock
24
- from diffusers.models.attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
25
- from diffusers.models.modeling_utils import ModelMixin
26
- from diffusers.models.normalization import AdaLayerNormContinuous
27
- from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
- from diffusers.models.embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
29
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
30
-
31
-
32
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
-
34
-
35
- class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
36
- """
37
- The Transformer model introduced in Stable Diffusion 3.
38
-
39
- Reference: https://arxiv.org/abs/2403.03206
40
-
41
- Parameters:
42
- sample_size (`int`): The width of the latent images. This is fixed during training since
43
- it is used to learn a number of position embeddings.
44
- patch_size (`int`): Patch size to turn the input data into small patches.
45
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
46
- num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
47
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
48
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
49
- cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
50
- caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
51
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
52
- out_channels (`int`, defaults to 16): Number of output channels.
53
-
54
- """
55
-
56
- _supports_gradient_checkpointing = True
57
-
58
- @register_to_config
59
- def __init__(
60
- self,
61
- sample_size: int = 128,
62
- patch_size: int = 2,
63
- in_channels: int = 16,
64
- num_layers: int = 18,
65
- attention_head_dim: int = 64,
66
- num_attention_heads: int = 18,
67
- joint_attention_dim: int = 4096,
68
- caption_projection_dim: int = 1152,
69
- pooled_projection_dim: int = 2048,
70
- out_channels: int = 16,
71
- pos_embed_max_size: int = 96,
72
- dual_attention_layers: Tuple[
73
- int, ...
74
- ] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
75
- qk_norm: Optional[str] = None,
76
- ):
77
- super().__init__()
78
- default_out_channels = in_channels
79
- self.out_channels = out_channels if out_channels is not None else default_out_channels
80
- self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
81
-
82
- self.pos_embed = PatchEmbed(
83
- height=self.config.sample_size,
84
- width=self.config.sample_size,
85
- patch_size=self.config.patch_size,
86
- in_channels=self.config.in_channels,
87
- embed_dim=self.inner_dim,
88
- pos_embed_max_size=pos_embed_max_size, # hard-code for now.
89
- )
90
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
91
- embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
92
- )
93
- self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
94
-
95
- # `attention_head_dim` is doubled to account for the mixing.
96
- # It needs to crafted when we get the actual checkpoints.
97
- self.transformer_blocks = nn.ModuleList(
98
- [
99
- JointTransformerBlock(
100
- dim=self.inner_dim,
101
- num_attention_heads=self.config.num_attention_heads,
102
- attention_head_dim=self.config.attention_head_dim,
103
- context_pre_only=i == num_layers - 1,
104
- qk_norm=qk_norm,
105
- use_dual_attention=True if i in dual_attention_layers else False,
106
- )
107
- for i in range(self.config.num_layers)
108
- ]
109
- )
110
-
111
- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
112
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
113
-
114
- self.gradient_checkpointing = False
115
-
116
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
117
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
118
- """
119
- Sets the attention processor to use [feed forward
120
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
121
-
122
- Parameters:
123
- chunk_size (`int`, *optional*):
124
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
125
- over each tensor of dim=`dim`.
126
- dim (`int`, *optional*, defaults to `0`):
127
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
128
- or dim=1 (sequence length).
129
- """
130
- if dim not in [0, 1]:
131
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
132
-
133
- # By default chunk size is 1
134
- chunk_size = chunk_size or 1
135
-
136
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
137
- if hasattr(module, "set_chunk_feed_forward"):
138
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
139
-
140
- for child in module.children():
141
- fn_recursive_feed_forward(child, chunk_size, dim)
142
-
143
- for module in self.children():
144
- fn_recursive_feed_forward(module, chunk_size, dim)
145
-
146
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
147
- def disable_forward_chunking(self):
148
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
149
- if hasattr(module, "set_chunk_feed_forward"):
150
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
151
-
152
- for child in module.children():
153
- fn_recursive_feed_forward(child, chunk_size, dim)
154
-
155
- for module in self.children():
156
- fn_recursive_feed_forward(module, None, 0)
157
-
158
- @property
159
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
160
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
161
- r"""
162
- Returns:
163
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
164
- indexed by its weight name.
165
- """
166
- # set recursively
167
- processors = {}
168
-
169
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
170
- if hasattr(module, "get_processor"):
171
- processors[f"{name}.processor"] = module.get_processor()
172
-
173
- for sub_name, child in module.named_children():
174
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
175
-
176
- return processors
177
-
178
- for name, module in self.named_children():
179
- fn_recursive_add_processors(name, module, processors)
180
-
181
- return processors
182
-
183
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
184
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
185
- r"""
186
- Sets the attention processor to use to compute attention.
187
-
188
- Parameters:
189
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
190
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
191
- for **all** `Attention` layers.
192
-
193
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
194
- processor. This is strongly recommended when setting trainable attention processors.
195
-
196
- """
197
- count = len(self.attn_processors.keys())
198
-
199
- if isinstance(processor, dict) and len(processor) != count:
200
- raise ValueError(
201
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
202
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
203
- )
204
-
205
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
206
- if hasattr(module, "set_processor"):
207
- if not isinstance(processor, dict):
208
- module.set_processor(processor)
209
- else:
210
- module.set_processor(processor.pop(f"{name}.processor"))
211
-
212
- for sub_name, child in module.named_children():
213
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
214
-
215
- for name, module in self.named_children():
216
- fn_recursive_attn_processor(name, module, processor)
217
-
218
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedJointAttnProcessor2_0
219
- def fuse_qkv_projections(self):
220
- """
221
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
222
- are fused. For cross-attention modules, key and value projection matrices are fused.
223
-
224
- <Tip warning={true}>
225
-
226
- This API is πŸ§ͺ experimental.
227
-
228
- </Tip>
229
- """
230
- self.original_attn_processors = None
231
-
232
- for _, attn_processor in self.attn_processors.items():
233
- if "Added" in str(attn_processor.__class__.__name__):
234
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
235
-
236
- self.original_attn_processors = self.attn_processors
237
-
238
- for module in self.modules():
239
- if isinstance(module, Attention):
240
- module.fuse_projections(fuse=True)
241
-
242
- self.set_attn_processor(FusedJointAttnProcessor2_0())
243
-
244
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
245
- def unfuse_qkv_projections(self):
246
- """Disables the fused QKV projection if enabled.
247
-
248
- <Tip warning={true}>
249
-
250
- This API is πŸ§ͺ experimental.
251
-
252
- </Tip>
253
-
254
- """
255
- if self.original_attn_processors is not None:
256
- self.set_attn_processor(self.original_attn_processors)
257
-
258
- def _set_gradient_checkpointing(self, module, value=False):
259
- if hasattr(module, "gradient_checkpointing"):
260
- module.gradient_checkpointing = value
261
-
262
- def forward(
263
- self,
264
- hidden_states: torch.FloatTensor,
265
- encoder_hidden_states: torch.FloatTensor = None,
266
- pooled_projections: torch.FloatTensor = None,
267
- timestep: torch.LongTensor = None,
268
- block_controlnet_hidden_states: List = None,
269
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
270
- return_dict: bool = True,
271
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
272
- """
273
- The [`SD3Transformer2DModel`] forward method.
274
-
275
- Args:
276
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
277
- Input `hidden_states`.
278
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
- from the embeddings of input conditions.
282
- timestep ( `torch.LongTensor`):
283
- Used to indicate denoising step.
284
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
285
- A list of tensors that if specified are added to the residuals of transformer blocks.
286
- joint_attention_kwargs (`dict`, *optional*):
287
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
288
- `self.processor` in
289
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
290
- return_dict (`bool`, *optional*, defaults to `True`):
291
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
292
- tuple.
293
-
294
- Returns:
295
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
296
- `tuple` where the first element is the sample tensor.
297
- """
298
- if joint_attention_kwargs is not None:
299
- joint_attention_kwargs = joint_attention_kwargs.copy()
300
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
301
- else:
302
- lora_scale = 1.0
303
-
304
- if USE_PEFT_BACKEND:
305
- # weight the lora layers by setting `lora_scale` for each PEFT layer
306
- scale_lora_layers(self, lora_scale)
307
- else:
308
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
309
- logger.warning(
310
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
311
- )
312
-
313
- height, width = hidden_states.shape[-2:]
314
-
315
- hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
316
- temb = self.time_text_embed(timestep, pooled_projections)
317
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
318
-
319
- for index_block, block in enumerate(self.transformer_blocks):
320
- if self.training and self.gradient_checkpointing:
321
-
322
- def create_custom_forward(module, return_dict=None):
323
- def custom_forward(*inputs):
324
- if return_dict is not None:
325
- return module(*inputs, return_dict=return_dict)
326
- else:
327
- return module(*inputs)
328
-
329
- return custom_forward
330
-
331
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
332
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
333
- create_custom_forward(block),
334
- hidden_states,
335
- encoder_hidden_states,
336
- temb,
337
- joint_attention_kwargs,
338
- **ckpt_kwargs,
339
- )
340
-
341
- else:
342
- encoder_hidden_states, hidden_states = block(
343
- hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb,
344
- joint_attention_kwargs=joint_attention_kwargs,
345
- )
346
-
347
- # controlnet residual
348
- if block_controlnet_hidden_states is not None and block.context_pre_only is False:
349
- interval_control = len(self.transformer_blocks) // len(block_controlnet_hidden_states)
350
- hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control]
351
-
352
- hidden_states = self.norm_out(hidden_states, temb)
353
- hidden_states = self.proj_out(hidden_states)
354
-
355
- # unpatchify
356
- patch_size = self.config.patch_size
357
- height = height // patch_size
358
- width = width // patch_size
359
-
360
- hidden_states = hidden_states.reshape(
361
- shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
362
- )
363
- hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
364
- output = hidden_states.reshape(
365
- shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
366
- )
367
-
368
- if USE_PEFT_BACKEND:
369
- # remove `lora_scale` from each PEFT layer
370
- unscale_lora_layers(self, lora_scale)
371
-
372
- if not return_dict:
373
- return (output,)
374
-
375
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pipeline_stable_diffusion_3_ipa.py DELETED
@@ -1,1235 +0,0 @@
1
- # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import inspect
16
- from typing import Any, Callable, Dict, List, Optional, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from transformers import (
22
- CLIPTextModelWithProjection,
23
- CLIPTokenizer,
24
- T5EncoderModel,
25
- T5TokenizerFast,
26
- )
27
-
28
- from diffusers.image_processor import VaeImageProcessor
29
- from diffusers.loaders import FromSingleFileMixin, SD3LoraLoaderMixin
30
- from diffusers.models.autoencoders import AutoencoderKL
31
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
32
- from diffusers.utils import (
33
- USE_PEFT_BACKEND,
34
- is_torch_xla_available,
35
- logging,
36
- replace_example_docstring,
37
- scale_lora_layers,
38
- unscale_lora_layers,
39
- )
40
- from diffusers.utils.torch_utils import randn_tensor
41
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
42
- from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
43
-
44
- from models_resampler import TimeResampler
45
- from models_transformer_sd3 import SD3Transformer2DModel
46
- from diffusers.models.normalization import RMSNorm
47
- from einops import rearrange
48
-
49
-
50
- if is_torch_xla_available():
51
- import torch_xla.core.xla_model as xm
52
-
53
- XLA_AVAILABLE = True
54
- else:
55
- XLA_AVAILABLE = False
56
-
57
-
58
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
59
-
60
- EXAMPLE_DOC_STRING = """
61
- Examples:
62
- ```py
63
- >>> import torch
64
- >>> from diffusers import StableDiffusion3Pipeline
65
-
66
- >>> pipe = StableDiffusion3Pipeline.from_pretrained(
67
- ... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
68
- ... )
69
- >>> pipe.to("cuda")
70
- >>> prompt = "A cat holding a sign that says hello world"
71
- >>> image = pipe(prompt).images[0]
72
- >>> image.save("sd3.png")
73
- ```
74
- """
75
-
76
-
77
- class AdaLayerNorm(nn.Module):
78
- """
79
- Norm layer adaptive layer norm zero (adaLN-Zero).
80
-
81
- Parameters:
82
- embedding_dim (`int`): The size of each embedding vector.
83
- num_embeddings (`int`): The size of the embeddings dictionary.
84
- """
85
-
86
- def __init__(self, embedding_dim: int, time_embedding_dim=None, mode='normal'):
87
- super().__init__()
88
-
89
- self.silu = nn.SiLU()
90
- num_params_dict = dict(
91
- zero=6,
92
- normal=2,
93
- )
94
- num_params = num_params_dict[mode]
95
- self.linear = nn.Linear(time_embedding_dim or embedding_dim, num_params * embedding_dim, bias=True)
96
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
97
- self.mode = mode
98
-
99
- def forward(
100
- self,
101
- x,
102
- hidden_dtype = None,
103
- emb = None,
104
- ):
105
- emb = self.linear(self.silu(emb))
106
- if self.mode == 'normal':
107
- shift_msa, scale_msa = emb.chunk(2, dim=1)
108
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
109
- return x
110
-
111
- elif self.mode == 'zero':
112
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
113
- x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
114
- return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
115
-
116
-
117
- class JointIPAttnProcessor(torch.nn.Module):
118
- """Attention processor used typically in processing the SD3-like self-attention projections."""
119
-
120
- def __init__(
121
- self,
122
- hidden_size=None,
123
- cross_attention_dim=None,
124
- ip_hidden_states_dim=None,
125
- ip_encoder_hidden_states_dim=None,
126
- head_dim=None,
127
- timesteps_emb_dim=1280,
128
- ):
129
- super().__init__()
130
-
131
- self.norm_ip = AdaLayerNorm(ip_hidden_states_dim, time_embedding_dim=timesteps_emb_dim)
132
- self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
133
- self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False)
134
- self.norm_q = RMSNorm(head_dim, 1e-6)
135
- self.norm_k = RMSNorm(head_dim, 1e-6)
136
- self.norm_ip_k = RMSNorm(head_dim, 1e-6)
137
-
138
-
139
- def __call__(
140
- self,
141
- attn,
142
- hidden_states: torch.FloatTensor,
143
- encoder_hidden_states: torch.FloatTensor = None,
144
- attention_mask: Optional[torch.FloatTensor] = None,
145
- emb_dict=None,
146
- *args,
147
- **kwargs,
148
- ) -> torch.FloatTensor:
149
- residual = hidden_states
150
-
151
- batch_size = hidden_states.shape[0]
152
-
153
- # `sample` projections.
154
- query = attn.to_q(hidden_states)
155
- key = attn.to_k(hidden_states)
156
- value = attn.to_v(hidden_states)
157
- img_query = query
158
- img_key = key
159
- img_value = value
160
-
161
- inner_dim = key.shape[-1]
162
- head_dim = inner_dim // attn.heads
163
-
164
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
166
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
-
168
- if attn.norm_q is not None:
169
- query = attn.norm_q(query)
170
- if attn.norm_k is not None:
171
- key = attn.norm_k(key)
172
-
173
- # `context` projections.
174
- if encoder_hidden_states is not None:
175
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
176
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
177
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
178
-
179
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
180
- batch_size, -1, attn.heads, head_dim
181
- ).transpose(1, 2)
182
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
183
- batch_size, -1, attn.heads, head_dim
184
- ).transpose(1, 2)
185
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
186
- batch_size, -1, attn.heads, head_dim
187
- ).transpose(1, 2)
188
-
189
- if attn.norm_added_q is not None:
190
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
191
- if attn.norm_added_k is not None:
192
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
193
-
194
- query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
195
- key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
196
- value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
197
-
198
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
199
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
200
- hidden_states = hidden_states.to(query.dtype)
201
-
202
- if encoder_hidden_states is not None:
203
- # Split the attention outputs.
204
- hidden_states, encoder_hidden_states = (
205
- hidden_states[:, : residual.shape[1]],
206
- hidden_states[:, residual.shape[1] :],
207
- )
208
- if not attn.context_pre_only:
209
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
210
-
211
-
212
- # IPadapter
213
- ip_hidden_states = emb_dict.get('ip_hidden_states', None)
214
- ip_hidden_states = self.get_ip_hidden_states(
215
- attn,
216
- img_query,
217
- ip_hidden_states,
218
- img_key,
219
- img_value,
220
- None,
221
- None,
222
- emb_dict['temb'],
223
- )
224
- if ip_hidden_states is not None:
225
- hidden_states = hidden_states + ip_hidden_states * emb_dict.get('scale', 1.0)
226
-
227
-
228
- # linear proj
229
- hidden_states = attn.to_out[0](hidden_states)
230
- # dropout
231
- hidden_states = attn.to_out[1](hidden_states)
232
-
233
- if encoder_hidden_states is not None:
234
- return hidden_states, encoder_hidden_states
235
- else:
236
- return hidden_states
237
-
238
-
239
- def get_ip_hidden_states(self, attn, query, ip_hidden_states, img_key=None, img_value=None, text_key=None, text_value=None, temb=None):
240
- if ip_hidden_states is None:
241
- return None
242
-
243
- if not hasattr(self, 'to_k_ip') or not hasattr(self, 'to_v_ip'):
244
- return None
245
-
246
- # norm ip input
247
- norm_ip_hidden_states = self.norm_ip(ip_hidden_states, emb=temb)
248
-
249
- # to k and v
250
- ip_key = self.to_k_ip(norm_ip_hidden_states)
251
- ip_value = self.to_v_ip(norm_ip_hidden_states)
252
-
253
- # reshape
254
- query = rearrange(query, 'b l (h d) -> b h l d', h=attn.heads)
255
- img_key = rearrange(img_key, 'b l (h d) -> b h l d', h=attn.heads)
256
- img_value = rearrange(img_value, 'b l (h d) -> b h l d', h=attn.heads)
257
- ip_key = rearrange(ip_key, 'b l (h d) -> b h l d', h=attn.heads)
258
- ip_value = rearrange(ip_value, 'b l (h d) -> b h l d', h=attn.heads)
259
-
260
- # norm
261
- query = self.norm_q(query)
262
- img_key = self.norm_k(img_key)
263
- ip_key = self.norm_ip_k(ip_key)
264
-
265
- # cat img
266
- key = torch.cat([img_key, ip_key], dim=2)
267
- value = torch.cat([img_value, ip_value], dim=2)
268
-
269
- #
270
- ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
271
- ip_hidden_states = rearrange(ip_hidden_states, 'b h l d -> b l (h d)')
272
- ip_hidden_states = ip_hidden_states.to(query.dtype)
273
- return ip_hidden_states
274
-
275
-
276
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
277
- def retrieve_timesteps(
278
- scheduler,
279
- num_inference_steps: Optional[int] = None,
280
- device: Optional[Union[str, torch.device]] = None,
281
- timesteps: Optional[List[int]] = None,
282
- sigmas: Optional[List[float]] = None,
283
- **kwargs,
284
- ):
285
- """
286
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
287
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
288
-
289
- Args:
290
- scheduler (`SchedulerMixin`):
291
- The scheduler to get timesteps from.
292
- num_inference_steps (`int`):
293
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
294
- must be `None`.
295
- device (`str` or `torch.device`, *optional*):
296
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
297
- timesteps (`List[int]`, *optional*):
298
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
299
- `num_inference_steps` and `sigmas` must be `None`.
300
- sigmas (`List[float]`, *optional*):
301
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
302
- `num_inference_steps` and `timesteps` must be `None`.
303
-
304
- Returns:
305
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
306
- second element is the number of inference steps.
307
- """
308
- if timesteps is not None and sigmas is not None:
309
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
310
- if timesteps is not None:
311
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
312
- if not accepts_timesteps:
313
- raise ValueError(
314
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
315
- f" timestep schedules. Please check whether you are using the correct scheduler."
316
- )
317
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
318
- timesteps = scheduler.timesteps
319
- num_inference_steps = len(timesteps)
320
- elif sigmas is not None:
321
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
322
- if not accept_sigmas:
323
- raise ValueError(
324
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
325
- f" sigmas schedules. Please check whether you are using the correct scheduler."
326
- )
327
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
328
- timesteps = scheduler.timesteps
329
- num_inference_steps = len(timesteps)
330
- else:
331
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
332
- timesteps = scheduler.timesteps
333
- return timesteps, num_inference_steps
334
-
335
-
336
- class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
337
- r"""
338
- Args:
339
- transformer ([`SD3Transformer2DModel`]):
340
- Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
341
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
342
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
343
- vae ([`AutoencoderKL`]):
344
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
345
- text_encoder ([`CLIPTextModelWithProjection`]):
346
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
347
- specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
348
- with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
349
- as its dimension.
350
- text_encoder_2 ([`CLIPTextModelWithProjection`]):
351
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
352
- specifically the
353
- [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
354
- variant.
355
- text_encoder_3 ([`T5EncoderModel`]):
356
- Frozen text-encoder. Stable Diffusion 3 uses
357
- [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
358
- [t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
359
- tokenizer (`CLIPTokenizer`):
360
- Tokenizer of class
361
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
362
- tokenizer_2 (`CLIPTokenizer`):
363
- Second Tokenizer of class
364
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
365
- tokenizer_3 (`T5TokenizerFast`):
366
- Tokenizer of class
367
- [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
368
- """
369
-
370
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
371
- _optional_components = []
372
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
373
-
374
- def __init__(
375
- self,
376
- transformer: SD3Transformer2DModel,
377
- scheduler: FlowMatchEulerDiscreteScheduler,
378
- vae: AutoencoderKL,
379
- text_encoder: CLIPTextModelWithProjection,
380
- tokenizer: CLIPTokenizer,
381
- text_encoder_2: CLIPTextModelWithProjection,
382
- tokenizer_2: CLIPTokenizer,
383
- text_encoder_3: T5EncoderModel,
384
- tokenizer_3: T5TokenizerFast,
385
- ):
386
- super().__init__()
387
-
388
- self.register_modules(
389
- vae=vae,
390
- text_encoder=text_encoder,
391
- text_encoder_2=text_encoder_2,
392
- text_encoder_3=text_encoder_3,
393
- tokenizer=tokenizer,
394
- tokenizer_2=tokenizer_2,
395
- tokenizer_3=tokenizer_3,
396
- transformer=transformer,
397
- scheduler=scheduler,
398
- )
399
- self.vae_scale_factor = (
400
- 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
401
- )
402
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
403
- self.tokenizer_max_length = (
404
- self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
405
- )
406
- self.default_sample_size = (
407
- self.transformer.config.sample_size
408
- if hasattr(self, "transformer") and self.transformer is not None
409
- else 128
410
- )
411
-
412
- def _get_t5_prompt_embeds(
413
- self,
414
- prompt: Union[str, List[str]] = None,
415
- num_images_per_prompt: int = 1,
416
- max_sequence_length: int = 256,
417
- device: Optional[torch.device] = None,
418
- dtype: Optional[torch.dtype] = None,
419
- ):
420
- device = device or self._execution_device
421
- dtype = dtype or self.text_encoder.dtype
422
-
423
- prompt = [prompt] if isinstance(prompt, str) else prompt
424
- batch_size = len(prompt)
425
-
426
- if self.text_encoder_3 is None:
427
- return torch.zeros(
428
- (
429
- batch_size * num_images_per_prompt,
430
- self.tokenizer_max_length,
431
- self.transformer.config.joint_attention_dim,
432
- ),
433
- device=device,
434
- dtype=dtype,
435
- )
436
-
437
- text_inputs = self.tokenizer_3(
438
- prompt,
439
- padding="max_length",
440
- max_length=max_sequence_length,
441
- truncation=True,
442
- add_special_tokens=True,
443
- return_tensors="pt",
444
- )
445
- text_input_ids = text_inputs.input_ids
446
- untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
447
-
448
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
449
- removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
450
- logger.warning(
451
- "The following part of your input was truncated because `max_sequence_length` is set to "
452
- f" {max_sequence_length} tokens: {removed_text}"
453
- )
454
-
455
- prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
456
-
457
- dtype = self.text_encoder_3.dtype
458
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
459
-
460
- _, seq_len, _ = prompt_embeds.shape
461
-
462
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
463
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
464
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465
-
466
- return prompt_embeds
467
-
468
- def _get_clip_prompt_embeds(
469
- self,
470
- prompt: Union[str, List[str]],
471
- num_images_per_prompt: int = 1,
472
- device: Optional[torch.device] = None,
473
- clip_skip: Optional[int] = None,
474
- clip_model_index: int = 0,
475
- ):
476
- device = device or self._execution_device
477
-
478
- clip_tokenizers = [self.tokenizer, self.tokenizer_2]
479
- clip_text_encoders = [self.text_encoder, self.text_encoder_2]
480
-
481
- tokenizer = clip_tokenizers[clip_model_index]
482
- text_encoder = clip_text_encoders[clip_model_index]
483
-
484
- prompt = [prompt] if isinstance(prompt, str) else prompt
485
- batch_size = len(prompt)
486
-
487
- text_inputs = tokenizer(
488
- prompt,
489
- padding="max_length",
490
- max_length=self.tokenizer_max_length,
491
- truncation=True,
492
- return_tensors="pt",
493
- )
494
-
495
- text_input_ids = text_inputs.input_ids
496
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
497
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
498
- removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
499
- logger.warning(
500
- "The following part of your input was truncated because CLIP can only handle sequences up to"
501
- f" {self.tokenizer_max_length} tokens: {removed_text}"
502
- )
503
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
504
- pooled_prompt_embeds = prompt_embeds[0]
505
-
506
- if clip_skip is None:
507
- prompt_embeds = prompt_embeds.hidden_states[-2]
508
- else:
509
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
510
-
511
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
512
-
513
- _, seq_len, _ = prompt_embeds.shape
514
- # duplicate text embeddings for each generation per prompt, using mps friendly method
515
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
516
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
517
-
518
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
519
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
520
-
521
- return prompt_embeds, pooled_prompt_embeds
522
-
523
- def encode_prompt(
524
- self,
525
- prompt: Union[str, List[str]],
526
- prompt_2: Union[str, List[str]],
527
- prompt_3: Union[str, List[str]],
528
- device: Optional[torch.device] = None,
529
- num_images_per_prompt: int = 1,
530
- do_classifier_free_guidance: bool = True,
531
- negative_prompt: Optional[Union[str, List[str]]] = None,
532
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
533
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
534
- prompt_embeds: Optional[torch.FloatTensor] = None,
535
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
536
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
537
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
538
- clip_skip: Optional[int] = None,
539
- max_sequence_length: int = 256,
540
- lora_scale: Optional[float] = None,
541
- ):
542
- r"""
543
-
544
- Args:
545
- prompt (`str` or `List[str]`, *optional*):
546
- prompt to be encoded
547
- prompt_2 (`str` or `List[str]`, *optional*):
548
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
549
- used in all text-encoders
550
- prompt_3 (`str` or `List[str]`, *optional*):
551
- The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
552
- used in all text-encoders
553
- device: (`torch.device`):
554
- torch device
555
- num_images_per_prompt (`int`):
556
- number of images that should be generated per prompt
557
- do_classifier_free_guidance (`bool`):
558
- whether to use classifier free guidance or not
559
- negative_prompt (`str` or `List[str]`, *optional*):
560
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
561
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
562
- less than `1`).
563
- negative_prompt_2 (`str` or `List[str]`, *optional*):
564
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
565
- `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
566
- negative_prompt_2 (`str` or `List[str]`, *optional*):
567
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
568
- `text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
569
- prompt_embeds (`torch.FloatTensor`, *optional*):
570
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
571
- provided, text embeddings will be generated from `prompt` input argument.
572
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
573
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
574
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
575
- argument.
576
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
577
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
578
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
579
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
580
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
581
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
582
- input argument.
583
- clip_skip (`int`, *optional*):
584
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
585
- the output of the pre-final layer will be used for computing the prompt embeddings.
586
- lora_scale (`float`, *optional*):
587
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
588
- """
589
- device = device or self._execution_device
590
-
591
- # set lora scale so that monkey patched LoRA
592
- # function of text encoder can correctly access it
593
- if lora_scale is not None and isinstance(self, SD3LoraLoaderMixin):
594
- self._lora_scale = lora_scale
595
-
596
- # dynamically adjust the LoRA scale
597
- if self.text_encoder is not None and USE_PEFT_BACKEND:
598
- scale_lora_layers(self.text_encoder, lora_scale)
599
- if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
600
- scale_lora_layers(self.text_encoder_2, lora_scale)
601
-
602
- prompt = [prompt] if isinstance(prompt, str) else prompt
603
- if prompt is not None:
604
- batch_size = len(prompt)
605
- else:
606
- batch_size = prompt_embeds.shape[0]
607
-
608
- if prompt_embeds is None:
609
- prompt_2 = prompt_2 or prompt
610
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
611
-
612
- prompt_3 = prompt_3 or prompt
613
- prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
614
-
615
- prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
616
- prompt=prompt,
617
- device=device,
618
- num_images_per_prompt=num_images_per_prompt,
619
- clip_skip=clip_skip,
620
- clip_model_index=0,
621
- )
622
- prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
623
- prompt=prompt_2,
624
- device=device,
625
- num_images_per_prompt=num_images_per_prompt,
626
- clip_skip=clip_skip,
627
- clip_model_index=1,
628
- )
629
- clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
630
-
631
- t5_prompt_embed = self._get_t5_prompt_embeds(
632
- prompt=prompt_3,
633
- num_images_per_prompt=num_images_per_prompt,
634
- max_sequence_length=max_sequence_length,
635
- device=device,
636
- )
637
-
638
- clip_prompt_embeds = torch.nn.functional.pad(
639
- clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
640
- )
641
-
642
- prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
643
- pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
644
-
645
- if do_classifier_free_guidance and negative_prompt_embeds is None:
646
- negative_prompt = negative_prompt or ""
647
- negative_prompt_2 = negative_prompt_2 or negative_prompt
648
- negative_prompt_3 = negative_prompt_3 or negative_prompt
649
-
650
- # normalize str to list
651
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
652
- negative_prompt_2 = (
653
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
654
- )
655
- negative_prompt_3 = (
656
- batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
657
- )
658
-
659
- if prompt is not None and type(prompt) is not type(negative_prompt):
660
- raise TypeError(
661
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
662
- f" {type(prompt)}."
663
- )
664
- elif batch_size != len(negative_prompt):
665
- raise ValueError(
666
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
667
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
668
- " the batch size of `prompt`."
669
- )
670
-
671
- negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
672
- negative_prompt,
673
- device=device,
674
- num_images_per_prompt=num_images_per_prompt,
675
- clip_skip=None,
676
- clip_model_index=0,
677
- )
678
- negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
679
- negative_prompt_2,
680
- device=device,
681
- num_images_per_prompt=num_images_per_prompt,
682
- clip_skip=None,
683
- clip_model_index=1,
684
- )
685
- negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
686
-
687
- t5_negative_prompt_embed = self._get_t5_prompt_embeds(
688
- prompt=negative_prompt_3,
689
- num_images_per_prompt=num_images_per_prompt,
690
- max_sequence_length=max_sequence_length,
691
- device=device,
692
- )
693
-
694
- negative_clip_prompt_embeds = torch.nn.functional.pad(
695
- negative_clip_prompt_embeds,
696
- (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
697
- )
698
-
699
- negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
700
- negative_pooled_prompt_embeds = torch.cat(
701
- [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
702
- )
703
-
704
- if self.text_encoder is not None:
705
- if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
706
- # Retrieve the original scale by scaling back the LoRA layers
707
- unscale_lora_layers(self.text_encoder, lora_scale)
708
-
709
- if self.text_encoder_2 is not None:
710
- if isinstance(self, SD3LoraLoaderMixin) and USE_PEFT_BACKEND:
711
- # Retrieve the original scale by scaling back the LoRA layers
712
- unscale_lora_layers(self.text_encoder_2, lora_scale)
713
-
714
- return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
715
-
716
- def check_inputs(
717
- self,
718
- prompt,
719
- prompt_2,
720
- prompt_3,
721
- height,
722
- width,
723
- negative_prompt=None,
724
- negative_prompt_2=None,
725
- negative_prompt_3=None,
726
- prompt_embeds=None,
727
- negative_prompt_embeds=None,
728
- pooled_prompt_embeds=None,
729
- negative_pooled_prompt_embeds=None,
730
- callback_on_step_end_tensor_inputs=None,
731
- max_sequence_length=None,
732
- ):
733
- if height % 8 != 0 or width % 8 != 0:
734
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
735
-
736
- if callback_on_step_end_tensor_inputs is not None and not all(
737
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
738
- ):
739
- raise ValueError(
740
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
741
- )
742
-
743
- if prompt is not None and prompt_embeds is not None:
744
- raise ValueError(
745
- f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
746
- " only forward one of the two."
747
- )
748
- elif prompt_2 is not None and prompt_embeds is not None:
749
- raise ValueError(
750
- f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
751
- " only forward one of the two."
752
- )
753
- elif prompt_3 is not None and prompt_embeds is not None:
754
- raise ValueError(
755
- f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
756
- " only forward one of the two."
757
- )
758
- elif prompt is None and prompt_embeds is None:
759
- raise ValueError(
760
- "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
761
- )
762
- elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
763
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
764
- elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
765
- raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
766
- elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
767
- raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
768
-
769
- if negative_prompt is not None and negative_prompt_embeds is not None:
770
- raise ValueError(
771
- f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
772
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
773
- )
774
- elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
775
- raise ValueError(
776
- f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
777
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
778
- )
779
- elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
780
- raise ValueError(
781
- f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
782
- f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
783
- )
784
-
785
- if prompt_embeds is not None and negative_prompt_embeds is not None:
786
- if prompt_embeds.shape != negative_prompt_embeds.shape:
787
- raise ValueError(
788
- "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
789
- f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
790
- f" {negative_prompt_embeds.shape}."
791
- )
792
-
793
- if prompt_embeds is not None and pooled_prompt_embeds is None:
794
- raise ValueError(
795
- "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
796
- )
797
-
798
- if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
799
- raise ValueError(
800
- "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
801
- )
802
-
803
- if max_sequence_length is not None and max_sequence_length > 512:
804
- raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
805
-
806
- def prepare_latents(
807
- self,
808
- batch_size,
809
- num_channels_latents,
810
- height,
811
- width,
812
- dtype,
813
- device,
814
- generator,
815
- latents=None,
816
- ):
817
- if latents is not None:
818
- return latents.to(device=device, dtype=dtype)
819
-
820
- shape = (
821
- batch_size,
822
- num_channels_latents,
823
- int(height) // self.vae_scale_factor,
824
- int(width) // self.vae_scale_factor,
825
- )
826
-
827
- if isinstance(generator, list) and len(generator) != batch_size:
828
- raise ValueError(
829
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
830
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
831
- )
832
-
833
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
834
-
835
- return latents
836
-
837
- @property
838
- def guidance_scale(self):
839
- return self._guidance_scale
840
-
841
- @property
842
- def clip_skip(self):
843
- return self._clip_skip
844
-
845
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
846
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
847
- # corresponds to doing no classifier free guidance.
848
- @property
849
- def do_classifier_free_guidance(self):
850
- return self._guidance_scale > 1
851
-
852
- @property
853
- def joint_attention_kwargs(self):
854
- return self._joint_attention_kwargs
855
-
856
- @property
857
- def num_timesteps(self):
858
- return self._num_timesteps
859
-
860
- @property
861
- def interrupt(self):
862
- return self._interrupt
863
-
864
-
865
- @torch.inference_mode()
866
- def init_ipadapter(self, ip_adapter_path, image_encoder_path, nb_token, output_dim=2432):
867
- from transformers import SiglipVisionModel, SiglipImageProcessor
868
- state_dict = torch.load(ip_adapter_path, map_location="cpu")
869
-
870
- device, dtype = self.transformer.device, self.transformer.dtype
871
- image_encoder = SiglipVisionModel.from_pretrained(image_encoder_path)
872
- image_processor = SiglipImageProcessor.from_pretrained(image_encoder_path)
873
- image_encoder.eval()
874
- image_encoder.to(device, dtype=dtype)
875
- self.image_encoder = image_encoder
876
- self.clip_image_processor = image_processor
877
-
878
- sample_class = TimeResampler
879
- image_proj_model = sample_class(
880
- dim=1280,
881
- depth=4,
882
- dim_head=64,
883
- heads=20,
884
- num_queries=nb_token,
885
- embedding_dim=1152,
886
- output_dim=output_dim,
887
- ff_mult=4,
888
- timestep_in_dim=320,
889
- timestep_flip_sin_to_cos=True,
890
- timestep_freq_shift=0,
891
- )
892
- image_proj_model.eval()
893
- image_proj_model.to(device, dtype=dtype)
894
- key_name = image_proj_model.load_state_dict(state_dict["image_proj"], strict=False)
895
- print(f"=> loading image_proj_model: {key_name}")
896
-
897
- self.image_proj_model = image_proj_model
898
-
899
-
900
- attn_procs = {}
901
- transformer = self.transformer
902
- for idx_name, name in enumerate(transformer.attn_processors.keys()):
903
- hidden_size = transformer.config.attention_head_dim * transformer.config.num_attention_heads
904
- ip_hidden_states_dim = transformer.config.attention_head_dim * transformer.config.num_attention_heads
905
- ip_encoder_hidden_states_dim = transformer.config.caption_projection_dim
906
-
907
- attn_procs[name] = JointIPAttnProcessor(
908
- hidden_size=hidden_size,
909
- cross_attention_dim=transformer.config.caption_projection_dim,
910
- ip_hidden_states_dim=ip_hidden_states_dim,
911
- ip_encoder_hidden_states_dim=ip_encoder_hidden_states_dim,
912
- head_dim=transformer.config.attention_head_dim,
913
- timesteps_emb_dim=1280,
914
- ).to(device, dtype=dtype)
915
-
916
- self.transformer.set_attn_processor(attn_procs)
917
- tmp_ip_layers = torch.nn.ModuleList(self.transformer.attn_processors.values())
918
-
919
- key_name = tmp_ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
920
- print(f"=> loading ip_adapter: {key_name}")
921
-
922
-
923
- @torch.inference_mode()
924
- def encode_clip_image_emb(self, clip_image, device, dtype):
925
-
926
- # clip
927
- clip_image_tensor = self.clip_image_processor(images=clip_image, return_tensors="pt").pixel_values
928
- clip_image_tensor = clip_image_tensor.to(device, dtype=dtype)
929
- clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
930
- clip_image_embeds = torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
931
-
932
- return clip_image_embeds
933
-
934
-
935
-
936
- @torch.no_grad()
937
- @replace_example_docstring(EXAMPLE_DOC_STRING)
938
- def __call__(
939
- self,
940
- prompt: Union[str, List[str]] = None,
941
- prompt_2: Optional[Union[str, List[str]]] = None,
942
- prompt_3: Optional[Union[str, List[str]]] = None,
943
- height: Optional[int] = None,
944
- width: Optional[int] = None,
945
- num_inference_steps: int = 28,
946
- timesteps: List[int] = None,
947
- guidance_scale: float = 7.0,
948
- negative_prompt: Optional[Union[str, List[str]]] = None,
949
- negative_prompt_2: Optional[Union[str, List[str]]] = None,
950
- negative_prompt_3: Optional[Union[str, List[str]]] = None,
951
- num_images_per_prompt: Optional[int] = 1,
952
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
953
- latents: Optional[torch.FloatTensor] = None,
954
- prompt_embeds: Optional[torch.FloatTensor] = None,
955
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
956
- pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
957
- negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
958
- output_type: Optional[str] = "pil",
959
- return_dict: bool = True,
960
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
961
- clip_skip: Optional[int] = None,
962
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
963
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
964
- max_sequence_length: int = 256,
965
-
966
- # ipa
967
- clip_image=None,
968
- ipadapter_scale=1.0,
969
- ):
970
- r"""
971
- Function invoked when calling the pipeline for generation.
972
-
973
- Args:
974
- prompt (`str` or `List[str]`, *optional*):
975
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
976
- instead.
977
- prompt_2 (`str` or `List[str]`, *optional*):
978
- The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
979
- will be used instead
980
- prompt_3 (`str` or `List[str]`, *optional*):
981
- The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
982
- will be used instead
983
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
984
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
985
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
986
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
987
- num_inference_steps (`int`, *optional*, defaults to 50):
988
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
989
- expense of slower inference.
990
- timesteps (`List[int]`, *optional*):
991
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
992
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
993
- passed will be used. Must be in descending order.
994
- guidance_scale (`float`, *optional*, defaults to 7.0):
995
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
996
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
997
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
998
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
999
- usually at the expense of lower image quality.
1000
- negative_prompt (`str` or `List[str]`, *optional*):
1001
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
1002
- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1003
- less than `1`).
1004
- negative_prompt_2 (`str` or `List[str]`, *optional*):
1005
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1006
- `text_encoder_2`. If not defined, `negative_prompt` is used instead
1007
- negative_prompt_3 (`str` or `List[str]`, *optional*):
1008
- The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
1009
- `text_encoder_3`. If not defined, `negative_prompt` is used instead
1010
- num_images_per_prompt (`int`, *optional*, defaults to 1):
1011
- The number of images to generate per prompt.
1012
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1013
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1014
- to make generation deterministic.
1015
- latents (`torch.FloatTensor`, *optional*):
1016
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1017
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1018
- tensor will ge generated by sampling using the supplied random `generator`.
1019
- prompt_embeds (`torch.FloatTensor`, *optional*):
1020
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1021
- provided, text embeddings will be generated from `prompt` input argument.
1022
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1023
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1024
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1025
- argument.
1026
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1027
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1028
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
1029
- negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1030
- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1031
- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1032
- input argument.
1033
- output_type (`str`, *optional*, defaults to `"pil"`):
1034
- The output format of the generate image. Choose between
1035
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1036
- return_dict (`bool`, *optional*, defaults to `True`):
1037
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
1038
- of a plain tuple.
1039
- joint_attention_kwargs (`dict`, *optional*):
1040
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1041
- `self.processor` in
1042
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1043
- callback_on_step_end (`Callable`, *optional*):
1044
- A function that calls at the end of each denoising steps during the inference. The function is called
1045
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1046
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1047
- `callback_on_step_end_tensor_inputs`.
1048
- callback_on_step_end_tensor_inputs (`List`, *optional*):
1049
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1050
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1051
- `._callback_tensor_inputs` attribute of your pipeline class.
1052
- max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
1053
-
1054
- Examples:
1055
-
1056
- Returns:
1057
- [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
1058
- [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
1059
- `tuple`. When returning a tuple, the first element is a list with the generated images.
1060
- """
1061
-
1062
- height = height or self.default_sample_size * self.vae_scale_factor
1063
- width = width or self.default_sample_size * self.vae_scale_factor
1064
-
1065
- # 1. Check inputs. Raise error if not correct
1066
- self.check_inputs(
1067
- prompt,
1068
- prompt_2,
1069
- prompt_3,
1070
- height,
1071
- width,
1072
- negative_prompt=negative_prompt,
1073
- negative_prompt_2=negative_prompt_2,
1074
- negative_prompt_3=negative_prompt_3,
1075
- prompt_embeds=prompt_embeds,
1076
- negative_prompt_embeds=negative_prompt_embeds,
1077
- pooled_prompt_embeds=pooled_prompt_embeds,
1078
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1079
- callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
1080
- max_sequence_length=max_sequence_length,
1081
- )
1082
-
1083
- self._guidance_scale = guidance_scale
1084
- self._clip_skip = clip_skip
1085
- self._joint_attention_kwargs = joint_attention_kwargs
1086
- self._interrupt = False
1087
-
1088
- # 2. Define call parameters
1089
- if prompt is not None and isinstance(prompt, str):
1090
- batch_size = 1
1091
- elif prompt is not None and isinstance(prompt, list):
1092
- batch_size = len(prompt)
1093
- else:
1094
- batch_size = prompt_embeds.shape[0]
1095
-
1096
- device = self._execution_device
1097
- dtype = self.transformer.dtype
1098
-
1099
- lora_scale = (
1100
- self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
1101
- )
1102
- (
1103
- prompt_embeds,
1104
- negative_prompt_embeds,
1105
- pooled_prompt_embeds,
1106
- negative_pooled_prompt_embeds,
1107
- ) = self.encode_prompt(
1108
- prompt=prompt,
1109
- prompt_2=prompt_2,
1110
- prompt_3=prompt_3,
1111
- negative_prompt=negative_prompt,
1112
- negative_prompt_2=negative_prompt_2,
1113
- negative_prompt_3=negative_prompt_3,
1114
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1115
- prompt_embeds=prompt_embeds,
1116
- negative_prompt_embeds=negative_prompt_embeds,
1117
- pooled_prompt_embeds=pooled_prompt_embeds,
1118
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1119
- device=device,
1120
- clip_skip=self.clip_skip,
1121
- num_images_per_prompt=num_images_per_prompt,
1122
- max_sequence_length=max_sequence_length,
1123
- lora_scale=lora_scale,
1124
- )
1125
-
1126
- if self.do_classifier_free_guidance:
1127
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1128
- pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1129
-
1130
- # 3. prepare clip emb
1131
- clip_image = clip_image.resize((max(clip_image.size), max(clip_image.size)))
1132
- clip_image_embeds = self.encode_clip_image_emb(clip_image, device, dtype)
1133
-
1134
- # 4. Prepare timesteps
1135
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1136
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1137
- self._num_timesteps = len(timesteps)
1138
-
1139
- # 5. Prepare latent variables
1140
- num_channels_latents = self.transformer.config.in_channels
1141
- latents = self.prepare_latents(
1142
- batch_size * num_images_per_prompt,
1143
- num_channels_latents,
1144
- height,
1145
- width,
1146
- prompt_embeds.dtype,
1147
- device,
1148
- generator,
1149
- latents,
1150
- )
1151
-
1152
- # 6. Denoising loop
1153
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1154
- for i, t in enumerate(timesteps):
1155
- if self.interrupt:
1156
- continue
1157
-
1158
- # expand the latents if we are doing classifier free guidance
1159
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1160
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1161
- timestep = t.expand(latent_model_input.shape[0])
1162
-
1163
- image_prompt_embeds, timestep_emb = self.image_proj_model(
1164
- clip_image_embeds,
1165
- timestep.to(dtype=latents.dtype),
1166
- need_temb=True
1167
- )
1168
-
1169
- joint_attention_kwargs = dict(
1170
- emb_dict=dict(
1171
- ip_hidden_states=image_prompt_embeds,
1172
- temb=timestep_emb,
1173
- scale=ipadapter_scale,
1174
- )
1175
- )
1176
-
1177
- noise_pred = self.transformer(
1178
- hidden_states=latent_model_input,
1179
- timestep=timestep,
1180
- encoder_hidden_states=prompt_embeds,
1181
- pooled_projections=pooled_prompt_embeds,
1182
- joint_attention_kwargs=joint_attention_kwargs,
1183
- return_dict=False,
1184
- )[0]
1185
-
1186
- # perform guidance
1187
- if self.do_classifier_free_guidance:
1188
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1189
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1190
-
1191
- # compute the previous noisy sample x_t -> x_t-1
1192
- latents_dtype = latents.dtype
1193
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1194
-
1195
- if latents.dtype != latents_dtype:
1196
- if torch.backends.mps.is_available():
1197
- # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1198
- latents = latents.to(latents_dtype)
1199
-
1200
- if callback_on_step_end is not None:
1201
- callback_kwargs = {}
1202
- for k in callback_on_step_end_tensor_inputs:
1203
- callback_kwargs[k] = locals()[k]
1204
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1205
-
1206
- latents = callback_outputs.pop("latents", latents)
1207
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1208
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1209
- negative_pooled_prompt_embeds = callback_outputs.pop(
1210
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1211
- )
1212
-
1213
- # call the callback, if provided
1214
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1215
- progress_bar.update()
1216
-
1217
- if XLA_AVAILABLE:
1218
- xm.mark_step()
1219
-
1220
- if output_type == "latent":
1221
- image = latents
1222
-
1223
- else:
1224
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1225
-
1226
- image = self.vae.decode(latents, return_dict=False)[0]
1227
- image = self.image_processor.postprocess(image, output_type=output_type)
1228
-
1229
- # Offload all models
1230
- self.maybe_free_model_hooks()
1231
-
1232
- if not return_dict:
1233
- return (image,)
1234
-
1235
- return StableDiffusion3PipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,6 +2,4 @@ diffusers
2
  torch
3
  transformers
4
  accelerate
5
- Pillow
6
- einops
7
- sentencepiece
 
2
  torch
3
  transformers
4
  accelerate
5
+ Pillow