ZhifengKong's picture
upload
92740f3
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/mlfoundations/open_flamingo under the MIT license.
# LICENSE is in incl_licenses directory.
# Adapted from https://github.com/lucidrains/flamingo-pytorch under the MIT license.
# LICENSE is in incl_licenses directory.
# Adapted from https://github.com/jadore801120/attention-is-all-you-need-pytorch under the MIT license.
# LICENSE is in incl_licenses directory.
from einops import rearrange, repeat
from einops_exts import rearrange_many
import numpy as np
import torch
from torch import einsum, nn
import torch.nn.functional as F
def exists(val):
return val is not None
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.dropout = nn.Dropout(attn_dropout)
def forward(self, q, k, v, mask=None):
attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e9)
attn = self.dropout(F.softmax(attn, dim=-1))
output = torch.matmul(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
def forward(self, q, k, v, mask=None):
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
residual = q
# Pass through the pre-attention projection: b x lq x (n*dv)
# Separate different heads: b x lq x n x dv
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# Transpose for attention dot product: b x n x lq x dv
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if mask is not None:
mask = mask.unsqueeze(1) # For head axis broadcasting.
q, attn = self.attention(q, k, v, mask=mask)
# Transpose to move the head dimension back: b x lq x n x dv
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
q = self.dropout(self.fc(q))
q += residual
q = self.layer_norm(q)
return q, attn
class PositionwiseFeedForward(nn.Module):
''' A two-feed-forward-layer module '''
def __init__(self, d_in, d_hid, dropout=0.1):
super().__init__()
self.w_1 = nn.Linear(d_in, d_hid) # position-wise
self.w_2 = nn.Linear(d_hid, d_in) # position-wise
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.w_2(F.relu(self.w_1(x)))
x = self.dropout(x)
x += residual
x = self.layer_norm(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_hid, n_position=200):
super(PositionalEncoding, self).__init__()
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid))
def _get_sinusoid_encoding_table(self, n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
def forward(self, x):
return x + self.pos_table[:, :x.size(1)].clone().detach()
class EncoderLayer(nn.Module):
''' Compose with two layers '''
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input, slf_attn_mask=None):
enc_output, enc_slf_attn = self.slf_attn(
enc_input, enc_input, enc_input, mask=slf_attn_mask)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
class TransformerEncoder(nn.Module):
''' A encoder model with self attention mechanism. '''
def __init__(
self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64,
d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True):
super().__init__()
if n_position > 0:
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
else:
self.position_enc = lambda x: x
self.dropout = nn.Dropout(p=dropout)
self.layer_stack = nn.ModuleList([
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
for _ in range(n_layers)])
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.scale_emb = scale_emb
self.d_model = d_model
def forward(self, src_seq, return_attns=False):
if len(src_seq.shape) == 2:
src_seq = src_seq.unsqueeze(1)
B, L, D = src_seq.shape
enc_slf_attn_list = []
causal_mask = None
enc_output = src_seq
if self.scale_emb:
enc_output = enc_output * self.d_model ** 0.5
enc_output = self.dropout(self.position_enc(enc_output))
enc_output = self.layer_norm(enc_output)
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask)
enc_slf_attn_list += [enc_slf_attn] if return_attns else []
if return_attns:
return enc_output, enc_slf_attn_list
return enc_output
# gated cross attention
class MaskedCrossAttention(nn.Module):
def __init__(
self,
*,
dim,
dim_audio,
max_window_per_audio,
dim_head=64,
heads=8,
only_attend_immediate_media=True,
):
super().__init__()
self.max_window_per_audio = max_window_per_audio
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.only_attend_immediate_media = only_attend_immediate_media
def forward(
self,
x,
media, media_mask,
media_locations=None,
use_cached_media=False
):
if not use_cached_media:
assert (
media_locations.shape[1] == x.shape[1]
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}"
T_txt = x.shape[1]
B, L = media.shape[:2]
assert media.shape[2] == 1 # extra dim
assert L % self.max_window_per_audio == 0 # should be 4 or 8 times
h = self.heads
x = self.norm(x)
q = self.to_q(x)
media = rearrange(media, "b t n d -> b (t n) d")
k, v = self.to_kv(media).chunk(2, dim=-1)
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = einsum("... i d, ... j d -> ... i j", q, k)
# mask padded audio embeddings
media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim
sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max)
assert self.only_attend_immediate_media is False
# mask media locations
if exists(media_locations):
few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device)
for batch_idx in range(B):
media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio>
if len(media_locations_b.shape) > 1:
media_locations_b = media_locations_b.squeeze(-1)
for i in range(-1, len(media_locations_b)):
if i == -1:
if len(media_locations_b) == 1:
text_start, text_end = 0, T_txt
else:
text_start, text_end = 0, media_locations_b[i+1]
elif i == len(media_locations_b) - 1:
text_start, text_end = media_locations_b[i], T_txt
else:
text_start, text_end = media_locations_b[i], media_locations_b[i+1]
if self.only_attend_immediate_media:
look_at_window_start = max(i,0) * self.max_window_per_audio
else:
look_at_window_start = 0
look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio
few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True
sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
if exists(media_locations) and self.only_attend_immediate_media:
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(
text_without_media_mask, "b i -> b 1 i 1"
)
attn = attn.masked_fill(text_without_media_mask, 0.0)
out = einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class GatedCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_audio,
max_window_per_audio,
dim_head=64,
heads=8,
ff_mult=4,
only_attend_immediate_media=True,
):
super().__init__()
self.attn = MaskedCrossAttention(
dim=dim,
dim_audio=dim_audio,
max_window_per_audio=max_window_per_audio,
dim_head=dim_head,
heads=heads,
only_attend_immediate_media=only_attend_immediate_media,
)
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
self.ff = FeedForward(dim, mult=ff_mult)
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
def forward(
self,
x,
media,
media_mask,
media_locations=None,
use_cached_media=False,
):
x = (
self.attn(
x,
media,
media_mask,
media_locations=media_locations,
use_cached_media=use_cached_media,
)
* self.attn_gate.tanh()
+ x
)
x = self.ff(x) * self.ff_gate.tanh() + x
return x
if __name__ == '__main__':
enc = TransformerEncoder().cuda()
x = torch.randn(4, 512).cuda()
output = enc(x)
enc._use_gradient_checkpointing = True
print(output.shape)