File size: 4,745 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
"""Implementation of the CNN Decoder part of
"Convolutional Sequence to Sequence Learning"
"""
import torch
import torch.nn as nn
from onmt.modules import ConvMultiStepAttention, GlobalAttention
from onmt.utils.cnn_factory import shape_transform, GatedConv
from onmt.decoders.decoder import DecoderBase
SCALE_WEIGHT = 0.5**0.5
class CNNDecoder(DecoderBase):
"""Decoder based on "Convolutional Sequence to Sequence Learning"
:cite:`DBLP:journals/corr/GehringAGYD17`.
Consists of residual convolutional layers, with ConvMultiStepAttention.
"""
def __init__(
self,
num_layers,
hidden_size,
attn_type,
copy_attn,
cnn_kernel_width,
dropout,
embeddings,
copy_attn_type,
):
super(CNNDecoder, self).__init__()
self.cnn_kernel_width = cnn_kernel_width
self.embeddings = embeddings
# Decoder State
self.state = {}
input_size = self.embeddings.embedding_size
self.linear = nn.Linear(input_size, hidden_size)
self.conv_layers = nn.ModuleList(
[
GatedConv(hidden_size, cnn_kernel_width, dropout, True)
for i in range(num_layers)
]
)
self.attn_layers = nn.ModuleList(
[ConvMultiStepAttention(hidden_size) for i in range(num_layers)]
)
# CNNDecoder has its own attention mechanism.
# Set up a separate copy attention layer if needed.
assert not copy_attn, "Copy mechanism not yet tested in conv2conv"
if copy_attn:
self.copy_attn = GlobalAttention(hidden_size, attn_type=copy_attn_type)
else:
self.copy_attn = None
@classmethod
def from_opt(cls, opt, embeddings):
"""Alternate constructor."""
return cls(
opt.dec_layers,
opt.dec_hid_size,
opt.global_attention,
opt.copy_attn,
opt.cnn_kernel_width,
opt.dropout[0] if type(opt.dropout) is list else opt.dropout,
embeddings,
opt.copy_attn_type,
)
def init_state(self, _, enc_out, enc_hidden):
"""Init decoder state."""
self.state["src"] = (enc_out + enc_hidden) * SCALE_WEIGHT
self.state["previous_input"] = None
def map_state(self, fn):
self.state["src"] = fn(self.state["src"], 0)
if self.state["previous_input"] is not None:
self.state["previous_input"] = fn(self.state["previous_input"], 0)
def detach_state(self):
self.state["previous_input"] = self.state["previous_input"].detach()
def forward(self, tgt, enc_out, step=None, **kwargs):
"""See :obj:`onmt.modules.RNNDecoderBase.forward()`"""
if self.state["previous_input"] is not None:
tgt = torch.cat([self.state["previous_input"], tgt], 1)
dec_outs = []
attns = {"std": []}
if self.copy_attn is not None:
attns["copy"] = []
emb = self.embeddings(tgt)
assert emb.dim() == 3 # batch x len x embedding_dim
tgt_emb = emb
# The output of CNNEncoder.
enc_out_t = enc_out
# The combination of output of CNNEncoder and source embeddings.
enc_out_c = self.state["src"]
emb_reshape = tgt_emb.view(tgt_emb.size(0) * tgt_emb.size(1), -1)
linear_out = self.linear(emb_reshape)
x = linear_out.view(tgt_emb.size(0), tgt_emb.size(1), -1)
x = shape_transform(x)
pad = torch.zeros(x.size(0), x.size(1), self.cnn_kernel_width - 1, 1)
pad = pad.type_as(x)
base_target_emb = x
for conv, attention in zip(self.conv_layers, self.attn_layers):
new_target_input = torch.cat([pad, x], 2)
out = conv(new_target_input)
c, attn = attention(base_target_emb, out, enc_out_t, enc_out_c)
x = (x + (c + out) * SCALE_WEIGHT) * SCALE_WEIGHT
dec_outs = x.squeeze(3).transpose(1, 2)
# Process the result and update the attentions.
if self.state["previous_input"] is not None:
dec_outs = dec_outs[:, self.state["previous_input"].size(1) :, :]
attn = attn[:, self.state["previous_input"].size(1) :].squeeze()
attn = torch.stack([attn])
attns["std"] = attn
if self.copy_attn is not None:
attns["copy"] = attn
# Update the state.
self.state["previous_input"] = tgt
# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns
def update_dropout(self, dropout, attention_dropout=None):
for layer in self.conv_layers:
layer.dropout.p = dropout
|