alibabasglab's picture
Upload 161 files
8e8cd3e verified
raw
history blame
6.31 kB
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
# source: https://github.com/lucidrains/conformer/blob/master/conformer/conformer.py
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
class Swish(nn.Module):
def forward(self, x):
return x * x.sigmoid()
class GLU(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
out, gate = x.chunk(2, dim=self.dim)
return out * gate.sigmoid()
class DepthWiseConv1d(nn.Module):
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)
# attention, feedforward, and conv module
class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)
class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
max_pos_emb = 512
):
super().__init__()
inner_dim = dim_head * heads
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.max_pos_emb = max_pos_emb
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)
self.dropout = nn.Dropout(dropout)
def forward(self, x, context = None, mask = None, context_mask = None):
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
context = default(context, x)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# shaw's relative positional embedding
seq = torch.arange(n, device = device)
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
rel_pos_emb = self.rel_pos_emb(dist).to(q)
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
dots = dots + pos_attn
if exists(mask) or exists(context_mask):
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
mask_value = -torch.finfo(dots.dtype).max
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
dots.masked_fill_(~mask, mask_value)
attn = dots.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return self.dropout(out)
class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class ConformerConvModule(nn.Module):
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.):
super().__init__()
inner_dim = dim * expansion_factor
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net = nn.Sequential(
nn.LayerNorm(dim),
Rearrange('b n c -> b c n'),
nn.Conv1d(dim, inner_dim * 2, 1),
GLU(dim=1),
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
Swish(),
nn.Conv1d(inner_dim, dim, 1),
Rearrange('b c n -> b n c'),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
# Conformer Block
class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, mask = None):
x = self.ff1(x) + x
x = self.attn(x, mask = mask) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x