|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class Modulation(nn.Module): |
|
def __init__( |
|
self, |
|
embedding_dim: int, |
|
condition_dim: int, |
|
zero_init: bool = False, |
|
single_layer: bool = False, |
|
): |
|
super().__init__() |
|
self.silu = nn.SiLU() |
|
if single_layer: |
|
self.linear1 = nn.Identity() |
|
else: |
|
self.linear1 = nn.Linear(condition_dim, condition_dim) |
|
|
|
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2) |
|
|
|
|
|
if zero_init: |
|
nn.init.zeros_(self.linear2.weight) |
|
nn.init.zeros_(self.linear2.bias) |
|
|
|
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor: |
|
emb = self.linear2(self.silu(self.linear1(condition))) |
|
scale, shift = torch.chunk(emb, 2, dim=1) |
|
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
return x |
|
|
|
|
|
class FeedForward(nn.Module): |
|
r""" |
|
A feed-forward layer. |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input. |
|
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. |
|
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. |
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
dim_out: Optional[int] = None, |
|
mult: int = 4, |
|
dropout: float = 0.0, |
|
activation_fn: str = "geglu", |
|
final_dropout: bool = False, |
|
): |
|
super().__init__() |
|
inner_dim = int(dim * mult) |
|
dim_out = dim_out if dim_out is not None else dim |
|
linear_cls = nn.Linear |
|
|
|
if activation_fn == "gelu": |
|
act_fn = GELU(dim, inner_dim) |
|
if activation_fn == "gelu-approximate": |
|
act_fn = GELU(dim, inner_dim, approximate="tanh") |
|
elif activation_fn == "geglu": |
|
act_fn = GEGLU(dim, inner_dim) |
|
elif activation_fn == "geglu-approximate": |
|
act_fn = ApproximateGELU(dim, inner_dim) |
|
|
|
self.net = nn.ModuleList([]) |
|
|
|
self.net.append(act_fn) |
|
|
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
self.net.append(linear_cls(inner_dim, dim_out)) |
|
|
|
if final_dropout: |
|
self.net.append(nn.Dropout(dropout)) |
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
for module in self.net: |
|
hidden_states = module(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
query_dim: int, |
|
heads: int = 8, |
|
dim_head: int = 64, |
|
dropout: float = 0.0, |
|
bias: bool = False, |
|
out_bias: bool = True, |
|
): |
|
super().__init__() |
|
self.inner_dim = dim_head * heads |
|
self.num_heads = heads |
|
self.scale = dim_head**-0.5 |
|
self.dropout = dropout |
|
|
|
|
|
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias) |
|
|
|
|
|
self.to_out = nn.ModuleList( |
|
[ |
|
nn.Linear(self.inner_dim, query_dim, bias=out_bias), |
|
nn.Dropout(dropout), |
|
] |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
batch_size, sequence_length, _ = hidden_states.shape |
|
|
|
|
|
query = self.to_q(hidden_states) |
|
key = self.to_k(hidden_states) |
|
value = self.to_v(hidden_states) |
|
|
|
|
|
query = query.reshape( |
|
batch_size, sequence_length, self.num_heads, -1 |
|
).transpose(1, 2) |
|
key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose( |
|
1, 2 |
|
) |
|
value = value.reshape( |
|
batch_size, sequence_length, self.num_heads, -1 |
|
).transpose(1, 2) |
|
|
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention( |
|
query, |
|
key, |
|
value, |
|
attn_mask=attention_mask, |
|
scale=self.scale, |
|
) |
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape( |
|
batch_size, sequence_length, self.inner_dim |
|
) |
|
|
|
|
|
for module in self.to_out: |
|
hidden_states = module(hidden_states) |
|
|
|
return hidden_states |
|
|
|
|
|
class BasicTransformerBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
activation_fn: str = "geglu", |
|
attention_bias: bool = False, |
|
norm_elementwise_affine: bool = True, |
|
norm_eps: float = 1e-5, |
|
): |
|
super().__init__() |
|
|
|
|
|
self.norm1 = nn.LayerNorm( |
|
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps |
|
) |
|
self.attn1 = Attention( |
|
query_dim=dim, |
|
heads=num_attention_heads, |
|
dim_head=attention_head_dim, |
|
bias=attention_bias, |
|
) |
|
|
|
|
|
self.norm3 = nn.LayerNorm( |
|
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps |
|
) |
|
self.ff = FeedForward( |
|
dim, |
|
activation_fn=activation_fn, |
|
) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
) -> torch.FloatTensor: |
|
|
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
hidden_states = ( |
|
self.attn1( |
|
norm_hidden_states, |
|
attention_mask=attention_mask, |
|
) |
|
+ hidden_states |
|
) |
|
|
|
|
|
ff_output = self.ff(self.norm3(hidden_states)) |
|
|
|
hidden_states = ff_output + hidden_states |
|
|
|
return hidden_states |
|
|
|
|
|
class GELU(nn.Module): |
|
r""" |
|
GELU activation function with tanh approximation support with `approximate="tanh"`. |
|
|
|
Parameters: |
|
dim_in (`int`): The number of channels in the input. |
|
dim_out (`int`): The number of channels in the output. |
|
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. |
|
""" |
|
|
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out) |
|
self.approximate = approximate |
|
|
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor: |
|
if gate.device.type != "mps": |
|
return F.gelu(gate, approximate=self.approximate) |
|
|
|
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to( |
|
dtype=gate.dtype |
|
) |
|
|
|
def forward(self, hidden_states): |
|
hidden_states = self.proj(hidden_states) |
|
hidden_states = self.gelu(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class GEGLU(nn.Module): |
|
r""" |
|
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. |
|
|
|
Parameters: |
|
dim_in (`int`): The number of channels in the input. |
|
dim_out (`int`): The number of channels in the output. |
|
""" |
|
|
|
def __init__(self, dim_in: int, dim_out: int): |
|
super().__init__() |
|
linear_cls = nn.Linear |
|
|
|
self.proj = linear_cls(dim_in, dim_out * 2) |
|
|
|
def gelu(self, gate: torch.Tensor) -> torch.Tensor: |
|
if gate.device.type != "mps": |
|
return F.gelu(gate) |
|
|
|
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) |
|
|
|
def forward(self, hidden_states, scale: float = 1.0): |
|
args = () |
|
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) |
|
return hidden_states * self.gelu(gate) |
|
|
|
|
|
class ApproximateGELU(nn.Module): |
|
r""" |
|
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2: |
|
https://arxiv.org/abs/1606.08415. |
|
|
|
Parameters: |
|
dim_in (`int`): The number of channels in the input. |
|
dim_out (`int`): The number of channels in the output. |
|
""" |
|
|
|
def __init__(self, dim_in: int, dim_out: int): |
|
super().__init__() |
|
self.proj = nn.Linear(dim_in, dim_out) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.proj(x) |
|
return x * torch.sigmoid(1.702 * x) |
|
|