from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F class Modulation(nn.Module): def __init__( self, embedding_dim: int, condition_dim: int, zero_init: bool = False, single_layer: bool = False, ): super().__init__() self.silu = nn.SiLU() if single_layer: self.linear1 = nn.Identity() else: self.linear1 = nn.Linear(condition_dim, condition_dim) self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) # Only zero init the last linear layer if zero_init: nn.init.zeros_(self.linear2.weight) nn.init.zeros_(self.linear2.bias) def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: emb = self.linear2(self.silu(self.linear1(condition))) scale, shift = torch.chunk(emb, 2, dim=1) x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x class FeedForward(nn.Module): r""" A feed-forward layer. Parameters: dim (`int`): The number of channels in the input. dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( self, dim: int, dim_out: Optional[int] = None, mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim linear_cls = nn.Linear if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) if activation_fn == "gelu-approximate": act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim) self.net = nn.ModuleList([]) # project in self.net.append(act_fn) # project dropout self.net.append(nn.Dropout(dropout)) # project out self.net.append(linear_cls(inner_dim, dim_out)) # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout if final_dropout: self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for module in self.net: hidden_states = module(hidden_states) return hidden_states class Attention(nn.Module): def __init__( self, query_dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, bias: bool = False, out_bias: bool = True, ): super().__init__() self.inner_dim = dim_head * heads self.num_heads = heads self.scale = dim_head**-0.5 self.dropout = dropout # Linear projections self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) # Output projection self.to_out = nn.ModuleList( [ nn.Linear(self.inner_dim, query_dim, bias=out_bias), nn.Dropout(dropout), ] ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = hidden_states.shape # Project queries, keys, and values query = self.to_q(hidden_states) key = self.to_k(hidden_states) value = self.to_v(hidden_states) # Reshape for multi-head attention query = query.reshape( batch_size, sequence_length, self.num_heads, -1 ).transpose(1, 2) key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose( 1, 2 ) value = value.reshape( batch_size, sequence_length, self.num_heads, -1 ).transpose(1, 2) # Compute scaled dot product attention hidden_states = torch.nn.functional.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=self.scale, ) # Reshape and project output hidden_states = hidden_states.transpose(1, 2).reshape( batch_size, sequence_length, self.inner_dim ) # Apply output projection and dropout for module in self.to_out: hidden_states = module(hidden_states) return hidden_states class BasicTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, activation_fn: str = "geglu", attention_bias: bool = False, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, ): super().__init__() # Self-Attn self.norm1 = nn.LayerNorm( dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps ) self.attn1 = Attention( query_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, bias=attention_bias, ) # Feed-forward self.norm3 = nn.LayerNorm( dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps ) self.ff = FeedForward( dim, activation_fn=activation_fn, ) def forward( self, hidden_states: torch.FloatTensor, attention_mask: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: # Self-Attention norm_hidden_states = self.norm1(hidden_states) hidden_states = ( self.attn1( norm_hidden_states, attention_mask=attention_mask, ) + hidden_states ) # Feed-forward ff_output = self.ff(self.norm3(hidden_states)) hidden_states = ff_output + hidden_states return hidden_states class GELU(nn.Module): r""" GELU activation function with tanh approximation support with `approximate="tanh"`. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. """ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): super().__init__() self.proj = nn.Linear(dim_in, dim_out) self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( dtype=gate.dtype ) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) hidden_states = self.gelu(hidden_states) return hidden_states class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() linear_cls = nn.Linear self.proj = linear_cls(dim_in, dim_out * 2) def gelu(self, gate: torch.Tensor) -> torch.Tensor: if gate.device.type != "mps": return F.gelu(gate) # mps: gelu is not implemented for float16 return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) def forward(self, hidden_states, scale: float = 1.0): args = () hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) return hidden_states * self.gelu(gate) class ApproximateGELU(nn.Module): r""" The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: https://arxiv.org/abs/1606.08415. Parameters: dim_in (`int`): The number of channels in the input. dim_out (`int`): The number of channels in the output. """ def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.proj(x) return x * torch.sigmoid(1.702 * x)