Sapir commited on
Commit
cef1afc
·
1 Parent(s): ebaff66

inference working.

Browse files
eval.py CHANGED
@@ -1,31 +1,45 @@
1
  import torch
2
- from vae.causal_video_autoencoder import CausalVideoAutoencoder
3
- from transformer.transformer3d import Trasformer3D
4
  from patchify.symmetric import SymmetricPatchifier
 
 
 
 
5
 
6
 
7
  model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
8
- vae_path = "/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000"
9
  dtype = torch.float32
10
  vae = CausalVideoAutoencoder.from_pretrained(
11
  pretrained_model_name_or_path=vae_local_path,
12
  revision=False,
13
  torch_dtype=torch.bfloat16,
14
  load_in_8bit=False,
15
- )
16
- transformer_config_path = "/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json"
17
- transformer_config = Transformer3D.load_config(config_local_path)
18
- transformer = Transformer3D.from_config(config)
19
- transformer_local_path = "/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.p"
20
  transformer_ckpt_state_dict = torch.load(transformer_local_path)
21
  transformer.load_state_dict(transformer_ckpt_state_dict, True)
 
22
  unet = transformer
23
- scheduler_config_path = "/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json"
24
- scheduler_config = RectifiedFlowScheduler.load_config(config_local_path)
25
- scheduler = RectifiedFlowScheduler.from_config(config)
26
  patchifier = SymmetricPatchifier(patch_size=1)
 
27
 
 
 
 
 
 
 
 
28
 
 
29
 
30
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
31
  safety_checker=None,
@@ -41,13 +55,17 @@ height=512
41
  width=768
42
  num_frames=57
43
  frame_rate=25
44
- sample = {
45
- "prompt_embeds": None, # (B, L, E)
46
- 'prompt_attention_mask': None, # (B , L)
47
- 'negative_prompt_embeds': None,' # (B, L, E)
48
- 'negative_prompt': None,
49
- 'negative_prompt_attention_mask': None # (B , L)
50
- }
 
 
 
 
51
 
52
 
53
 
@@ -64,5 +82,7 @@ images = pipeline(
64
  frame_rate=frame_rate,
65
  **sample,
66
  is_video=True,
67
- vae_per_channel_noramlize=True,
68
- ).images
 
 
 
1
  import torch
2
+ from vae.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
+ from transformer.transformer3d import Transformer3DModel
4
  from patchify.symmetric import SymmetricPatchifier
5
+ from scheduler.rf import RectifiedFlowScheduler
6
+ from pipeline.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
7
+ from pathlib import Path
8
+ from transformers import T5EncoderModel
9
 
10
 
11
  model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
12
+ vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000")
13
  dtype = torch.float32
14
  vae = CausalVideoAutoencoder.from_pretrained(
15
  pretrained_model_name_or_path=vae_local_path,
16
  revision=False,
17
  torch_dtype=torch.bfloat16,
18
  load_in_8bit=False,
19
+ ).cuda()
20
+ transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json")
21
+ transformer_config = Transformer3DModel.load_config(transformer_config_path)
22
+ transformer = Transformer3DModel.from_config(transformer_config)
23
+ transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.pt")
24
  transformer_ckpt_state_dict = torch.load(transformer_local_path)
25
  transformer.load_state_dict(transformer_ckpt_state_dict, True)
26
+ transformer = transformer.cuda()
27
  unet = transformer
28
+ scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json")
29
+ scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
30
+ scheduler = RectifiedFlowScheduler.from_config(scheduler_config)
31
  patchifier = SymmetricPatchifier(patch_size=1)
32
+ # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl")
33
 
34
+ submodel_dict = {
35
+ "unet": unet,
36
+ "transformer": transformer,
37
+ "patchifier": patchifier,
38
+ "text_encoder": None,
39
+ "scheduler": scheduler,
40
+ "vae": vae,
41
 
42
+ }
43
 
44
  pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
45
  safety_checker=None,
 
55
  width=768
56
  num_frames=57
57
  frame_rate=25
58
+ # sample = {
59
+ # "prompt": "A cat", # (B, L, E)
60
+ # 'prompt_attention_mask': None, # (B , L)
61
+ # 'negative_prompt': "Ugly deformed",
62
+ # 'negative_prompt_attention_mask': None # (B , L)
63
+ # }
64
+
65
+ sample = torch.load("/opt/sample.pt")
66
+ for _, item in sample.items():
67
+ if item is not None:
68
+ item = item.cuda()
69
 
70
 
71
 
 
82
  frame_rate=frame_rate,
83
  **sample,
84
  is_video=True,
85
+ vae_per_channel_normalize=True,
86
+ ).images
87
+
88
+ print()
patchify/symmetric.py CHANGED
@@ -6,8 +6,49 @@ from diffusers.configuration_utils import ConfigMixin
6
  from einops import rearrange
7
  from torch import Tensor
8
 
