Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .attention import * | |
class ConvLayer(nn.Module): | |
def __init__(self, d_model): | |
super(ConvLayer, self).__init__() | |
self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model, | |
kernel_size=3, padding=1, padding_mode='circular') | |
self.norm = nn.BatchNorm1d(d_model) | |
self.activ = nn.ELU() | |
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) | |
def forward(self, x): | |
x = self.downConv(x.transpose(-1, 1)) | |
x = self.norm(x) | |
x = self.activ(x) | |
x = self.maxPool(x) | |
x = x.transpose(-1,1) | |
return x | |
class EncoderLayer(nn.Module): | |
def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"): | |
super(EncoderLayer, self).__init__() | |
self.att = att | |
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1) | |
self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1) | |
self.norm1 = nn.LayerNorm(d_model) | |
self.norm2 = nn.LayerNorm(d_model) | |
self.dropout = nn.Dropout(r_drop) | |
self.activ = F.relu if activ == "relu" else F.gelu | |
def forward(self, x): | |
new_x = self.att(x, x, x) | |
x = x + self.dropout(new_x) | |
res = x = self.norm1(x) | |
res = self.dropout(self.activ(self.conv1(res.transpose(-1,1)))) | |
res = self.dropout(self.conv2(res).transpose(-1,1)) | |
return self.norm2(x+res) | |
class Encoder(nn.Module): | |
def __init__(self, enc_layers, conv_layers=None, norm_layer=None): | |
super(Encoder, self).__init__() | |
self.enc_layers = nn.ModuleList(enc_layers) | |
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None | |
self.norm = norm_layer | |
def forward(self, x): | |
# x [B, L, D] | |
if self.conv_layers is not None: | |
for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers): | |
x = enc_layer(x) | |
x = conv_layer(x) | |
x = self.enc_layers[-1](x) | |
else: | |
for enc_layer in self.enc_layers: | |
x = enc_layer(x) | |
if self.norm is not None: | |
x = self.norm(x) | |
return x |