jammmmm's picture
Add spar3d demo files
38dbec8
raw
history blame
9.21 kB
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
class Modulation(nn.Module):
def __init__(
self,
embedding_dim: int,
condition_dim: int,
zero_init: bool = False,
single_layer: bool = False,
):
super().__init__()
self.silu = nn.SiLU()
if single_layer:
self.linear1 = nn.Identity()
else:
self.linear1 = nn.Linear(condition_dim, condition_dim)
self.linear2 = nn.Linear(condition_dim, embedding_dim * 2)
# Only zero init the last linear layer
if zero_init:
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
emb = self.linear2(self.silu(self.linear1(condition)))
scale, shift = torch.chunk(emb, 2, dim=1)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
return x
class FeedForward(nn.Module):
r"""
A feed-forward layer.
Parameters:
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(act_fn)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
out_bias: bool = True,
):
super().__init__()
self.inner_dim = dim_head * heads
self.num_heads = heads
self.scale = dim_head**-0.5
self.dropout = dropout
# Linear projections
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
# Output projection
self.to_out = nn.ModuleList(
[
nn.Linear(self.inner_dim, query_dim, bias=out_bias),
nn.Dropout(dropout),
]
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
# Project queries, keys, and values
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
# Reshape for multi-head attention
query = query.reshape(
batch_size, sequence_length, self.num_heads, -1
).transpose(1, 2)
key = key.reshape(batch_size, sequence_length, self.num_heads, -1).transpose(
1, 2
)
value = value.reshape(
batch_size, sequence_length, self.num_heads, -1
).transpose(1, 2)
# Compute scaled dot product attention
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
scale=self.scale,
)
# Reshape and project output
hidden_states = hidden_states.transpose(1, 2).reshape(
batch_size, sequence_length, self.inner_dim
)
# Apply output projection and dropout
for module in self.to_out:
hidden_states = module(hidden_states)
return hidden_states
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "geglu",
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
):
super().__init__()
# Self-Attn
self.norm1 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=attention_bias,
)
# Feed-forward
self.norm3 = nn.LayerNorm(
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
)
self.ff = FeedForward(
dim,
activation_fn=activation_fn,
)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Self-Attention
norm_hidden_states = self.norm1(hidden_states)
hidden_states = (
self.attn1(
norm_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
ff_output = self.ff(self.norm3(hidden_states))
hidden_states = ff_output + hidden_states
return hidden_states
class GELU(nn.Module):
r"""
GELU activation function with tanh approximation support with `approximate="tanh"`.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
"""
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(
dtype=gate.dtype
)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states = self.gelu(hidden_states)
return hidden_states
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
linear_cls = nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0):
args = ()
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
r"""
The approximate form of Gaussian Error Linear Unit (GELU). For more details, see section 2:
https://arxiv.org/abs/1606.08415.
Parameters:
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)