9
- from txt2img.common.torch_utils import append_dims
10
- from txt2img.config.diffusion_parts import PatchifierConfig, PatchifierName
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def pixart_alpha_patchify(
 
6
  from einops import rearrange
7
  from torch import Tensor
8
 
9
+ from utils.torch_utils import append_dims
10
+
11
+
12
+ class Patchifier(ConfigMixin, ABC):
13
+ def __init__(self, patch_size: int):
14
+ super().__init__()
15
+ self._patch_size = (1, patch_size, patch_size)
16
+
17
+ @abstractmethod
18
+ def patchify(self, latents: Tensor, frame_rates: Tensor, scale_grid: bool) -> Tuple[Tensor, Tensor]:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def unpatchify(
23
+ self, latents: Tensor, output_height: int, output_width: int, output_num_frames: int, out_channels: int
24
+ ) -> Tuple[Tensor, Tensor]:
25
+ pass
26
+
27
+ @property
28
+ def patch_size(self):
29
+ return self._patch_size
30
+
31
+ def get_grid(self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device):
32
+ f = orig_num_frames // self._patch_size[0]
33
+ h = orig_height // self._patch_size[1]
34
+ w = orig_width // self._patch_size[2]
35
+ grid_h = torch.arange(h, dtype=torch.float32, device=device)
36
+ grid_w = torch.arange(w, dtype=torch.float32, device=device)
37
+ grid_f = torch.arange(f, dtype=torch.float32, device=device)
38
+ grid = torch.meshgrid(grid_f, grid_h, grid_w)
39
+ grid = torch.stack(grid, dim=0)
40
+ grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
41
+
42
+ if scale_grid is not None:
43
+ for i in range(3):
44
+ if isinstance(scale_grid[i], Tensor):
45
+ scale = append_dims(scale_grid[i], grid.ndim - 1)
46
+ else:
47
+ scale = scale_grid[i]
48
+ grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
49
+
50
+ grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
51
+ return grid
52
 
53
 
54
  def pixart_alpha_patchify(
pipeline/pipeline_video_pixart_alpha.py CHANGED
@@ -5,9 +5,12 @@ import math
5
  import re
6
  import urllib.parse as ul
7
  from typing import Callable, Dict, List, Optional, Tuple, Union
 
 
8
 
9
  import torch
10
  import torch.nn.functional as F
 
11
  from diffusers.image_processor import VaeImageProcessor
12
  from diffusers.models import AutoencoderKL
13
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -24,17 +27,15 @@ from diffusers.utils.torch_utils import randn_tensor
24
  from einops import rearrange
25
  from transformers import T5EncoderModel, T5Tokenizer
26
 
27
- from dataset_metadata.data_field_name import DataFieldName
28
- from txt2img.config.eval import ValLossConfig
29
- from txt2img.diffusers_schedulers.rf_scheduler import TimestepShifter
30
- from txt2img.diffusion.loss.losses import DiffusionLoss
31
- from txt2img.diffusion.models.pixart.transformer_3d import Transformer3DModel
32
- from txt2img.diffusion.patchify import Patchifier
33
- from txt2img.diffusion.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
34
- from txt2img.vae.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
35
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
 
38
  if is_bs4_available():
39
  from bs4 import BeautifulSoup
40
 
@@ -581,79 +582,8 @@ class VideoPixArtAlphaPipeline(DiffusionPipeline):
581
 
582
  return samples
583
 
584
- @torch.no_grad()
585
- def calculate_val_loss(
586
- self,
587
- batch: Dict[str, torch.Tensor],
588
- loss_obj: DiffusionLoss,
589
- val_loss_config: ValLossConfig,
590
- vae_per_channel_normalize: bool,
591
- ) -> torch.Tensor:
592
- if DataFieldName.VIDEO in batch:
593
- media_items = batch[DataFieldName.VIDEO]
594
- else:
595
- media_items = batch[DataFieldName.IMAGE]
596
- media_items = media_items.to(dtype=self.vae.dtype)
597
-
598
- if DataFieldName.VIDEO_AVERAGE_FPS in batch:
599
- frame_rates = batch[DataFieldName.VIDEO_AVERAGE_FPS]
600
- else:
601
- frame_rates = torch.ones(media_items.shape[0], 1, device=media_items.device) * 25.0
602
- frame_rates = frame_rates / self.video_scale_factor
603
-
604
- if DataFieldName.T5_EMBEDDING in batch:
605
- prompt_embeds = batch[DataFieldName.T5_EMBEDDING].to(dtype=self.transformer.dtype)
606
- prompt_attn_mask = batch[DataFieldName.T5_EMBEDDING_MASK]
607
-
608
- else:
609
- text = batch[DataFieldName.CAPTION]
610
- prompt_embeds, prompt_attn_mask, _, _ = self.encode_prompt(text)
611
-
612
- latents = vae_encode(media_items, self.vae, vae_per_channel_normalize=vae_per_channel_normalize).float()
613
- b, _, f, h, w = latents.shape
614
- if self.patchifier:
615
- scale_grid = (
616
- (1 / frame_rates, self.vae_scale_factor, self.vae_scale_factor) if self.transformer.use_rope else None
617
- )
618
- indices_grid = self.patchifier.get_grid(
619
- orig_num_frames=f,
620
- orig_height=h,
621
- orig_width=w,
622
- batch_size=b,
623
- scale_grid=scale_grid,
624
- device=self.device,
625
- )
626
- latents = self.patchifier.patchify(latents=latents)
627
-
628
- noise = torch.randn_like(latents)
629
- noise_cond = torch.linspace(val_loss_config.min_step, val_loss_config.max_step, b, device=latents.device)
630
-
631
- if isinstance(self.scheduler, TimestepShifter):
632
- noise_cond = self.scheduler.shift_timesteps(latents, noise_cond)
633
-
634
- noise_cond = noise_cond[:, None]
635
- noisy_latents = self.scheduler.add_noise(latents, noise, noise_cond)
636
-
637
- pred_mean = self.transformer(
638
- hidden_states=noisy_latents.to(self.transformer.dtype),
639
- timestep=noise_cond,
640
- encoder_hidden_states=prompt_embeds,
641
- encoder_attention_mask=prompt_attn_mask,
642
- indices_grid=indices_grid,
643
- ).sample.float()
644
-
645
- loss = loss_obj(
646
- pred_mean=pred_mean,
647
- x_start=latents,
648
- noise=noise,
649
- x_t=noisy_latents,
650
- noise_cond=noise_cond,
651
- )
652
-
653
- return loss
654
 
655
  @torch.no_grad()
656
- @replace_example_docstring(EXAMPLE_DOC_STRING)
657
  def __call__(
658
  self,
659
  height: int,
 
5
  import re
6
  import urllib.parse as ul
7
  from typing import Callable, Dict, List, Optional, Tuple, Union
8
+ from abc import ABC, abstractmethod
9
+
10
 
11
  import torch
12
  import torch.nn.functional as F
13
+ from torch import Tensor
14
  from diffusers.image_processor import VaeImageProcessor
15
  from diffusers.models import AutoencoderKL
16
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
27
  from einops import rearrange
28
  from transformers import T5EncoderModel, T5Tokenizer
29
 
30
+ from transformer.transformer3d import Transformer3DModel
31
+ from patchify.symmetric import Patchifier
32
+ from vae.vae_encode import get_vae_size_scale_factor, vae_decode, vae_encode
33
+ from vae.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
34
+ from scheduler.rf import TimestepShifter
 
 
 
35
 
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
 
38
+
39
  if is_bs4_available():
40
  from bs4 import BeautifulSoup
41
 
 
582
 
583
  return samples
584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
 
586
  @torch.no_grad()
 
587
  def __call__(
588
  self,
589
  height: int,
scheduler/rf.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
- from txt2img.common.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
 
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
+ from utils.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
transformer/attention.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from importlib import import_module
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
8
+ from diffusers.models.attention import _chunked_feed_forward
9
+ from diffusers.models.attention_processor import (
10
+ LoRAAttnAddedKVProcessor,
11
+ LoRAAttnProcessor,
12
+ LoRAAttnProcessor2_0,
13
+ LoRAXFormersAttnProcessor,
14
+ SpatialNorm,
15
+ )
16
+ from diffusers.models.lora import LoRACompatibleLinear
17
+ from diffusers.models.normalization import RMSNorm
18
+ from diffusers.utils import deprecate, logging
19
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
+ from einops import rearrange
21
+ from torch import nn
22
+
23
+ # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+
28
+ @maybe_allow_in_graph
29
+ class BasicTransformerBlock(nn.Module):
30
+ r"""
31
+ A basic Transformer block.
32
+
33
+ Parameters:
34
+ dim (`int`): The number of channels in the input and output.
35
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
36
+ attention_head_dim (`int`): The number of channels in each head.
37
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
38
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
39
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40
+ num_embeds_ada_norm (:
41
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
42
+ attention_bias (:
43
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
44
+ only_cross_attention (`bool`, *optional*):
45
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
46
+ double_self_attention (`bool`, *optional*):
47
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
48
+ upcast_attention (`bool`, *optional*):
49
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
50
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
51
+ Whether to use learnable elementwise affine parameters for normalization.
52
+ qk_norm (`str`, *optional*, defaults to None):
53
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
54
+ adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
55
+ The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
56
+ standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
57
+ The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
58
+ final_dropout (`bool` *optional*, defaults to False):
59
+ Whether to apply a final dropout after the last feed-forward layer.
60
+ attention_type (`str`, *optional*, defaults to `"default"`):
61
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
62
+ positional_embeddings (`str`, *optional*, defaults to `None`):
63
+ The type of positional embeddings to apply to.
64
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
65
+ The maximum number of positional embeddings to apply.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ dim: int,
71
+ num_attention_heads: int,
72
+ attention_head_dim: int,
73
+ dropout=0.0,
74
+ cross_attention_dim: Optional[int] = None,
75
+ activation_fn: str = "geglu",
76
+ num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
77
+ attention_bias: bool = False,
78
+ only_cross_attention: bool = False,
79
+ double_self_attention: bool = False,
80
+ upcast_attention: bool = False,
81
+ norm_elementwise_affine: bool = True,
82
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
83
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
84
+ norm_eps: float = 1e-5,
85
+ qk_norm: Optional[str] = None,
86
+ final_dropout: bool = False,
87
+ attention_type: str = "default", # pylint: disable=unused-argument
88
+ ff_inner_dim: Optional[int] = None,
89
+ ff_bias: bool = True,
90
+ attention_out_bias: bool = True,
91
+ use_tpu_flash_attention: bool = False,
92
+ use_rope: bool = False,
93
+ ):
94
+ super().__init__()
95
+ self.only_cross_attention = only_cross_attention
96
+ self.use_tpu_flash_attention = use_tpu_flash_attention
97
+ self.adaptive_norm = adaptive_norm
98
+
99
+ assert standardization_norm in ["layer_norm", "rms_norm"]
100
+ assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
101
+
102
+ make_norm_layer = nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
103
+
104
+ # Define 3 blocks. Each block has its own normalization layer.
105
+ # 1. Self-Attn
106
+ self.norm1 = make_norm_layer(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
107
+
108
+ self.attn1 = Attention(
109
+ query_dim=dim,
110
+ heads=num_attention_heads,
111
+ dim_head=attention_head_dim,
112
+ dropout=dropout,
113
+ bias=attention_bias,
114
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
115
+ upcast_attention=upcast_attention,
116
+ out_bias=attention_out_bias,
117
+ use_tpu_flash_attention=use_tpu_flash_attention,
118
+ qk_norm=qk_norm,
119
+ use_rope=use_rope,
120
+ )
121
+
122
+ # 2. Cross-Attn
123
+ if cross_attention_dim is not None or double_self_attention:
124
+ self.attn2 = Attention(
125
+ query_dim=dim,
126
+ cross_attention_dim=cross_attention_dim if not double_self_attention else None,
127
+ heads=num_attention_heads,
128
+ dim_head=attention_head_dim,
129
+ dropout=dropout,
130
+ bias=attention_bias,
131
+ upcast_attention=upcast_attention,
132
+ out_bias=attention_out_bias,
133
+ use_tpu_flash_attention=use_tpu_flash_attention,
134
+ qk_norm=qk_norm,
135
+ use_rope=use_rope,
136
+ ) # is self-attn if encoder_hidden_states is none
137
+
138
+ if adaptive_norm == "none":
139
+ self.attn2_norm = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
140
+ else:
141
+ self.attn2 = None
142
+ self.attn2_norm = None
143
+
144
+ self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
145
+
146
+ # 3. Feed-forward
147
+ self.ff = FeedForward(
148
+ dim,
149
+ dropout=dropout,
150
+ activation_fn=activation_fn,
151
+ final_dropout=final_dropout,
152
+ inner_dim=ff_inner_dim,
153
+ bias=ff_bias,
154
+ )
155
+
156
+ # 5. Scale-shift for PixArt-Alpha.
157
+ if adaptive_norm != "none":
158
+ num_ada_params = 4 if adaptive_norm == "single_scale" else 6
159
+ self.scale_shift_table = nn.Parameter(torch.randn(num_ada_params, dim) / dim**0.5)
160
+
161
+ # let chunk size default to None
162
+ self._chunk_size = None
163
+ self._chunk_dim = 0
164
+
165
+
166
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
167
+ # Sets chunk feed-forward
168
+ self._chunk_size = chunk_size
169
+ self._chunk_dim = dim
170
+
171
+ def forward(
172
+ self,
173
+ hidden_states: torch.FloatTensor,
174
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
175
+ attention_mask: Optional[torch.FloatTensor] = None,
176
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
177
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
178
+ timestep: Optional[torch.LongTensor] = None,
179
+ cross_attention_kwargs: Dict[str, Any] = None,
180
+ class_labels: Optional[torch.LongTensor] = None,
181
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
182
+ ) -> torch.FloatTensor:
183
+ if cross_attention_kwargs is not None:
184
+ if cross_attention_kwargs.get("scale", None) is not None:
185
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
186
+
187
+ # Notice that normalization is always applied before the real computation in the following blocks.
188
+ # 0. Self-Attention
189
+ batch_size = hidden_states.shape[0]
190
+
191
+ norm_hidden_states = self.norm1(hidden_states)
192
+
193
+ # Apply ada_norm_single
194
+ if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
195
+ assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
196
+ num_ada_params = self.scale_shift_table.shape[0]
197
+ ada_values = self.scale_shift_table[None, None] + timestep.reshape(
198
+ batch_size, timestep.shape[1], num_ada_params, -1
199
+ )
200
+ if self.adaptive_norm == "single_scale_shift":
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
202
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
203
+ else:
204
+ scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa)
206
+ elif self.adaptive_norm == "none":
207
+ scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
208
+ else:
209
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
210
+
211
+ norm_hidden_states = norm_hidden_states.squeeze(1) # TODO: Check if this is needed
212
+
213
+ # 1. Prepare GLIGEN inputs
214
+ cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
215
+
216
+ attn_output = self.attn1(
217
+ norm_hidden_states,
218
+ freqs_cis=freqs_cis,
219
+ encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
220
+ attention_mask=attention_mask,
221
+ **cross_attention_kwargs,
222
+ )
223
+ if gate_msa is not None:
224
+ attn_output = gate_msa * attn_output
225
+
226
+ hidden_states = attn_output + hidden_states
227
+ if hidden_states.ndim == 4:
228
+ hidden_states = hidden_states.squeeze(1)
229
+
230
+ # 3. Cross-Attention
231
+ if self.attn2 is not None:
232
+ if self.adaptive_norm == "none":
233
+ attn_input = self.attn2_norm(hidden_states)
234
+ else:
235
+ attn_input = hidden_states
236
+ attn_output = self.attn2(
237
+ attn_input,
238
+ freqs_cis=freqs_cis,
239
+ encoder_hidden_states=encoder_hidden_states,
240
+ attention_mask=encoder_attention_mask,
241
+ **cross_attention_kwargs,
242
+ )
243
+ hidden_states = attn_output + hidden_states
244
+
245
+ # 4. Feed-forward
246
+ norm_hidden_states = self.norm2(hidden_states)
247
+ if self.adaptive_norm == "single_scale_shift":
248
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
249
+ elif self.adaptive_norm == "single_scale":
250
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
251
+ elif self.adaptive_norm == "none":
252
+ pass
253
+ else:
254
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
255
+
256
+ if self._chunk_size is not None:
257
+ # "feed_forward_chunk_size" can be used to save memory
258
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
259
+ else:
260
+ ff_output = self.ff(norm_hidden_states)
261
+ if gate_mlp is not None:
262
+ ff_output = gate_mlp * ff_output
263
+
264
+ hidden_states = ff_output + hidden_states
265
+ if hidden_states.ndim == 4:
266
+ hidden_states = hidden_states.squeeze(1)
267
+
268
+ return hidden_states
269
+
270
+
271
+ @maybe_allow_in_graph
272
+ class Attention(nn.Module):
273
+ r"""
274
+ A cross attention layer.
275
+
276
+ Parameters:
277
+ query_dim (`int`):
278
+ The number of channels in the query.
279
+ cross_attention_dim (`int`, *optional*):
280
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
281
+ heads (`int`, *optional*, defaults to 8):
282
+ The number of heads to use for multi-head attention.
283
+ dim_head (`int`, *optional*, defaults to 64):
284
+ The number of channels in each head.
285
+ dropout (`float`, *optional*, defaults to 0.0):
286
+ The dropout probability to use.
287
+ bias (`bool`, *optional*, defaults to False):
288
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
289
+ upcast_attention (`bool`, *optional*, defaults to False):
290
+ Set to `True` to upcast the attention computation to `float32`.
291
+ upcast_softmax (`bool`, *optional*, defaults to False):
292
+ Set to `True` to upcast the softmax computation to `float32`.
293
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
294
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
295
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
296
+ The number of groups to use for the group norm in the cross attention.
297
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
298
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
299
+ norm_num_groups (`int`, *optional*, defaults to `None`):
300
+ The number of groups to use for the group norm in the attention.
301
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
302
+ The number of channels to use for the spatial normalization.
303
+ out_bias (`bool`, *optional*, defaults to `True`):
304
+ Set to `True` to use a bias in the output linear layer.
305
+ scale_qk (`bool`, *optional*, defaults to `True`):
306
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
307
+ qk_norm (`str`, *optional*, defaults to None):
308
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
309
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
310
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
311
+ `added_kv_proj_dim` is not `None`.
312
+ eps (`float`, *optional*, defaults to 1e-5):
313
+ An additional value added to the denominator in group normalization that is used for numerical stability.
314
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
315
+ A factor to rescale the output by dividing it with this value.
316
+ residual_connection (`bool`, *optional*, defaults to `False`):
317
+ Set to `True` to add the residual connection to the output.
318
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
319
+ Set to `True` if the attention block is loaded from a deprecated state dict.
320
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
321
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
322
+ `AttnProcessor` otherwise.
323
+ """
324
+
325
+ def __init__(
326
+ self,
327
+ query_dim: int,
328
+ cross_attention_dim: Optional[int] = None,
329
+ heads: int = 8,
330
+ dim_head: int = 64,
331
+ dropout: float = 0.0,
332
+ bias: bool = False,
333
+ upcast_attention: bool = False,
334
+ upcast_softmax: bool = False,
335
+ cross_attention_norm: Optional[str] = None,
336
+ cross_attention_norm_num_groups: int = 32,
337
+ added_kv_proj_dim: Optional[int] = None,
338
+ norm_num_groups: Optional[int] = None,
339
+ spatial_norm_dim: Optional[int] = None,
340
+ out_bias: bool = True,
341
+ scale_qk: bool = True,
342
+ qk_norm: Optional[str] = None,
343
+ only_cross_attention: bool = False,
344
+ eps: float = 1e-5,
345
+ rescale_output_factor: float = 1.0,
346
+ residual_connection: bool = False,
347
+ _from_deprecated_attn_block: bool = False,
348
+ processor: Optional["AttnProcessor"] = None,
349
+ out_dim: int = None,
350
+ use_tpu_flash_attention: bool = False,
351
+ use_rope: bool = False,
352
+ ):
353
+ super().__init__()
354
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
355
+ self.query_dim = query_dim
356
+ self.use_bias = bias
357
+ self.is_cross_attention = cross_attention_dim is not None
358
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
359
+ self.upcast_attention = upcast_attention
360
+ self.upcast_softmax = upcast_softmax
361
+ self.rescale_output_factor = rescale_output_factor
362
+ self.residual_connection = residual_connection
363
+ self.dropout = dropout
364
+ self.fused_projections = False
365
+ self.out_dim = out_dim if out_dim is not None else query_dim
366
+ self.use_tpu_flash_attention = use_tpu_flash_attention
367
+ self.use_rope = use_rope
368
+
369
+ # we make use of this private variable to know whether this class is loaded
370
+ # with an deprecated state dict so that we can convert it on the fly
371
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
372
+
373
+ self.scale_qk = scale_qk
374
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
375
+
376
+ if qk_norm is None:
377
+ self.q_norm = nn.Identity()
378
+ self.k_norm = nn.Identity()
379
+ elif qk_norm == "rms_norm":
380
+ self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
381
+ self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
382
+ elif qk_norm == "layer_norm":
383
+ self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
384
+ self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
385
+ else:
386
+ raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
387
+
388
+ self.heads = out_dim // dim_head if out_dim is not None else heads
389
+ # for slice_size > 0 the attention score computation
390
+ # is split across the batch axis to save memory
391
+ # You can set slice_size with `set_attention_slice`
392
+ self.sliceable_head_dim = heads
393
+
394
+ self.added_kv_proj_dim = added_kv_proj_dim
395
+ self.only_cross_attention = only_cross_attention
396
+
397
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
398
+ raise ValueError(
399
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
400
+ )
401
+
402
+ if norm_num_groups is not None:
403
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
404
+ else:
405
+ self.group_norm = None
406
+
407
+ if spatial_norm_dim is not None:
408
+ self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim)
409
+ else:
410
+ self.spatial_norm = None
411
+
412
+ if cross_attention_norm is None:
413
+ self.norm_cross = None
414
+ elif cross_attention_norm == "layer_norm":
415
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
416
+ elif cross_attention_norm == "group_norm":
417
+ if self.added_kv_proj_dim is not None:
418
+ # The given `encoder_hidden_states` are initially of shape
419
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
420
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
421
+ # before the projection, so we need to use `added_kv_proj_dim` as
422
+ # the number of channels for the group norm.
423
+ norm_cross_num_channels = added_kv_proj_dim
424
+ else:
425
+ norm_cross_num_channels = self.cross_attention_dim
426
+
427
+ self.norm_cross = nn.GroupNorm(
428
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
429
+ )
430
+ else:
431
+ raise ValueError(
432
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
433
+ )
434
+
435
+ linear_cls = nn.Linear
436
+
437
+ self.linear_cls = linear_cls
438
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
439
+
440
+ if not self.only_cross_attention:
441
+ # only relevant for the `AddedKVProcessor` classes
442
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
443
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
444
+ else:
445
+ self.to_k = None
446
+ self.to_v = None
447
+
448
+ if self.added_kv_proj_dim is not None:
449
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
450
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
451
+
452
+ self.to_out = nn.ModuleList([])
453
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
454
+ self.to_out.append(nn.Dropout(dropout))
455
+
456
+ # set attention processor
457
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
458
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
459
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
460
+ if processor is None:
461
+ processor = AttnProcessor2_0()
462
+ self.set_processor(processor)
463
+
464
+ def set_processor(self, processor: "AttnProcessor") -> None:
465
+ r"""
466
+ Set the attention processor to use.
467
+
468
+ Args:
469
+ processor (`AttnProcessor`):
470
+ The attention processor to use.
471
+ """
472
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
473
+ # pop `processor` from `self._modules`
474
+ if (
475
+ hasattr(self, "processor")
476
+ and isinstance(self.processor, torch.nn.Module)
477
+ and not isinstance(processor, torch.nn.Module)
478
+ ):
479
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
480
+ self._modules.pop("processor")
481
+
482
+ self.processor = processor
483
+
484
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor": # noqa: F821
485
+ r"""
486
+ Get the attention processor in use.
487
+
488
+ Args:
489
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
490
+ Set to `True` to return the deprecated LoRA attention processor.
491
+
492
+ Returns:
493
+ "AttentionProcessor": The attention processor in use.
494
+ """
495
+ if not return_deprecated_lora:
496
+ return self.processor
497
+
498
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
499
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
500
+ # with PEFT is completed.
501
+ is_lora_activated = {
502
+ name: module.lora_layer is not None
503
+ for name, module in self.named_modules()
504
+ if hasattr(module, "lora_layer")
505
+ }
506
+
507
+ # 1. if no layer has a LoRA activated we can return the processor as usual
508
+ if not any(is_lora_activated.values()):
509
+ return self.processor
510
+
511
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
512
+ is_lora_activated.pop("add_k_proj", None)
513
+ is_lora_activated.pop("add_v_proj", None)
514
+ # 2. else it is not posssible that only some layers have LoRA activated
515
+ if not all(is_lora_activated.values()):
516
+ raise ValueError(
517
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
518
+ )
519
+
520
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
521
+ non_lora_processor_cls_name = self.processor.__class__.__name__
522
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
523
+
524
+ hidden_size = self.inner_dim
525
+
526
+ # now create a LoRA attention processor from the LoRA layers
527
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
528
+ kwargs = {
529
+ "cross_attention_dim": self.cross_attention_dim,
530
+ "rank": self.to_q.lora_layer.rank,
531
+ "network_alpha": self.to_q.lora_layer.network_alpha,
532
+ "q_rank": self.to_q.lora_layer.rank,
533
+ "q_hidden_size": self.to_q.lora_layer.out_features,
534
+ "k_rank": self.to_k.lora_layer.rank,
535
+ "k_hidden_size": self.to_k.lora_layer.out_features,
536
+ "v_rank": self.to_v.lora_layer.rank,
537
+ "v_hidden_size": self.to_v.lora_layer.out_features,
538
+ "out_rank": self.to_out[0].lora_layer.rank,
539
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
540
+ }
541
+
542
+ if hasattr(self.processor, "attention_op"):
543
+ kwargs["attention_op"] = self.processor.attention_op
544
+
545
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
546
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
547
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
548
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
549
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
550
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
551
+ lora_processor = lora_processor_cls(
552
+ hidden_size,
553
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
554
+ rank=self.to_q.lora_layer.rank,
555
+ network_alpha=self.to_q.lora_layer.network_alpha,
556
+ )
557
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
558
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
559
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
560
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
561
+
562
+ # only save if used
563
+ if self.add_k_proj.lora_layer is not None:
564
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
565
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
566
+ else:
567
+ lora_processor.add_k_proj_lora = None
568
+ lora_processor.add_v_proj_lora = None
569
+ else:
570
+ raise ValueError(f"{lora_processor_cls} does not exist.")
571
+
572
+ return lora_processor
573
+
574
+ def forward(
575
+ self,
576
+ hidden_states: torch.FloatTensor,
577
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
578
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
579
+ attention_mask: Optional[torch.FloatTensor] = None,
580
+ **cross_attention_kwargs,
581
+ ) -> torch.Tensor:
582
+ r"""
583
+ The forward method of the `Attention` class.
584
+
585
+ Args:
586
+ hidden_states (`torch.Tensor`):
587
+ The hidden states of the query.
588
+ encoder_hidden_states (`torch.Tensor`, *optional*):
589
+ The hidden states of the encoder.
590
+ attention_mask (`torch.Tensor`, *optional*):
591
+ The attention mask to use. If `None`, no mask is applied.
592
+ **cross_attention_kwargs:
593
+ Additional keyword arguments to pass along to the cross attention.
594
+
595
+ Returns:
596
+ `torch.Tensor`: The output of the attention layer.
597
+ """
598
+ # The `Attention` class can call different attention processors / attention functions
599
+ # here we simply pass along all tensors to the selected processor class
600
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
601
+
602
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
603
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
604
+ if len(unused_kwargs) > 0:
605
+ logger.warning(
606
+ f"cross_attention_kwargs {unused_kwargs} are not expected by"
607
+ f" {self.processor.__class__.__name__} and will be ignored."
608
+ )
609
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
610
+
611
+ return self.processor(
612
+ self,
613
+ hidden_states,
614
+ freqs_cis=freqs_cis,
615
+ encoder_hidden_states=encoder_hidden_states,
616
+ attention_mask=attention_mask,
617
+ **cross_attention_kwargs,
618
+ )
619
+
620
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
621
+ r"""
622
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
623
+ is the number of heads initialized while constructing the `Attention` class.
624
+
625
+ Args:
626
+ tensor (`torch.Tensor`): The tensor to reshape.
627
+
628
+ Returns:
629
+ `torch.Tensor`: The reshaped tensor.
630
+ """
631
+ head_size = self.heads
632
+ batch_size, seq_len, dim = tensor.shape
633
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
634
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
635
+ return tensor
636
+
637
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
638
+ r"""
639
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
640
+ the number of heads initialized while constructing the `Attention` class.
641
+
642
+ Args:
643
+ tensor (`torch.Tensor`): The tensor to reshape.
644
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
645
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
646
+
647
+ Returns:
648
+ `torch.Tensor`: The reshaped tensor.
649
+ """
650
+
651
+ head_size = self.heads
652
+ if tensor.ndim == 3:
653
+ batch_size, seq_len, dim = tensor.shape
654
+ extra_dim = 1
655
+ else:
656
+ batch_size, extra_dim, seq_len, dim = tensor.shape
657
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
658
+ tensor = tensor.permute(0, 2, 1, 3)
659
+
660
+ if out_dim == 3:
661
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
662
+
663
+ return tensor
664
+
665
+ def get_attention_scores(
666
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: torch.Tensor = None
667
+ ) -> torch.Tensor:
668
+ r"""
669
+ Compute the attention scores.
670
+
671
+ Args:
672
+ query (`torch.Tensor`): The query tensor.
673
+ key (`torch.Tensor`): The key tensor.
674
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
675
+
676
+ Returns:
677
+ `torch.Tensor`: The attention probabilities/scores.
678
+ """
679
+ dtype = query.dtype
680
+ if self.upcast_attention:
681
+ query = query.float()
682
+ key = key.float()
683
+
684
+ if attention_mask is None:
685
+ baddbmm_input = torch.empty(
686
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
687
+ )
688
+ beta = 0
689
+ else:
690
+ baddbmm_input = attention_mask
691
+ beta = 1
692
+
693
+ attention_scores = torch.baddbmm(
694
+ baddbmm_input,
695
+ query,
696
+ key.transpose(-1, -2),
697
+ beta=beta,
698
+ alpha=self.scale,
699
+ )
700
+ del baddbmm_input
701
+
702
+ if self.upcast_softmax:
703
+ attention_scores = attention_scores.float()
704
+
705
+ attention_probs = attention_scores.softmax(dim=-1)
706
+ del attention_scores
707
+
708
+ attention_probs = attention_probs.to(dtype)
709
+
710
+ return attention_probs
711
+
712
+ def prepare_attention_mask(
713
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
714
+ ) -> torch.Tensor:
715
+ r"""
716
+ Prepare the attention mask for the attention computation.
717
+
718
+ Args:
719
+ attention_mask (`torch.Tensor`):
720
+ The attention mask to prepare.
721
+ target_length (`int`):
722
+ The target length of the attention mask. This is the length of the attention mask after padding.
723
+ batch_size (`int`):
724
+ The batch size, which is used to repeat the attention mask.
725
+ out_dim (`int`, *optional*, defaults to `3`):
726
+ The output dimension of the attention mask. Can be either `3` or `4`.
727
+
728
+ Returns:
729
+ `torch.Tensor`: The prepared attention mask.
730
+ """
731
+ head_size = self.heads
732
+ if attention_mask is None:
733
+ return attention_mask
734
+
735
+ current_length: int = attention_mask.shape[-1]
736
+ if current_length != target_length:
737
+ if attention_mask.device.type == "mps":
738
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
739
+ # Instead, we can manually construct the padding tensor.
740
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
741
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
742
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
743
+ else:
744
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
745
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
746
+ # remaining_length: int = target_length - current_length
747
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
748
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
749
+
750
+ if out_dim == 3:
751
+ if attention_mask.shape[0] < batch_size * head_size:
752
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
753
+ elif out_dim == 4:
754
+ attention_mask = attention_mask.unsqueeze(1)
755
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
756
+
757
+ return attention_mask
758
+
759
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
760
+ r"""
761
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
762
+ `Attention` class.
763
+
764
+ Args:
765
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
766
+
767
+ Returns:
768
+ `torch.Tensor`: The normalized encoder hidden states.
769
+ """
770
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
771
+
772
+ if isinstance(self.norm_cross, nn.LayerNorm):
773
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
774
+ elif isinstance(self.norm_cross, nn.GroupNorm):
775
+ # Group norm norms along the channels dimension and expects
776
+ # input to be in the shape of (N, C, *). In this case, we want
777
+ # to norm along the hidden dimension, so we need to move
778
+ # (batch_size, sequence_length, hidden_size) ->
779
+ # (batch_size, hidden_size, sequence_length)
780
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
781
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
782
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
783
+ else:
784
+ assert False
785
+
786
+ return encoder_hidden_states
787
+
788
+ @staticmethod
789
+ def apply_rotary_emb(
790
+ input_tensor: torch.Tensor,
791
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
792
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
793
+ cos_freqs = freqs_cis[0]
794
+ sin_freqs = freqs_cis[1]
795
+
796
+ t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
797
+ t1, t2 = t_dup.unbind(dim=-1)
798
+ t_dup = torch.stack((-t2, t1), dim=-1)
799
+ input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
800
+
801
+ out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
802
+
803
+ return out
804
+
805
+
806
+ class AttnProcessor2_0:
807
+ r"""
808
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
809
+ """
810
+
811
+ def __init__(self):
812
+ pass
813
+
814
+ def __call__(
815
+ self,
816
+ attn: Attention,
817
+ hidden_states: torch.FloatTensor,
818
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
819
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
820
+ attention_mask: Optional[torch.FloatTensor] = None,
821
+ temb: Optional[torch.FloatTensor] = None,
822
+ *args,
823
+ **kwargs,
824
+ ) -> torch.FloatTensor:
825
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
826
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
827
+ deprecate("scale", "1.0.0", deprecation_message)
828
+
829
+ residual = hidden_states
830
+ if attn.spatial_norm is not None:
831
+ hidden_states = attn.spatial_norm(hidden_states, temb)
832
+
833
+ input_ndim = hidden_states.ndim
834
+
835
+ if input_ndim == 4:
836
+ batch_size, channel, height, width = hidden_states.shape
837
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
838
+
839
+ batch_size, sequence_length, _ = (
840
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
841
+ )
842
+
843
+ if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
844
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
845
+ # scaled_dot_product_attention expects attention_mask shape to be
846
+ # (batch, heads, source_length, target_length)
847
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
848
+
849
+ if attn.group_norm is not None:
850
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
851
+
852
+ query = attn.to_q(hidden_states)
853
+ query = attn.q_norm(query)
854
+
855
+ if encoder_hidden_states is not None:
856
+ if attn.norm_cross:
857
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
858
+ key = attn.to_k(encoder_hidden_states)
859
+ key = attn.k_norm(key)
860
+ else: # if no context provided do self-attention
861
+ encoder_hidden_states = hidden_states
862
+ key = attn.to_k(hidden_states)
863
+ key = attn.k_norm(key)
864
+ if attn.use_rope:
865
+ key = attn.apply_rotary_emb(key, freqs_cis)
866
+ query = attn.apply_rotary_emb(query, freqs_cis)
867
+
868
+ value = attn.to_v(encoder_hidden_states)
869
+
870
+ inner_dim = key.shape[-1]
871
+ head_dim = inner_dim // attn.heads
872
+
873
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
874
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
875
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
876
+
877
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
878
+
879
+ if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
880
+ q_segment_indexes = None
881
+ if attention_mask is not None: # if mask is required need to tune both segmenIds fields
882
+ # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
883
+ attention_mask = attention_mask.to(torch.float32)
884
+ q_segment_indexes = torch.ones(batch_size, query.shape[2], device=query.device, dtype=torch.float32)
885
+ assert (
886
+ attention_mask.shape[1] == key.shape[2]
887
+ ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
888
+
889
+ assert (
890
+ query.shape[2] % 128 == 0
891
+ ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
892
+ assert (
893
+ key.shape[2] % 128 == 0
894
+ ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
895
+
896
+ # run the TPU kernel implemented in jax with pallas
897
+ hidden_states = flash_attention(
898
+ q=query,
899
+ k=key,
900
+ v=value,
901
+ q_segment_ids=q_segment_indexes,
902
+ kv_segment_ids=attention_mask,
903
+ sm_scale=attn.scale,
904
+ )
905
+ else:
906
+ hidden_states = F.scaled_dot_product_attention(
907
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
908
+ )
909
+
910
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
911
+ hidden_states = hidden_states.to(query.dtype)
912
+
913
+ # linear proj
914
+ hidden_states = attn.to_out[0](hidden_states)
915
+ # dropout
916
+ hidden_states = attn.to_out[1](hidden_states)
917
+
918
+ if input_ndim == 4:
919
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
920
+
921
+ if attn.residual_connection:
922
+ hidden_states = hidden_states + residual
923
+
924
+ hidden_states = hidden_states / attn.rescale_output_factor
925
+
926
+ return hidden_states
927
+
928
+
929
+ class AttnProcessor:
930
+ r"""
931
+ Default processor for performing attention-related computations.
932
+ """
933
+
934
+ def __call__(
935
+ self,
936
+ attn: Attention,
937
+ hidden_states: torch.FloatTensor,
938
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
939
+ attention_mask: Optional[torch.FloatTensor] = None,
940
+ temb: Optional[torch.FloatTensor] = None,
941
+ *args,
942
+ **kwargs,
943
+ ) -> torch.Tensor:
944
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
945
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
946
+ deprecate("scale", "1.0.0", deprecation_message)
947
+
948
+ residual = hidden_states
949
+
950
+ if attn.spatial_norm is not None:
951
+ hidden_states = attn.spatial_norm(hidden_states, temb)
952
+
953
+ input_ndim = hidden_states.ndim
954
+
955
+ if input_ndim == 4:
956
+ batch_size, channel, height, width = hidden_states.shape
957
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
958
+
959
+ batch_size, sequence_length, _ = (
960
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
961
+ )
962
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
963
+
964
+ if attn.group_norm is not None:
965
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
966
+
967
+ query = attn.to_q(hidden_states)
968
+
969
+ if encoder_hidden_states is None:
970
+ encoder_hidden_states = hidden_states
971
+ elif attn.norm_cross:
972
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
973
+
974
+ key = attn.to_k(encoder_hidden_states)
975
+ value = attn.to_v(encoder_hidden_states)
976
+
977
+ query = attn.head_to_batch_dim(query)
978
+ key = attn.head_to_batch_dim(key)
979
+ value = attn.head_to_batch_dim(value)
980
+
981
+ query = attn.q_norm(query)
982
+ key = attn.k_norm(key)
983
+
984
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
985
+ hidden_states = torch.bmm(attention_probs, value)
986
+ hidden_states = attn.batch_to_head_dim(hidden_states)
987
+
988
+ # linear proj
989
+ hidden_states = attn.to_out[0](hidden_states)
990
+ # dropout
991
+ hidden_states = attn.to_out[1](hidden_states)
992
+
993
+ if input_ndim == 4:
994
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
995
+
996
+ if attn.residual_connection:
997
+ hidden_states = hidden_states + residual
998
+
999
+ hidden_states = hidden_states / attn.rescale_output_factor
1000
+
1001
+ return hidden_states
1002
+
1003
+
1004
+ class FeedForward(nn.Module):
1005
+ r"""
1006
+ A feed-forward layer.
1007
+
1008
+ Parameters:
1009
+ dim (`int`): The number of channels in the input.
1010
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1011
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1012
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1013
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1014
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1015
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1016
+ """
1017
+
1018
+ def __init__(
1019
+ self,
1020
+ dim: int,
1021
+ dim_out: Optional[int] = None,
1022
+ mult: int = 4,
1023
+ dropout: float = 0.0,
1024
+ activation_fn: str = "geglu",
1025
+ final_dropout: bool = False,
1026
+ inner_dim=None,
1027
+ bias: bool = True,
1028
+ ):
1029
+ super().__init__()
1030
+ if inner_dim is None:
1031
+ inner_dim = int(dim * mult)
1032
+ dim_out = dim_out if dim_out is not None else dim
1033
+ linear_cls = nn.Linear
1034
+
1035
+ if activation_fn == "gelu":
1036
+ act_fn = GELU(dim, inner_dim, bias=bias)
1037
+ elif activation_fn == "gelu-approximate":
1038
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1039
+ elif activation_fn == "geglu":
1040
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1041
+ elif activation_fn == "geglu-approximate":
1042
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1043
+ else:
1044
+ raise ValueError(f"Unsupported activation function: {activation_fn}")
1045
+
1046
+ self.net = nn.ModuleList([])
1047
+ # project in
1048
+ self.net.append(act_fn)
1049
+ # project dropout
1050
+ self.net.append(nn.Dropout(dropout))
1051
+ # project out
1052
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
1053
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1054
+ if final_dropout:
1055
+ self.net.append(nn.Dropout(dropout))
1056
+
1057
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1058
+ compatible_cls = (GEGLU, LoRACompatibleLinear)
1059
+ for module in self.net:
1060
+ if isinstance(module, compatible_cls):
1061
+ hidden_states = module(hidden_states, scale)
1062
+ else:
1063
+ hidden_states = module(hidden_states)
1064
+ return hidden_states
transformer/transformer3d.py CHANGED
@@ -11,10 +11,7 @@ from diffusers.models.normalization import AdaLayerNormSingle
11
  from diffusers.utils import BaseOutput, is_torch_version
12
  from torch import nn
13
 
14
- from txt2img.common import dist_util, logger
15
- from txt2img.config.weights_init_config import WeightsInitConfig, WeightsInitModeName
16
- from txt2img.diffusion.models.pixart.attention import BasicTransformerBlock
17
- from txt2img.diffusion.models.pixart.embeddings import get_3d_sincos_pos_embed
18
 
19
 
20
  @dataclass
@@ -146,64 +143,6 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
146
 
147
  self.gradient_checkpointing = False
148
 
149
- def set_use_tpu_flash_attention(self):
150
- r"""
151
- Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
152
- attention kernel.
153
- """
154
- logger.info(" ENABLE TPU FLASH ATTENTION -> TRUE")
155
- # if using TPU -> configure components to use TPU flash attention
156
- if dist_util.acceleration_type() == dist_util.AccelerationType.TPU:
157
- self.use_tpu_flash_attention = True
158
- # push config down to the attention modules
159
- for block in self.transformer_blocks:
160
- block.set_use_tpu_flash_attention()
161
-
162
- def initialize(self, weights_init: WeightsInitConfig):
163
- if weights_init.mode != WeightsInitModeName.PixArt and weights_init.mode != WeightsInitModeName.Xora:
164
- return
165
-
166
- def _basic_init(module):
167
- if isinstance(module, nn.Linear):
168
- torch.nn.init.xavier_uniform_(module.weight)
169
- if module.bias is not None:
170
- nn.init.constant_(module.bias, 0)
171
-
172
- self.apply(_basic_init)
173
-
174
- # Initialize timestep embedding MLP:
175
- nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_1.weight, std=weights_init.embedding_std)
176
- nn.init.normal_(self.adaln_single.emb.timestep_embedder.linear_2.weight, std=weights_init.embedding_std)
177
- nn.init.normal_(self.adaln_single.linear.weight, std=weights_init.embedding_std)
178
-
179
- if hasattr(self.adaln_single.emb, "resolution_embedder"):
180
- nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_1.weight, std=weights_init.embedding_std)
181
- nn.init.normal_(self.adaln_single.emb.resolution_embedder.linear_2.weight, std=weights_init.embedding_std)
182
- if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"):
183
- nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, std=weights_init.embedding_std)
184
- nn.init.normal_(self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, std=weights_init.embedding_std)
185
-
186
- # Initialize caption embedding MLP:
187
- nn.init.normal_(self.caption_projection.linear_1.weight, std=weights_init.embedding_std)
188
- nn.init.normal_(self.caption_projection.linear_1.weight, std=weights_init.embedding_std)
189
-
190
- # Zero-out adaLN modulation layers in PixArt blocks:
191
- for block in self.transformer_blocks:
192
- if weights_init.mode == WeightsInitModeName.Xora:
193
- nn.init.constant_(block.attn1.to_out[0].weight, 0)
194
- nn.init.constant_(block.attn1.to_out[0].bias, 0)
195
-
196
- nn.init.constant_(block.attn2.to_out[0].weight, 0)
197
- nn.init.constant_(block.attn2.to_out[0].bias, 0)
198
-
199
- if weights_init.mode == WeightsInitModeName.Xora:
200
- nn.init.constant_(block.ff.net[2].weight, 0)
201
- nn.init.constant_(block.ff.net[2].bias, 0)
202
-
203
- # Zero-out output layers:
204
- nn.init.constant_(self.proj_out.weight, 0)
205
- nn.init.constant_(self.proj_out.bias, 0)
206
-
207
  def _set_gradient_checkpointing(self, module, value=False):
208
  if hasattr(module, "gradient_checkpointing"):
209
  module.gradient_checkpointing = value
@@ -348,14 +287,10 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
348
  if self.timestep_scale_multiplier:
349
  timestep = self.timestep_scale_multiplier * timestep
350
 
351
- if self.positional_embedding_type == "absolute":
352
- pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to(hidden_states.device)
353
- if self.project_to_2d_pos:
354
- pos_embed = self.to_2d_proj(pos_embed_3d)
355
- hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype)
356
- freqs_cis = None
357
- elif self.positional_embedding_type == "rope":
358
  freqs_cis = self.precompute_freqs_cis(indices_grid)
 
 
359
 
360
  batch_size = hidden_states.shape[0]
361
  timestep, embedded_timestep = self.adaln_single(
@@ -423,14 +358,3 @@ class Transformer3DModel(ModelMixin, ConfigMixin):
423
 
424
  return Transformer3DModelOutput(sample=hidden_states)
425
 
426
- def get_absolute_pos_embed(self, grid):
427
- grid_np = grid[0].cpu().numpy()
428
- embed_dim_3d = math.ceil((self.inner_dim / 2) * 3) if self.project_to_2d_pos else self.inner_dim
429
- pos_embed = get_3d_sincos_pos_embed( # (f h w)
430
- embed_dim_3d,
431
- grid_np,
432
- h=int(max(grid_np[1]) + 1),
433
- w=int(max(grid_np[2]) + 1),
434
- f=int(max(grid_np[0] + 1)),
435
- )
436
- return torch.from_numpy(pos_embed).float().unsqueeze(0)
 
11
  from diffusers.utils import BaseOutput, is_torch_version
12
  from torch import nn
13
 
14
+ from transformer.attention import BasicTransformerBlock
 
 
 
15
 
16
 
17
  @dataclass
 
143
 
144
  self.gradient_checkpointing = False
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  def _set_gradient_checkpointing(self, module, value=False):
147
  if hasattr(module, "gradient_checkpointing"):
148
  module.gradient_checkpointing = value
 
287
  if self.timestep_scale_multiplier:
288
  timestep = self.timestep_scale_multiplier * timestep
289
 
290
+ if self.positional_embedding_type == "rope":
 
 
 
 
 
 
291
  freqs_cis = self.precompute_freqs_cis(indices_grid)
292
+ else:
293
+ raise NotImplementedError("Only rope pos embed supported.")
294
 
295
  batch_size = hidden_states.shape[0]
296
  timestep, embedded_timestep = self.adaln_single(
 
358
 
359
  return Transformer3DModelOutput(sample=hidden_states)
360
 
 
 
 
 
 
 
 
 
 
 
 
utils/torch_utils.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
4
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
5
+ dims_to_append = target_dims - x.ndim
6
+ if dims_to_append < 0:
7
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
8
+ elif dims_to_append == 0:
9
+ return x
10
+ return x[(...,) + (None,) * dims_to_append]
vae/{causal_video_encoder.py → autoencoders/causal_video_autoencoder.py} RENAMED
@@ -9,10 +9,9 @@ import numpy as np
9
  from einops import rearrange
10
  from torch import nn
11
 
12
- from txt2img.common import logger
13
- from txt2img.vae.layers.conv_nd_factory import make_conv_nd, make_linear_nd
14
- from txt2img.vae.layers.pixel_norm import PixelNorm
15
- from txt2img.vae.vae import AutoencoderKLWrapper
16
 
17
 
18
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
@@ -139,7 +138,7 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
139
  key = key.replace(k, v)
140
 
141
  if "norm" in key and key not in model_keys:
142
- logger.info(f"Removing key {key} from state_dict as it is not present in the model")
143
  continue
144
 
145
  converted_state_dict[key] = value
 
9
  from einops import rearrange
10
  from torch import nn
11
 
12
+ from vae.layers.conv_nd_factory import make_conv_nd, make_linear_nd
13
+ from vae.layers.pixel_norm import PixelNorm
14
+ from vae.vae import AutoencoderKLWrapper
 
15
 
16
 
17
  class CausalVideoAutoencoder(AutoencoderKLWrapper):
 
138
  key = key.replace(k, v)
139
 
140
  if "norm" in key and key not in model_keys:
141
+ print(f"Removing key {key} from state_dict as it is not present in the model")
142
  continue
143
 
144
  converted_state_dict[key] = value
vae/layers/causal_conv3d.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class CausalConv3d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels,
11
+ out_channels,
12
+ kernel_size: int = 3,
13
+ stride: Union[int, Tuple[int]] = 1,
14
+ **kwargs,
15
+ ):
16
+ super().__init__()
17
+
18
+ self.in_channels = in_channels
19
+ self.out_channels = out_channels
20
+
21
+ kernel_size = (kernel_size, kernel_size, kernel_size)
22
+ self.time_kernel_size = kernel_size[0]
23
+
24
+ dilation = kwargs.pop("dilation", 1)
25
+ dilation = (dilation, 1, 1)
26
+
27
+ height_pad = kernel_size[1] // 2
28
+ width_pad = kernel_size[2] // 2
29
+ padding = (0, height_pad, width_pad)
30
+
31
+ self.conv = nn.Conv3d(
32
+ in_channels,
33
+ out_channels,
34
+ kernel_size,
35
+ stride=stride,
36
+ dilation=dilation,
37
+ padding=padding,
38
+ padding_mode="zeros",
39
+ )
40
+
41
+ def forward(self, x, causal: bool = True):
42
+ if causal:
43
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
44
+ x = torch.concatenate((first_frame_pad, x), dim=2)
45
+ else:
46
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
47
+ last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
48
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
49
+ x = self.conv(x)
50
+ return x
51
+
52
+ @property
53
+ def weight(self):
54
+ return self.conv.weight
vae/layers/conv_nd_factory.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+ from vae.layers.dual_conv3d import DualConv3d
6
+ from vae.layers.causal_conv3d import CausalConv3d
7
+
8
+
9
+ def make_conv_nd(
10
+ dims: Union[int, Tuple[int, int]],
11
+ in_channels: int,
12
+ out_channels: int,
13
+ kernel_size: int,
14
+ stride=1,
15
+ padding=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bias=True,
19
+ causal=False,
20
+ ):
21
+ if dims == 2:
22
+ return torch.nn.Conv2d(
23
+ in_channels=in_channels,
24
+ out_channels=out_channels,
25
+ kernel_size=kernel_size,
26
+ stride=stride,
27
+ padding=padding,
28
+ dilation=dilation,
29
+ groups=groups,
30
+ bias=bias,
31
+ )
32
+ elif dims == 3:
33
+ if causal:
34
+ return CausalConv3d(
35
+ in_channels=in_channels,
36
+ out_channels=out_channels,
37
+ kernel_size=kernel_size,
38
+ stride=stride,
39
+ padding=padding,
40
+ dilation=dilation,
41
+ groups=groups,
42
+ bias=bias,
43
+ )
44
+ return torch.nn.Conv3d(
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ kernel_size=kernel_size,
48
+ stride=stride,
49
+ padding=padding,
50
+ dilation=dilation,
51
+ groups=groups,
52
+ bias=bias,
53
+ )
54
+ elif dims == (2, 1):
55
+ return DualConv3d(
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ kernel_size=kernel_size,
59
+ stride=stride,
60
+ padding=padding,
61
+ bias=bias,
62
+ )
63
+ else:
64
+ raise ValueError(f"unsupported dimensions: {dims}")
65
+
66
+
67
+ def make_linear_nd(
68
+ dims: int,
69
+ in_channels: int,
70
+ out_channels: int,
71
+ bias=True,
72
+ ):
73
+ if dims == 2:
74
+ return torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
75
+ elif dims == 3 or dims == (2, 1):
76
+ return torch.nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias)
77
+ else:
78
+ raise ValueError(f"unsupported dimensions: {dims}")
vae/layers/dual_conv3d.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ class DualConv3d(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride: Union[int, Tuple[int, int, int]] = 1,
17
+ padding: Union[int, Tuple[int, int, int]] = 0,
18
+ dilation: Union[int, Tuple[int, int, int]] = 1,
19
+ groups=1,
20
+ bias=True,
21
+ ):
22
+ super(DualConv3d, self).__init__()
23
+
24
+ self.in_channels = in_channels
25
+ self.out_channels = out_channels
26
+ # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
27
+ if isinstance(kernel_size, int):
28
+ kernel_size = (kernel_size, kernel_size, kernel_size)
29
+ if kernel_size == (1, 1, 1):
30
+ raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.")
31
+ if isinstance(stride, int):
32
+ stride = (stride, stride, stride)
33
+ if isinstance(padding, int):
34
+ padding = (padding, padding, padding)
35
+ if isinstance(dilation, int):
36
+ dilation = (dilation, dilation, dilation)
37
+
38
+ # Set parameters for convolutions
39
+ self.groups = groups
40
+ self.bias = bias
41
+
42
+ # Define the size of the channels after the first convolution
43
+ intermediate_channels = out_channels if in_channels < out_channels else in_channels
44
+
45
+ # Define parameters for the first convolution
46
+ self.weight1 = nn.Parameter(
47
+ torch.Tensor(intermediate_channels, in_channels // groups, 1, kernel_size[1], kernel_size[2])
48
+ )
49
+ self.stride1 = (1, stride[1], stride[2])
50
+ self.padding1 = (0, padding[1], padding[2])
51
+ self.dilation1 = (1, dilation[1], dilation[2])
52
+ if bias:
53
+ self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
54
+ else:
55
+ self.register_parameter("bias1", None)
56
+
57
+ # Define parameters for the second convolution
58
+ self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1))
59
+ self.stride2 = (stride[0], 1, 1)
60
+ self.padding2 = (padding[0], 0, 0)
61
+ self.dilation2 = (dilation[0], 1, 1)
62
+ if bias:
63
+ self.bias2 = nn.Parameter(torch.Tensor(out_channels))
64
+ else:
65
+ self.register_parameter("bias2", None)
66
+
67
+ # Initialize weights and biases
68
+ self.reset_parameters()
69
+
70
+ def reset_parameters(self):
71
+ nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
72
+ nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
73
+ if self.bias:
74
+ fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
75
+ bound1 = 1 / math.sqrt(fan_in1)
76
+ nn.init.uniform_(self.bias1, -bound1, bound1)
77
+ fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
78
+ bound2 = 1 / math.sqrt(fan_in2)
79
+ nn.init.uniform_(self.bias2, -bound2, bound2)
80
+
81
+ def forward(self, x, use_conv3d=False, skip_time_conv=False):
82
+ if use_conv3d:
83
+ return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
84
+ else:
85
+ return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
86
+
87
+ def forward_with_3d(self, x, skip_time_conv):
88
+ # First convolution
89
+ x = F.conv3d(x, self.weight1, self.bias1, self.stride1, self.padding1, self.dilation1, self.groups)
90
+
91
+ if skip_time_conv:
92
+ return x
93
+
94
+ # Second convolution
95
+ x = F.conv3d(x, self.weight2, self.bias2, self.stride2, self.padding2, self.dilation2, self.groups)
96
+
97
+ return x
98
+
99
+ def forward_with_2d(self, x, skip_time_conv):
100
+ b, c, d, h, w = x.shape
101
+
102
+ # First 2D convolution
103
+ x = rearrange(x, "b c d h w -> (b d) c h w")
104
+ # Squeeze the depth dimension out of weight1 since it's 1
105
+ weight1 = self.weight1.squeeze(2)
106
+ # Select stride, padding, and dilation for the 2D convolution
107
+ stride1 = (self.stride1[1], self.stride1[2])
108
+ padding1 = (self.padding1[1], self.padding1[2])
109
+ dilation1 = (self.dilation1[1], self.dilation1[2])
110
+ x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
111
+
112
+ _, _, h, w = x.shape
113
+
114
+ if skip_time_conv:
115
+ x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
116
+ return x
117
+
118
+ # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
119
+ x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
120
+
121
+ # Reshape weight2 to match the expected dimensions for conv1d
122
+ weight2 = self.weight2.squeeze(-1).squeeze(-1)
123
+ # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
124
+ stride2 = self.stride2[0]
125
+ padding2 = self.padding2[0]
126
+ dilation2 = self.dilation2[0]
127
+ x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
128
+ x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
129
+
130
+ return x
131
+
132
+ @property
133
+ def weight(self):
134
+ return self.weight2
135
+
136
+
137
+ def test_dual_conv3d_consistency():
138
+ # Initialize parameters
139
+ in_channels = 3
140
+ out_channels = 5
141
+ kernel_size = (3, 3, 3)
142
+ stride = (2, 2, 2)
143
+ padding = (1, 1, 1)
144
+
145
+ # Create an instance of the DualConv3d class
146
+ dual_conv3d = DualConv3d(
147
+ in_channels=in_channels,
148
+ out_channels=out_channels,
149
+ kernel_size=kernel_size,
150
+ stride=stride,
151
+ padding=padding,
152
+ bias=True,
153
+ )
154
+
155
+ # Example input tensor
156
+ test_input = torch.randn(1, 3, 10, 10, 10)
157
+
158
+ # Perform forward passes with both 3D and 2D settings
159
+ output_conv3d = dual_conv3d(test_input, use_conv3d=True)
160
+ output_2d = dual_conv3d(test_input, use_conv3d=False)
161
+
162
+ # Assert that the outputs from both methods are sufficiently close
163
+ assert torch.allclose(
164
+ output_conv3d, output_2d, atol=1e-6
165
+ ), "Outputs are not consistent between 3D and 2D convolutions."
vae/layers/pixel_norm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PixelNorm(nn.Module):
6
+ def __init__(self, dim=1, eps=1e-8):
7
+ super(PixelNorm, self).__init__()
8
+ self.dim = dim
9
+ self.eps = eps
10
+
11
+ def forward(self, x):
12
+ return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
vae/vae.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import math
5
+ import torch.nn as nn
6
+ from diffusers import ConfigMixin, ModelMixin
7
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
8
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
9
+ from vae.layers.conv_nd_factory import make_conv_nd
10
+
11
+
12
+ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
13
+ """Variational Autoencoder (VAE) model with KL loss.
14
+
15
+ VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
16
+ This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
17
+
18
+ Args:
19
+ encoder (`nn.Module`):
20
+ Encoder module.
21
+ decoder (`nn.Module`):
22
+ Decoder module.
23
+ latent_channels (`int`, *optional*, defaults to 4):
24
+ Number of latent channels.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ encoder: nn.Module,
30
+ decoder: nn.Module,
31
+ latent_channels: int = 4,
32
+ dims: int = 2,
33
+ sample_size=512,
34
+ use_quant_conv: bool = True,
35
+ ):
36
+ super().__init__()
37
+
38
+ # pass init params to Encoder
39
+ self.encoder = encoder
40
+ self.use_quant_conv = use_quant_conv
41
+
42
+ # pass init params to Decoder
43
+ quant_dims = 2 if dims == 2 else 3
44
+ self.decoder = decoder
45
+ if use_quant_conv:
46
+ self.quant_conv = make_conv_nd(quant_dims, 2 * latent_channels, 2 * latent_channels, 1)
47
+ self.post_quant_conv = make_conv_nd(quant_dims, latent_channels, latent_channels, 1)
48
+ else:
49
+ self.quant_conv = nn.Identity()
50
+ self.post_quant_conv = nn.Identity()
51
+ self.use_z_tiling = False
52
+ self.use_hw_tiling = False
53
+ self.dims = dims
54
+ self.z_sample_size = 1
55
+
56
+ # only relevant if vae tiling is enabled
57
+ self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
58
+
59
+ def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
60
+ self.tile_sample_min_size = sample_size
61
+ num_blocks = len(self.encoder.down_blocks)
62
+ self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
63
+ self.tile_overlap_factor = overlap_factor
64
+
65
+ def enable_z_tiling(self, z_sample_size: int = 8):
66
+ r"""
67
+ Enable tiling during VAE decoding.
68
+
69
+ When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
70
+ steps. This is useful to save some memory and allow larger batch sizes.
71
+ """
72
+ self.use_z_tiling = z_sample_size > 1
73
+ self.z_sample_size = z_sample_size
74
+ assert (
75
+ z_sample_size % 8 == 0 or z_sample_size == 1
76
+ ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
77
+
78
+ def disable_z_tiling(self):
79
+ r"""
80
+ Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
81
+ decoding in one step.
82
+ """
83
+ self.use_z_tiling = False
84
+
85
+ def enable_hw_tiling(self):
86
+ r"""
87
+ Enable tiling during VAE decoding along the height and width dimension.
88
+ """
89
+ self.use_hw_tiling = True
90
+
91
+ def disable_hw_tiling(self):
92
+ r"""
93
+ Disable tiling during VAE decoding along the height and width dimension.
94
+ """
95
+ self.use_hw_tiling = False
96
+
97
+ def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
98
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
99
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
100
+ row_limit = self.tile_latent_min_size - blend_extent
101
+
102
+ # Split the image into 512x512 tiles and encode them separately.
103
+ rows = []
104
+ for i in range(0, x.shape[3], overlap_size):
105
+ row = []
106
+ for j in range(0, x.shape[4], overlap_size):
107
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
108
+ tile = self.encoder(tile)
109
+ tile = self.quant_conv(tile)
110
+ row.append(tile)
111
+ rows.append(row)
112
+ result_rows = []
113
+ for i, row in enumerate(rows):
114
+ result_row = []
115
+ for j, tile in enumerate(row):
116
+ # blend the above tile and the left tile
117
+ # to the current tile and add the current tile to the result row
118
+ if i > 0:
119
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
120
+ if j > 0:
121
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
122
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
123
+ result_rows.append(torch.cat(result_row, dim=4))
124
+
125
+ moments = torch.cat(result_rows, dim=3)
126
+ return moments
127
+
128
+ def blend_z(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
129
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
130
+ for z in range(blend_extent):
131
+ b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (1 - z / blend_extent) + b[:, :, z, :, :] * (
132
+ z / blend_extent
133
+ )
134
+ return b
135
+
136
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
137
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
138
+ for y in range(blend_extent):
139
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
140
+ y / blend_extent
141
+ )
142
+ return b
143
+
144
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
145
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
146
+ for x in range(blend_extent):
147
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
148
+ x / blend_extent
149
+ )
150
+ return b
151
+
152
+ def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
153
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
154
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
155
+ row_limit = self.tile_sample_min_size - blend_extent
156
+ tile_target_shape = (*target_shape[:3], self.tile_sample_min_size, self.tile_sample_min_size)
157
+ # Split z into overlapping 64x64 tiles and decode them separately.
158
+ # The tiles have an overlap to avoid seams between tiles.
159
+ rows = []
160
+ for i in range(0, z.shape[3], overlap_size):
161
+ row = []
162
+ for j in range(0, z.shape[4], overlap_size):
163
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
164
+ tile = self.post_quant_conv(tile)
165
+ decoded = self.decoder(tile, target_shape=tile_target_shape)
166
+ row.append(decoded)
167
+ rows.append(row)
168
+ result_rows = []
169
+ for i, row in enumerate(rows):
170
+ result_row = []
171
+ for j, tile in enumerate(row):
172
+ # blend the above tile and the left tile
173
+ # to the current tile and add the current tile to the result row
174
+ if i > 0:
175
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
176
+ if j > 0:
177
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
178
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
179
+ result_rows.append(torch.cat(result_row, dim=4))
180
+
181
+ dec = torch.cat(result_rows, dim=3)
182
+ return dec
183
+
184
+ def encode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
185
+ if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
186
+ num_splits = z.shape[2] // self.z_sample_size
187
+ sizes = [self.z_sample_size] * num_splits
188
+ sizes = sizes + [z.shape[2] - sum(sizes)] if z.shape[2] - sum(sizes) > 0 else sizes
189
+ tiles = z.split(sizes, dim=2)
190
+ moments_tiles = [
191
+ self._hw_tiled_encode(z_tile, return_dict) if self.use_hw_tiling else self._encode(z_tile)
192
+ for z_tile in tiles
193
+ ]
194
+ moments = torch.cat(moments_tiles, dim=2)
195
+
196
+ else:
197
+ moments = self._hw_tiled_encode(z, return_dict) if self.use_hw_tiling else self._encode(z)
198
+
199
+ posterior = DiagonalGaussianDistribution(moments)
200
+ if not return_dict:
201
+ return (posterior,)
202
+
203
+ return AutoencoderKLOutput(latent_dist=posterior)
204
+
205
+ def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
206
+ h = self.encoder(x)
207
+ moments = self.quant_conv(h)
208
+ return moments
209
+
210
+ def _decode(self, z: torch.FloatTensor, target_shape=None) -> Union[DecoderOutput, torch.FloatTensor]:
211
+ z = self.post_quant_conv(z)
212
+ dec = self.decoder(z, target_shape=target_shape)
213
+ return dec
214
+
215
+ def decode(
216
+ self, z: torch.FloatTensor, return_dict: bool = True, target_shape=None
217
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
218
+ assert target_shape is not None, "target_shape must be provided for decoding"
219
+ if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
220
+ reduction_factor = int(
221
+ self.encoder.patch_size_t
222
+ * 2 ** (len(self.encoder.down_blocks) - 1 - math.sqrt(self.encoder.patch_size))
223
+ )
224
+ split_size = self.z_sample_size // reduction_factor
225
+ num_splits = z.shape[2] // split_size
226
+
227
+ # copy target shape, and divide frame dimension (=2) by the context size
228
+ target_shape_split = list(target_shape)
229
+ target_shape_split[2] = target_shape[2] // num_splits
230
+
231
+ decoded_tiles = [
232
+ (
233
+ self._hw_tiled_decode(z_tile, target_shape_split)
234
+ if self.use_hw_tiling
235
+ else self._decode(z_tile, target_shape=target_shape_split)
236
+ )
237
+ for z_tile in torch.tensor_split(z, num_splits, dim=2)
238
+ ]
239
+ decoded = torch.cat(decoded_tiles, dim=2)
240
+ else:
241
+ decoded = (
242
+ self._hw_tiled_decode(z, target_shape)
243
+ if self.use_hw_tiling
244
+ else self._decode(z, target_shape=target_shape)
245
+ )
246
+
247
+ if not return_dict:
248
+ return (decoded,)
249
+
250
+ return DecoderOutput(sample=decoded)
251
+
252
+ def forward(
253
+ self,
254
+ sample: torch.FloatTensor,
255
+ sample_posterior: bool = False,
256
+ return_dict: bool = True,
257
+ generator: Optional[torch.Generator] = None,
258
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
259
+ r"""
260
+ Args:
261
+ sample (`torch.FloatTensor`): Input sample.
262
+ sample_posterior (`bool`, *optional*, defaults to `False`):
263
+ Whether to sample from the posterior.
264
+ return_dict (`bool`, *optional*, defaults to `True`):
265
+ Whether to return a [`DecoderOutput`] instead of a plain tuple.
266
+ generator (`torch.Generator`, *optional*):
267
+ Generator used to sample from the posterior.
268
+ """
269
+ x = sample
270
+ posterior = self.encode(x).latent_dist
271
+ if sample_posterior:
272
+ z = posterior.sample(generator=generator)
273
+ else:
274
+ z = posterior.mode()
275
+ dec = self.decode(z, target_shape=sample.shape).sample
276
+
277
+ if not return_dict:
278
+ return (dec,)
279
+
280
+ return DecoderOutput(sample=dec)
vae/vae_encode.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from diffusers import AutoencoderKL
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+ from torch.nn import functional
7
+
8
+
9
+ from vae.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
10
+
11
+ class Downsample3D(nn.Module):
12
+ def __init__(self, dims, in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1):
13
+ super().__init__()
14
+ stride: int = 2
15
+ self.padding = padding
16
+ self.in_channels = in_channels
17
+ self.dims = dims
18
+ self.conv = make_conv_nd(
19
+ dims=dims,
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=kernel_size,
23
+ stride=stride,
24
+ padding=padding,
25
+ )
26
+
27
+ def forward(self, x, downsample_in_time=True):
28
+ conv = self.conv
29
+ if self.padding == 0:
30
+ if self.dims == 2:
31
+ padding = (0, 1, 0, 1)
32
+ else:
33
+ padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
34
+
35
+ x = functional.pad(x, padding, mode="constant", value=0)
36
+
37
+ if self.dims == (2, 1) and not downsample_in_time:
38
+ return conv(x, skip_time_conv=True)
39
+
40
+ return conv(x)
41
+
42
+
43
+
44
+ def vae_encode(media_items: Tensor, vae: AutoencoderKL, split_size: int = 1, vae_per_channel_normalize=False) -> Tensor:
45
+ """
46
+ Encodes media items (images or videos) into latent representations using a specified VAE model.
47
+ The function supports processing batches of images or video frames and can handle the processing
48
+ in smaller sub-batches if needed.
49
+
50
+ Args:
51
+ media_items (Tensor): A torch Tensor containing the media items to encode. The expected
52
+ shape is (batch_size, channels, height, width) for images or (batch_size, channels,
53
+ frames, height, width) for videos.
54
+ vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
55
+ pre-configured and loaded with the appropriate model weights.
56
+ split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
57
+ If set to more than 1, the input media items are processed in smaller batches according to
58
+ this value. Defaults to 1, which processes all items in a single batch.
59
+
60
+ Returns:
61
+ Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
62
+ to match the input shape, scaled by the model's configuration.
63
+
64
+ Examples:
65
+ >>> import torch
66
+ >>> from diffusers import AutoencoderKL
67
+ >>> vae = AutoencoderKL.from_pretrained('your-model-name')
68
+ >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
69
+ >>> latents = vae_encode(images, vae)
70
+ >>> print(latents.shape) # Output shape will depend on the model's latent configuration.
71
+
72
+ Note:
73
+ In case of a video, the function encodes the media item frame-by frame.
74
+ """
75
+ is_video_shaped = media_items.dim() == 5
76
+ batch_size, channels = media_items.shape[0:2]
77
+
78
+ if channels != 3:
79
+ raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
80
+
81
+ if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
82
+ media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
83
+ if split_size > 1:
84
+ if len(media_items) % split_size != 0:
85
+ raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split")
86
+ encode_bs = len(media_items) // split_size
87
+ # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
88
+ latents = []
89
+ for image_batch in media_items.split(encode_bs):
90
+ latents.append(vae.encode(image_batch).latent_dist.sample())
91
+ latents = torch.cat(latents, dim=0)
92
+ else:
93
+ latents = vae.encode(media_items).latent_dist.sample()
94
+
95
+ latents = normalize_latents(latents, vae, vae_per_channel_normalize)
96
+ if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
97
+ latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
98
+ return latents
99
+
100
+
101
+ def vae_decode(
102
+ latents: Tensor, vae: AutoencoderKL, is_video: bool = True, split_size: int = 1, vae_per_channel_normalize=False
103
+ ) -> Tensor:
104
+ is_video_shaped = latents.dim() == 5
105
+ batch_size = latents.shape[0]
106
+
107
+ if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
108
+ latents = rearrange(latents, "b c n h w -> (b n) c h w")
109
+ if split_size > 1:
110
+ if len(latents) % split_size != 0:
111
+ raise ValueError("Error: The batch size must be divisible by 'train.vae_bs_split")
112
+ encode_bs = len(latents) // split_size
113
+ image_batch = [
114
+ _run_decoder(latent_batch, vae, is_video, vae_per_channel_normalize)
115
+ for latent_batch in latents.split(encode_bs)
116
+ ]
117
+ images = torch.cat(image_batch, dim=0)
118
+ else:
119
+ images = _run_decoder(latents, vae, is_video, vae_per_channel_normalize)
120
+
121
+ if is_video_shaped and not isinstance(vae, (CausalVideoAutoencoder)):
122
+ images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
123
+ return images
124
+
125
+
126
+ def _run_decoder(latents: Tensor, vae: AutoencoderKL, is_video: bool, vae_per_channel_normalize=False) -> Tensor:
127
+ if isinstance(vae, (CausalVideoAutoencoder)):
128
+ *_, fl, hl, wl = latents.shape
129
+ temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
130
+ latents = latents.to(vae.dtype)
131
+ image = vae.decode(
132
+ un_normalize_latents(latents, vae, vae_per_channel_normalize),
133
+ return_dict=False,
134
+ target_shape=(1, 3, fl * temporal_scale if is_video else 1, hl * spatial_scale, wl * spatial_scale),
135
+ )[0]
136
+ else:
137
+ image = vae.decode(
138
+ un_normalize_latents(latents, vae, vae_per_channel_normalize),
139
+ return_dict=False,
140
+ )[0]
141
+ return image
142
+
143
+
144
+ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
145
+ if isinstance(vae, CausalVideoAutoencoder):
146
+ spatial = vae.spatial_downscale_factor
147
+ temporal = vae.temporal_downscale_factor
148
+ else:
149
+ down_blocks = len([block for block in vae.encoder.down_blocks if isinstance(block.downsample, Downsample3D)])
150
+ spatial = vae.config.patch_size * 2**down_blocks
151
+ temporal = vae.config.patch_size_t * 2 ** down_blocks if isinstance(vae) else 1
152
+
153
+ return (temporal, spatial, spatial)
154
+
155
+
156
+ def normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor:
157
+ return (
158
+ (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
159
+ / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
160
+ if vae_per_channel_normalize
161
+ else latents * vae.config.scaling_factor
162
+ )
163
+
164
+
165
+ def un_normalize_latents(latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False) -> Tensor:
166
+ return (
167
+ latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
168
+ + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
169
+ if vae_per_channel_normalize
170
+ else latents / vae.config.scaling_factor
171
+ )