hma / genie /st_transformer.py
LeroyWaa's picture
draft
246c106
from torch import nn, Tensor
from einops import rearrange
import torch
from genie.attention import SelfAttention
import numpy as np
from typing import Optional
class Mlp(nn.Module):
def __init__(
self,
d_model: int,
mlp_ratio: float = 4.0,
mlp_bias: bool = True,
mlp_drop: float = 0.0,
) -> None:
super().__init__()
hidden_dim = int(d_model * mlp_ratio)
self.fc1 = nn.Linear(d_model, hidden_dim, bias=mlp_bias)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, d_model, bias=mlp_bias)
self.drop = nn.Dropout(mlp_drop)
def forward(self, x: Tensor) -> Tensor:
x = self.drop(self.act(self.fc1(x)))
x = self.drop(self.fc2(x))
return x
class STBlock(nn.Module):
# See Figure 4 of https://arxiv.org/pdf/2402.15391.pdf
def __init__(
self,
num_heads: int,
d_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.05, # add dropout
mlp_ratio: float = 4.0,
mlp_bias: bool = True,
mlp_drop: float = 0.05,
# action relevant
action_processing: str = "mlp",
jointly_predict_actions: bool = False,
mask_token_id: int = 0
) -> None:
super().__init__()
self.norm1 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05)
# sequence dim is over each frame's 16x16 patch tokens
self.spatial_attn = SelfAttention(
num_heads=num_heads,
d_model=d_model,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
qk_norm=qk_norm,
use_mup=use_mup,
attn_drop=attn_drop,
)
# sequence dim is over time sequence (16)
self.temporal_attn = SelfAttention(
num_heads=num_heads,
d_model=d_model,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
qk_norm=qk_norm,
use_mup=use_mup,
attn_drop=attn_drop,
)
self.action_prediction = jointly_predict_actions
self.action_processing = action_processing
self.norm2 = nn.Identity() if qk_norm else nn.LayerNorm(d_model, eps=1e-05)
self.mlp = Mlp(d_model=d_model, mlp_ratio=mlp_ratio, mlp_bias=mlp_bias, mlp_drop=mlp_drop)
self.action_projectors = None # set at run-time
def forward(self, x_TSC: Tensor, action_ids: Tensor = None, domain = None) -> Tensor:
"""
The main forward pass of the STBlock. It does action conditioning (with options),
(bidrectional) spatial attention, (causal) temporal attention, and action masking.
"""
T, S = x_TSC.size(1), x_TSC.size(2)
x_SC = rearrange(x_TSC, 'B T S C -> (B T) S C')
x_SC = x_SC + self.spatial_attn(self.norm1(x_SC))
# Process attention temporally
x_TC = rearrange(x_SC, '(B T) S C -> (B S) T C', T=T)
if action_ids is not None and domain is not None and self.action_projectors is not None:
# action_ids: [B, T, D]. Only apply to video parts
if "mlp" in self.action_processing:
action_ids = self.action_projectors[domain](action_ids) # does not depend on x_TC
x_TC = rearrange(x_TC, '(B S) T C -> B S T C', S=S)
x_TC = x_TC + action_ids[:, None, :x_TC.shape[2]] # expand across spatial
x_TC = rearrange(x_TC, 'B S T C -> (B S) T C', S=S)
elif "cross_attention" in self.action_processing:
x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids, action_ids)
elif "modulate" in self.action_processing:
try:
x_TC = x_TC + self.action_projectors[domain](x_TC, action_ids)
except:
import IPython; IPython.embed()
# Apply the Causal Transformer
x_TC = x_TC + self.temporal_attn(x_TC, causal=True) # [256, 16, 256]
x_TC = x_TC + self.mlp(self.norm2(x_TC))
x_TSC = rearrange(x_TC, '(B S) T C -> B T S C', S=S)
return x_TSC
class STTransformerDecoder(nn.Module):
def __init__(
self,
num_layers: int,
num_heads: int,
d_model: int,
qkv_bias: bool = False,
proj_bias: bool = True,
qk_norm: bool = True,
use_mup: bool = True,
attn_drop: float = 0.0,
mlp_ratio: float = 4.0,
mlp_bias: bool = True,
mlp_drop: float = 0.0,
# action relevant
action_processing: str = "mlp",
jointly_predict_actions: bool = False,
random_dummy_action: bool = True,
mask_token_id: int = 0,
):
super().__init__()
self.layers = nn.ModuleList([STBlock(
num_heads=num_heads,
d_model=d_model,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
qk_norm=qk_norm,
use_mup=use_mup,
attn_drop=attn_drop,
mlp_ratio=mlp_ratio,
mlp_bias=mlp_bias,
mlp_drop=mlp_drop,
action_processing=action_processing,
jointly_predict_actions=jointly_predict_actions,
mask_token_id=mask_token_id
) for _ in range(num_layers)])
self.apply(self._init_weights)
def _init_weights(self, m):
"""
Weight initialization for transformer
"""
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, tgt, action_ids=None, domain=""):
x = tgt
for layer in self.layers:
x = layer(x, action_ids=action_ids, domain=domain)
return x