Livia_Zaharia
added code for the first time
bacf16b
raw
history blame
2.12 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import *
class ConvLayer(nn.Module):
def __init__(self, d_model):
super(ConvLayer, self).__init__()
self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model,
kernel_size=3, padding=1, padding_mode='circular')
self.norm = nn.BatchNorm1d(d_model)
self.activ = nn.ELU()
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.downConv(x.transpose(-1, 1))
x = self.norm(x)
x = self.activ(x)
x = self.maxPool(x)
x = x.transpose(-1,1)
return x
class EncoderLayer(nn.Module):
def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"):
super(EncoderLayer, self).__init__()
self.att = att
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(r_drop)
self.activ = F.relu if activ == "relu" else F.gelu
def forward(self, x):
new_x = self.att(x, x, x)
x = x + self.dropout(new_x)
res = x = self.norm1(x)
res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
res = self.dropout(self.conv2(res).transpose(-1,1))
return self.norm2(x+res)
class Encoder(nn.Module):
def __init__(self, enc_layers, conv_layers=None, norm_layer=None):
super(Encoder, self).__init__()
self.enc_layers = nn.ModuleList(enc_layers)
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
self.norm = norm_layer
def forward(self, x):
# x [B, L, D]
if self.conv_layers is not None:
for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers):
x = enc_layer(x)
x = conv_layer(x)
x = self.enc_layers[-1](x)
else:
for enc_layer in self.enc_layers:
x = enc_layer(x)
if self.norm is not None:
x = self.norm(x)
return x