import torch import torch.nn as nn import torch.nn.functional as F from .attention import * class DecoderLayer(nn.Module): def __init__(self, self_att, cross_att, d_model, d_fcn, r_drop, activ="relu"): super(DecoderLayer, self).__init__() self.self_att = self_att self.cross_att = cross_att self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1) self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(r_drop) self.activ = F.relu if activ == "relu" else F.gelu def forward(self, x_dec, x_enc): x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec) x_dec = self.norm1(x_dec) x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc) res = x_dec = self.norm2(x_dec) res = self.dropout(self.activ(self.conv1(res.transpose(-1,1)))) res = self.dropout(self.conv2(res).transpose(-1,1)) return self.norm3(x_dec+res) class Decoder(nn.Module): def __init__(self, layers, norm_layer=None): super(Decoder, self).__init__() self.layers = nn.ModuleList(layers) self.norm = norm_layer def forward(self, x_dec, x_enc): for layer in self.layers: x_dec = layer(x_dec, x_enc) if self.norm is not None: x_dec = self.norm(x_dec) return x_dec