Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from math import sqrt | |
class CausalConv1d(torch.nn.Conv1d): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
dilation=1, | |
groups=1, | |
bias=True): | |
self.__padding = (kernel_size - 1) * dilation | |
super(CausalConv1d, self).__init__( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=self.__padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias) | |
def forward(self, input): | |
result = super(CausalConv1d, self).forward(input) | |
if self.__padding != 0: | |
return result[:, :, :-self.__padding] | |
return result | |
class TriangularCausalMask(): | |
def __init__(self, b, n, device="cpu"): | |
mask_shape = [b, 1, n, n] | |
with torch.no_grad(): | |
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) | |
def mask(self): | |
return self._mask | |
class MultiheadAttention(nn.Module): | |
def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1): | |
super(MultiheadAttention, self).__init__() | |
self.h, self.d, self.mask_flag= n_heads, d_keys, mask_flag | |
self.proj_q = nn.Linear(d_model, self.h * self.d) | |
self.proj_k = nn.Linear(d_model, self.h * self.d) | |
self.proj_v = nn.Linear(d_model, self.h * self.d) | |
self.proj_out = nn.Linear(self.h * self.d, d_model) | |
self.dropout = nn.Dropout(r_att_drop) | |
def forward(self, q, k, v): | |
b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d | |
q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v) # b, n_*, h*d | |
q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v]) # b, n_*, h, d | |
scores = torch.einsum('bnhd,bmhd->bhnm', (q,k)) # b, h, n_q, n_k | |
if self.mask_flag: | |
att_mask = TriangularCausalMask(b, n_q, device=q.device) | |
scores.masked_fill_(att_mask.mask, -np.inf) | |
att = F.softmax(scores / (self.d ** .5), dim=-1) # b, h, n_q, n_k | |
att = self.dropout(att) | |
att_out = torch.einsum('bhnm,bmhd->bnhd', (att,v)) # b, n_q, h, d | |
att_out = att_out.reshape(b, -1, h*d) # b, n_q, h*d | |
out = self.proj_out(att_out) # b, n_q, d_model | |
return out | |