Xhr0306's picture
update
15fa80a
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical
import models.pos_encoding as pos_encoding
import numpy as np
class Text2Motion_Transformer(nn.Module):
def __init__(self,
num_vq=1024,
embed_dim=512,
clip_dim=512,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4,
):
super().__init__()
self.trans_base = CrossCondTransBase(num_vq, embed_dim, clip_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
self.trans_head = CrossCondTransHead(num_vq, embed_dim, block_size, num_layers, n_head, drop_out_rate, fc_rate)
self.block_size = block_size
self.num_vq = num_vq
def get_block_size(self):
return self.block_size
def forward(self, idxs, clip_feature):
feat = self.trans_base(idxs, clip_feature)
logits = self.trans_head(feat)
return logits
def sample(self, clip_feature, if_categorial=False,att=False):
for k in range(self.block_size):
if k == 0:
x = []
logits = self.forward(x, clip_feature)
if att==True:
return self.trans_base.blocks[0].get_attention_weights()
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
else:
x = xs
logits = self.forward(x, clip_feature)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
if if_categorial:
dist = Categorical(probs)
idx = dist.sample()
if idx == self.num_vq:
break
idx = idx.unsqueeze(-1)
else:
_, idx = torch.topk(probs, k=1, dim=-1)
if idx[0] == self.num_vq:
break
# append to the sequence and continue
if k == 0:
xs = idx
else:
xs = torch.cat((xs, idx), dim=1)
if k == self.block_size - 1:
return xs[:, :-1]
return xs
class CausalCrossConditionalSelfAttention(nn.Module):
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1):
super().__init__()
assert embed_dim % 8 == 0
# key, query, value projections for all heads
self.key = nn.Linear(embed_dim, embed_dim)
self.query = nn.Linear(embed_dim, embed_dim)
self.value = nn.Linear(embed_dim, embed_dim)
self.attn_drop = nn.Dropout(drop_out_rate)
self.resid_drop = nn.Dropout(drop_out_rate)
self.proj = nn.Linear(embed_dim, embed_dim)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
self.n_head = n_head
self.att=None
def forward(self, x):
B, T, C = x.size()
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
self.att=att
# output projection
y = self.resid_drop(self.proj(y))
return y
def get_attention_weights(self):
return self.att
class Block(nn.Module):
def __init__(self, embed_dim=512, block_size=16, n_head=8, drop_out_rate=0.1, fc_rate=4,num_layers=-1,num=None):
super().__init__()
self.num_layers=num_layers
self.num=num
self.attn_weight=None
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
self.attn = CausalCrossConditionalSelfAttention(embed_dim, block_size, n_head, drop_out_rate)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, fc_rate * embed_dim),
nn.GELU(),
nn.Linear(fc_rate * embed_dim, embed_dim),
nn.Dropout(drop_out_rate),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
if self.num==0:
self.attn_weight = self.attn.get_attention_weights()
x = x + self.mlp(self.ln2(x))
return x
def get_attention_weights(self):
return self.attn_weight
class CrossCondTransBase(nn.Module):
def __init__(self,
num_vq=1024,
embed_dim=512,
clip_dim=512,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4,
):
super().__init__()
self.tok_emb = nn.Embedding(num_vq + 2, embed_dim)
self.cond_emb = nn.Linear(clip_dim, embed_dim)
self.pos_embedding = nn.Embedding(block_size, embed_dim)
self.drop = nn.Dropout(drop_out_rate)
# transformer block
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate,num=_) for _ in range(num_layers)])
self.pos_embed = pos_encoding.PositionEmbedding(block_size, embed_dim, 0.0, False)
self.block_size = block_size
self.first_att_weights = None
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, idx, clip_feature):
if len(idx) == 0:
token_embeddings = self.cond_emb(clip_feature).unsqueeze(1)
else:
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
# forward the Trans model
token_embeddings = self.tok_emb(idx)
# clip_feature.dtype = token_embeddings.dtype
token_embeddings = torch.cat([self.cond_emb(clip_feature.to(torch.float32)).unsqueeze(1), token_embeddings], dim=1)
x = self.pos_embed(token_embeddings)
x = self.blocks(x)
return x
class CrossCondTransHead(nn.Module):
def __init__(self,
num_vq=1024,
embed_dim=512,
block_size=16,
num_layers=2,
n_head=8,
drop_out_rate=0.1,
fc_rate=4):
super().__init__()
self.blocks = nn.Sequential(*[Block(embed_dim, block_size, n_head, drop_out_rate, fc_rate,num=_) for _ in range(num_layers)])
self.ln_f = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, num_vq + 1, bias=False)
self.block_size = block_size
self.apply(self._init_weights)
def get_block_size(self):
return self.block_size
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(self, x):
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
return logits