Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn, einsum | |
import torch.nn.functional as F | |
from einops import rearrange | |
from einops.layers.torch import Rearrange | |
# source: https://github.com/lucidrains/conformer/blob/master/conformer/conformer.py | |
# helper functions | |
def exists(val): | |
return val is not None | |
def default(val, d): | |
return val if exists(val) else d | |
def calc_same_padding(kernel_size): | |
pad = kernel_size // 2 | |
return (pad, pad - (kernel_size + 1) % 2) | |
class Swish(nn.Module): | |
def forward(self, x): | |
return x * x.sigmoid() | |
class GLU(nn.Module): | |
def __init__(self, dim): | |
super().__init__() | |
self.dim = dim | |
def forward(self, x): | |
out, gate = x.chunk(2, dim=self.dim) | |
return out * gate.sigmoid() | |
class DepthWiseConv1d(nn.Module): | |
def __init__(self, chan_in, chan_out, kernel_size, padding): | |
super().__init__() | |
self.padding = padding | |
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) | |
def forward(self, x): | |
x = F.pad(x, self.padding) | |
return self.conv(x) | |
# attention, feedforward, and conv module | |
class Scale(nn.Module): | |
def __init__(self, scale, fn): | |
super().__init__() | |
self.fn = fn | |
self.scale = scale | |
def forward(self, x, **kwargs): | |
return self.fn(x, **kwargs) * self.scale | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.fn = fn | |
self.norm = nn.LayerNorm(dim) | |
def forward(self, x, **kwargs): | |
x = self.norm(x) | |
return self.fn(x, **kwargs) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
heads = 8, | |
dim_head = 64, | |
dropout = 0., | |
max_pos_emb = 512 | |
): | |
super().__init__() | |
inner_dim = dim_head * heads | |
self.heads= heads | |
self.scale = dim_head ** -0.5 | |
self.to_q = nn.Linear(dim, inner_dim, bias = False) | |
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | |
self.to_out = nn.Linear(inner_dim, dim) | |
self.max_pos_emb = max_pos_emb | |
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head) | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x, context = None, mask = None, context_mask = None): | |
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context) | |
context = default(context, x) | |
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
# shaw's relative positional embedding | |
seq = torch.arange(n, device = device) | |
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j') | |
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb | |
rel_pos_emb = self.rel_pos_emb(dist).to(q) | |
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale | |
dots = dots + pos_attn | |
if exists(mask) or exists(context_mask): | |
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device)) | |
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device)) | |
mask_value = -torch.finfo(dots.dtype).max | |
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j') | |
dots.masked_fill_(~mask, mask_value) | |
attn = dots.softmax(dim = -1) | |
out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
out = self.to_out(out) | |
return self.dropout(out) | |
class FeedForward(nn.Module): | |
def __init__( | |
self, | |
dim, | |
mult = 4, | |
dropout = 0. | |
): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, dim * mult), | |
Swish(), | |
nn.Dropout(dropout), | |
nn.Linear(dim * mult, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class ConformerConvModule(nn.Module): | |
def __init__( | |
self, | |
dim, | |
causal = False, | |
expansion_factor = 2, | |
kernel_size = 31, | |
dropout = 0.): | |
super().__init__() | |
inner_dim = dim * expansion_factor | |
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) | |
self.net = nn.Sequential( | |
nn.LayerNorm(dim), | |
Rearrange('b n c -> b c n'), | |
nn.Conv1d(dim, inner_dim * 2, 1), | |
GLU(dim=1), | |
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), | |
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), | |
Swish(), | |
nn.Conv1d(inner_dim, dim, 1), | |
Rearrange('b c n -> b n c'), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
# Conformer Block | |
class ConformerBlock(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
dim_head = 64, | |
heads = 8, | |
ff_mult = 4, | |
conv_expansion_factor = 2, | |
conv_kernel_size = 31, | |
attn_dropout = 0., | |
ff_dropout = 0., | |
conv_dropout = 0. | |
): | |
super().__init__() | |
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout) | |
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout) | |
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
self.attn = PreNorm(dim, self.attn) | |
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) | |
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) | |
self.post_norm = nn.LayerNorm(dim) | |
def forward(self, x, mask = None): | |
x = self.ff1(x) + x | |
x = self.attn(x, mask = mask) + x | |
x = self.conv(x) + x | |
x = self.ff2(x) + x | |
x = self.post_norm(x) | |
return x | |