|
""" Implementation of ONMT RNN for Input Feeding Decoding """ |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class StackedLSTM(nn.Module): |
|
""" |
|
Our own implementation of stacked LSTM. |
|
Needed for the decoder, because we do input feeding. |
|
""" |
|
|
|
def __init__(self, num_layers, input_size, hidden_size, dropout): |
|
super(StackedLSTM, self).__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.num_layers = num_layers |
|
self.layers = nn.ModuleList() |
|
|
|
for _ in range(num_layers): |
|
self.layers.append(nn.LSTMCell(input_size, hidden_size)) |
|
input_size = hidden_size |
|
|
|
def forward(self, input_feed, hidden): |
|
h_0, c_0 = hidden |
|
h_1, c_1 = [], [] |
|
for i, layer in enumerate(self.layers): |
|
h_1_i, c_1_i = layer(input_feed, (h_0[i], c_0[i])) |
|
input_feed = h_1_i |
|
if i + 1 != self.num_layers: |
|
input_feed = self.dropout(input_feed) |
|
h_1 += [h_1_i] |
|
c_1 += [c_1_i] |
|
|
|
h_1 = torch.stack(h_1) |
|
c_1 = torch.stack(c_1) |
|
|
|
return input_feed, (h_1, c_1) |
|
|
|
|
|
class StackedGRU(nn.Module): |
|
""" |
|
Our own implementation of stacked GRU. |
|
Needed for the decoder, because we do input feeding. |
|
""" |
|
|
|
def __init__(self, num_layers, input_size, hidden_size, dropout): |
|
super(StackedGRU, self).__init__() |
|
self.dropout = nn.Dropout(dropout) |
|
self.num_layers = num_layers |
|
self.layers = nn.ModuleList() |
|
|
|
for _ in range(num_layers): |
|
self.layers.append(nn.GRUCell(input_size, hidden_size)) |
|
input_size = hidden_size |
|
|
|
def forward(self, input_feed, hidden): |
|
h_1 = [] |
|
for i, layer in enumerate(self.layers): |
|
h_1_i = layer(input_feed, hidden[0][i]) |
|
input_feed = h_1_i |
|
if i + 1 != self.num_layers: |
|
input_feed = self.dropout(input_feed) |
|
h_1 += [h_1_i] |
|
|
|
h_1 = torch.stack(h_1) |
|
return input_feed, (h_1,) |
|
|