|
""" |
|
Implementation of "Convolutional Sequence to Sequence Learning" |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.init as init |
|
|
|
import onmt.modules |
|
|
|
SCALE_WEIGHT = 0.5**0.5 |
|
|
|
|
|
def shape_transform(x): |
|
"""Tranform the size of the tensors to fit for conv input.""" |
|
return torch.unsqueeze(torch.transpose(x, 1, 2), 3) |
|
|
|
|
|
class GatedConv(nn.Module): |
|
"""Gated convolution for CNN class""" |
|
|
|
def __init__(self, input_size, width=3, dropout=0.2, nopad=False): |
|
super(GatedConv, self).__init__() |
|
self.conv = onmt.modules.WeightNormConv2d( |
|
input_size, |
|
2 * input_size, |
|
kernel_size=(width, 1), |
|
stride=(1, 1), |
|
padding=(width // 2 * (1 - nopad), 0), |
|
) |
|
|
|
init.xavier_uniform_(self.conv.weight, gain=(4 * (1 - dropout)) ** 0.5) |
|
self.dropout = nn.Dropout(dropout) |
|
|
|
def forward(self, x_var): |
|
x_var = self.dropout(x_var) |
|
x_var = self.conv(x_var) |
|
out, gate = x_var.split(int(x_var.size(1) / 2), 1) |
|
out = out * torch.sigmoid(gate) |
|
return out |
|
|
|
|
|
class StackedCNN(nn.Module): |
|
"""Stacked CNN class""" |
|
|
|
def __init__(self, num_layers, input_size, cnn_kernel_width=3, dropout=0.2): |
|
super(StackedCNN, self).__init__() |
|
self.dropout = dropout |
|
self.num_layers = num_layers |
|
self.layers = nn.ModuleList() |
|
for _ in range(num_layers): |
|
self.layers.append(GatedConv(input_size, cnn_kernel_width, dropout)) |
|
|
|
def forward(self, x): |
|
for conv in self.layers: |
|
x = x + conv(x) |
|
x *= SCALE_WEIGHT |
|
return x |
|
|