|
""" ContextGate module """ |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def context_gate_factory( |
|
gate_type, embeddings_size, decoder_size, attention_size, output_size |
|
): |
|
"""Returns the correct ContextGate class""" |
|
|
|
gate_types = { |
|
"source": SourceContextGate, |
|
"target": TargetContextGate, |
|
"both": BothContextGate, |
|
} |
|
|
|
assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(gate_type) |
|
return gate_types[gate_type]( |
|
embeddings_size, decoder_size, attention_size, output_size |
|
) |
|
|
|
|
|
class ContextGate(nn.Module): |
|
""" |
|
Context gate is a decoder module that takes as input the previous word |
|
embedding, the current decoder state and the attention state, and |
|
produces a gate. |
|
The gate can be used to select the input from the target side context |
|
(decoder state), from the source context (attention state) or both. |
|
""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, attention_size, output_size): |
|
super(ContextGate, self).__init__() |
|
input_size = embeddings_size + decoder_size + attention_size |
|
self.gate = nn.Linear(input_size, output_size, bias=True) |
|
self.sig = nn.Sigmoid() |
|
self.source_proj = nn.Linear(attention_size, output_size) |
|
self.target_proj = nn.Linear(embeddings_size + decoder_size, output_size) |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1) |
|
z = self.sig(self.gate(input_tensor)) |
|
proj_source = self.source_proj(attn_state) |
|
proj_target = self.target_proj(torch.cat((prev_emb, dec_state), dim=1)) |
|
return z, proj_source, proj_target |
|
|
|
|
|
class SourceContextGate(nn.Module): |
|
"""Apply the context gate only to the source context""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, attention_size, output_size): |
|
super(SourceContextGate, self).__init__() |
|
self.context_gate = ContextGate( |
|
embeddings_size, decoder_size, attention_size, output_size |
|
) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate(prev_emb, dec_state, attn_state) |
|
return self.tanh(target + z * source) |
|
|
|
|
|
class TargetContextGate(nn.Module): |
|
"""Apply the context gate only to the target context""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, attention_size, output_size): |
|
super(TargetContextGate, self).__init__() |
|
self.context_gate = ContextGate( |
|
embeddings_size, decoder_size, attention_size, output_size |
|
) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate(prev_emb, dec_state, attn_state) |
|
return self.tanh(z * target + source) |
|
|
|
|
|
class BothContextGate(nn.Module): |
|
"""Apply the context gate to both contexts""" |
|
|
|
def __init__(self, embeddings_size, decoder_size, attention_size, output_size): |
|
super(BothContextGate, self).__init__() |
|
self.context_gate = ContextGate( |
|
embeddings_size, decoder_size, attention_size, output_size |
|
) |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, prev_emb, dec_state, attn_state): |
|
z, source, target = self.context_gate(prev_emb, dec_state, attn_state) |
|
return self.tanh((1.0 - z) * target + z * source) |
|
|