import torch import torch.nn as nn import torch.nn.functional as F from .attention import * class ConvLayer(nn.Module): def __init__(self, d_model): super(ConvLayer, self).__init__() self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model, kernel_size=3, padding=1, padding_mode='circular') self.norm = nn.BatchNorm1d(d_model) self.activ = nn.ELU() self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) def forward(self, x): x = self.downConv(x.transpose(-1, 1)) x = self.norm(x) x = self.activ(x) x = self.maxPool(x) x = x.transpose(-1,1) return x class EncoderLayer(nn.Module): def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"): super(EncoderLayer, self).__init__() self.att = att self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(r_drop) self.activ = F.relu if activ == "relu" else F.gelu def forward(self, x): new_x = self.att(x, x, x) x = x + self.dropout(new_x) res = x = self.norm1(x) res = self.dropout(self.activ(self.conv1(res.transpose(-1,1)))) res = self.dropout(self.conv2(res).transpose(-1,1)) return self.norm2(x+res) class Encoder(nn.Module): def __init__(self, enc_layers, conv_layers=None, norm_layer=None): super(Encoder, self).__init__() self.enc_layers = nn.ModuleList(enc_layers) self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None self.norm = norm_layer def forward(self, x): # x [B, L, D] if self.conv_layers is not None: for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers): x = enc_layer(x) x = conv_layer(x) x = self.enc_layers[-1](x) else: for enc_layer in self.enc_layers: x = enc_layer(x) if self.norm is not None: x = self.norm(x) return x