ReactSeq / onmt /utils /cnn_factory.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
1.7 kB
"""
Implementation of "Convolutional Sequence to Sequence Learning"
"""
import torch
import torch.nn as nn
import torch.nn.init as init
import onmt.modules
SCALE_WEIGHT = 0.5**0.5
def shape_transform(x):
"""Tranform the size of the tensors to fit for conv input."""
return torch.unsqueeze(torch.transpose(x, 1, 2), 3)
class GatedConv(nn.Module):
"""Gated convolution for CNN class"""
def __init__(self, input_size, width=3, dropout=0.2, nopad=False):
super(GatedConv, self).__init__()
self.conv = onmt.modules.WeightNormConv2d(
input_size,
2 * input_size,
kernel_size=(width, 1),
stride=(1, 1),
padding=(width // 2 * (1 - nopad), 0),
)
# this param init is overridden by model_builder, useless then.
init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout)) ** 0.5)
self.dropout = nn.Dropout(dropout)
def forward(self, x_var):
x_var = self.dropout(x_var)
x_var = self.conv(x_var)
out, gate = x_var.split(int(x_var.size(1) / 2), 1)
out = out * torch.sigmoid(gate)
return out
class StackedCNN(nn.Module):
"""Stacked CNN class"""
def __init__(self, num_layers, input_size, cnn_kernel_width=3, dropout=0.2):
super(StackedCNN, self).__init__()
self.dropout = dropout
self.num_layers = num_layers
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(GatedConv(input_size, cnn_kernel_width, dropout))
def forward(self, x):
for conv in self.layers:
x = x + conv(x)
x *= SCALE_WEIGHT
return x