Livia_Zaharia
added code for the first time
bacf16b
raw
history blame
1.44 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import *
class DecoderLayer(nn.Module):
def __init__(self, self_att, cross_att, d_model, d_fcn,
r_drop, activ="relu"):
super(DecoderLayer, self).__init__()
self.self_att = self_att
self.cross_att = cross_att
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(r_drop)
self.activ = F.relu if activ == "relu" else F.gelu
def forward(self, x_dec, x_enc):
x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec)
x_dec = self.norm1(x_dec)
x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc)
res = x_dec = self.norm2(x_dec)
res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
res = self.dropout(self.conv2(res).transpose(-1,1))
return self.norm3(x_dec+res)
class Decoder(nn.Module):
def __init__(self, layers, norm_layer=None):
super(Decoder, self).__init__()
self.layers = nn.ModuleList(layers)
self.norm = norm_layer
def forward(self, x_dec, x_enc):
for layer in self.layers:
x_dec = layer(x_dec, x_enc)
if self.norm is not None:
x_dec = self.norm(x_dec)
return x_dec