Build error
Build error
# Copyright (c) 2024 NVIDIA CORPORATION. | |
# Licensed under the MIT license. | |
# Adapted from under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
# Adapted from under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
# Adapted from under the MIT license. | |
# LICENSE is in incl_licenses directory. | |
from einops import rearrange, repeat | |
from einops_exts import rearrange_many | |
import numpy as np | |
import torch | |
from torch import einsum, nn | |
import torch.nn.functional as F | |
def exists(val): | |
return val is not None | |
def FeedForward(dim, mult=4): | |
inner_dim = int(dim * mult) | |
return nn.Sequential( | |
nn.LayerNorm(dim), | |
nn.Linear(dim, inner_dim, bias=False), | |
nn.GELU(), | |
nn.Linear(inner_dim, dim, bias=False), | |
) | |
class ScaledDotProductAttention(nn.Module): | |
''' Scaled Dot-Product Attention ''' | |
def __init__(self, temperature, attn_dropout=0.1): | |
super().__init__() | |
self.temperature = temperature | |
self.dropout = nn.Dropout(attn_dropout) | |
def forward(self, q, k, v, mask=None): | |
attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) | |
if mask is not None: | |
attn = attn.masked_fill(mask == 0, -1e9) | |
attn = self.dropout(F.softmax(attn, dim=-1)) | |
output = torch.matmul(attn, v) | |
return output, attn | |
class MultiHeadAttention(nn.Module): | |
''' Multi-Head Attention module ''' | |
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): | |
super().__init__() | |
self.n_head = n_head | |
self.d_k = d_k | |
self.d_v = d_v | |
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) | |
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) | |
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) | |
self.fc = nn.Linear(n_head * d_v, d_model, bias=False) | |
self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5) | |
self.dropout = nn.Dropout(dropout) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
def forward(self, q, k, v, mask=None): | |
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head | |
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) | |
residual = q | |
# Pass through the pre-attention projection: b x lq x (n*dv) | |
# Separate different heads: b x lq x n x dv | |
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) | |
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) | |
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) | |
# Transpose for attention dot product: b x n x lq x dv | |
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) | |
if mask is not None: | |
mask = mask.unsqueeze(1) # For head axis broadcasting. | |
q, attn = self.attention(q, k, v, mask=mask) | |
# Transpose to move the head dimension back: b x lq x n x dv | |
# Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) | |
q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) | |
q = self.dropout(self.fc(q)) | |
q += residual | |
q = self.layer_norm(q) | |
return q, attn | |
class PositionwiseFeedForward(nn.Module): | |
''' A two-feed-forward-layer module ''' | |
def __init__(self, d_in, d_hid, dropout=0.1): | |
super().__init__() | |
self.w_1 = nn.Linear(d_in, d_hid) # position-wise | |
self.w_2 = nn.Linear(d_hid, d_in) # position-wise | |
self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
residual = x | |
x = self.w_2(F.relu(self.w_1(x))) | |
x = self.dropout(x) | |
x += residual | |
x = self.layer_norm(x) | |
return x | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_hid, n_position=200): | |
super(PositionalEncoding, self).__init__() | |
self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) | |
def _get_sinusoid_encoding_table(self, n_position, d_hid): | |
def get_position_angle_vec(position): | |
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] | |
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) | |
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i | |
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 | |
return torch.FloatTensor(sinusoid_table).unsqueeze(0) | |
def forward(self, x): | |
return x + self.pos_table[:, :x.size(1)].clone().detach() | |
class EncoderLayer(nn.Module): | |
''' Compose with two layers ''' | |
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.0): | |
super(EncoderLayer, self).__init__() | |
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) | |
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) | |
def forward(self, enc_input, slf_attn_mask=None): | |
enc_output, enc_slf_attn = self.slf_attn( | |
enc_input, enc_input, enc_input, mask=slf_attn_mask) | |
enc_output = self.pos_ffn(enc_output) | |
return enc_output, enc_slf_attn | |
class TransformerEncoder(nn.Module): | |
''' A encoder model with self attention mechanism. ''' | |
def __init__( | |
self, d_word_vec=512, n_layers=6, n_head=8, d_k=64, d_v=64, | |
d_model=512, d_inner=2048, dropout=0.0, n_position=16, scale_emb=True): | |
super().__init__() | |
if n_position > 0: | |
self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) | |
else: | |
self.position_enc = lambda x: x | |
self.dropout = nn.Dropout(p=dropout) | |
self.layer_stack = nn.ModuleList([ | |
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) | |
for _ in range(n_layers)]) | |
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) | |
self.scale_emb = scale_emb | |
self.d_model = d_model | |
def forward(self, src_seq, return_attns=False): | |
if len(src_seq.shape) == 2: | |
src_seq = src_seq.unsqueeze(1) | |
B, L, D = src_seq.shape | |
enc_slf_attn_list = [] | |
causal_mask = None | |
enc_output = src_seq | |
if self.scale_emb: | |
enc_output = enc_output * self.d_model ** 0.5 | |
enc_output = self.dropout(self.position_enc(enc_output)) | |
enc_output = self.layer_norm(enc_output) | |
for enc_layer in self.layer_stack: | |
enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=causal_mask) | |
enc_slf_attn_list += [enc_slf_attn] if return_attns else [] | |
if return_attns: | |
return enc_output, enc_slf_attn_list | |
return enc_output | |
# gated cross attention | |
class MaskedCrossAttention(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
dim_audio, | |
max_window_per_audio, | |
dim_head=64, | |
heads=8, | |
only_attend_immediate_media=True, | |
): | |
super().__init__() | |
self.max_window_per_audio = max_window_per_audio | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
inner_dim = dim_head * heads | |
self.norm = nn.LayerNorm(dim) | |
self.to_q = nn.Linear(dim, inner_dim, bias=False) | |
self.to_kv = nn.Linear(dim_audio, inner_dim * 2, bias=False) | |
self.to_out = nn.Linear(inner_dim, dim, bias=False) | |
self.only_attend_immediate_media = only_attend_immediate_media | |
def forward( | |
self, | |
x, | |
media, media_mask, | |
media_locations=None, | |
use_cached_media=False | |
): | |
if not use_cached_media: | |
assert ( | |
media_locations.shape[1] == x.shape[1] | |
), f"media_location.shape is {media_locations.shape} but x.shape is {x.shape}" | |
T_txt = x.shape[1] | |
B, L = media.shape[:2] | |
assert media.shape[2] == 1 # extra dim | |
assert L % self.max_window_per_audio == 0 # should be 4 or 8 times | |
h = self.heads | |
x = self.norm(x) | |
q = self.to_q(x) | |
media = rearrange(media, "b t n d -> b (t n) d") | |
k, v = self.to_kv(media).chunk(2, dim=-1) | |
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) | |
q = q * self.scale | |
sim = einsum("... i d, ... j d -> ... i j", q, k) | |
# mask padded audio embeddings | |
media_mask = rearrange(media_mask, "b i n -> b 1 1 (i n)").bool() # n = 1 is extra dim | |
sim = sim.masked_fill(~media_mask, -torch.finfo(sim.dtype).max) | |
assert self.only_attend_immediate_media is False | |
# mask media locations | |
if exists(media_locations): | |
few_shot_mask = torch.zeros(B, T_txt, L).bool().to(sim.device) | |
for batch_idx in range(B): | |
media_locations_b = media_locations[batch_idx].nonzero() # locations of <audio> | |
if len(media_locations_b.shape) > 1: | |
media_locations_b = media_locations_b.squeeze(-1) | |
for i in range(-1, len(media_locations_b)): | |
if i == -1: | |
if len(media_locations_b) == 1: | |
text_start, text_end = 0, T_txt | |
else: | |
text_start, text_end = 0, media_locations_b[i+1] | |
elif i == len(media_locations_b) - 1: | |
text_start, text_end = media_locations_b[i], T_txt | |
else: | |
text_start, text_end = media_locations_b[i], media_locations_b[i+1] | |
if self.only_attend_immediate_media: | |
look_at_window_start = max(i,0) * self.max_window_per_audio | |
else: | |
look_at_window_start = 0 | |
look_at_window_end = (max(i,0) + 1) * self.max_window_per_audio | |
few_shot_mask[batch_idx, text_start:text_end, look_at_window_start:look_at_window_end] = True | |
sim = sim.masked_fill(~few_shot_mask.unsqueeze(1), -torch.finfo(sim.dtype).max) | |
sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
attn = sim.softmax(dim=-1) | |
if exists(media_locations) and self.only_attend_immediate_media: | |
text_without_media_mask = text_time == 0 | |
text_without_media_mask = rearrange( | |
text_without_media_mask, "b i -> b 1 i 1" | |
) | |
attn = attn.masked_fill(text_without_media_mask, 0.0) | |
out = einsum("... i j, ... j d -> ... i d", attn, v) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
return self.to_out(out) | |
class GatedCrossAttentionBlock(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
dim_audio, | |
max_window_per_audio, | |
dim_head=64, | |
heads=8, | |
ff_mult=4, | |
only_attend_immediate_media=True, | |
): | |
super().__init__() | |
self.attn = MaskedCrossAttention( | |
dim=dim, | |
dim_audio=dim_audio, | |
max_window_per_audio=max_window_per_audio, | |
dim_head=dim_head, | |
heads=heads, | |
only_attend_immediate_media=only_attend_immediate_media, | |
) | |
self.attn_gate = nn.Parameter(torch.tensor([0.0])) | |
self.ff = FeedForward(dim, mult=ff_mult) | |
self.ff_gate = nn.Parameter(torch.tensor([0.0])) | |
def forward( | |
self, | |
x, | |
media, | |
media_mask, | |
media_locations=None, | |
use_cached_media=False, | |
): | |
x = ( | |
self.attn( | |
x, | |
media, | |
media_mask, | |
media_locations=media_locations, | |
use_cached_media=use_cached_media, | |
) | |
* self.attn_gate.tanh() | |
+ x | |
) | |
x = self.ff(x) * self.ff_gate.tanh() + x | |
return x | |
if __name__ == '__main__': | |
enc = TransformerEncoder().cuda() | |
x = torch.randn(4, 512).cuda() | |
output = enc(x) | |
enc._use_gradient_checkpointing = True | |
print(output.shape) |