Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
3.4 kB
""" ContextGate module """
import torch
import torch.nn as nn
def context_gate_factory(
gate_type, embeddings_size, decoder_size, attention_size, output_size
):
"""Returns the correct ContextGate class"""
gate_types = {
"source": SourceContextGate,
"target": TargetContextGate,
"both": BothContextGate,
}
assert gate_type in gate_types, "Not valid ContextGate type: {0}".format(gate_type)
return gate_types[gate_type](
embeddings_size, decoder_size, attention_size, output_size
)
class ContextGate(nn.Module):
"""
Context gate is a decoder module that takes as input the previous word
embedding, the current decoder state and the attention state, and
produces a gate.
The gate can be used to select the input from the target side context
(decoder state), from the source context (attention state) or both.
"""
def __init__(self, embeddings_size, decoder_size, attention_size, output_size):
super(ContextGate, self).__init__()
input_size = embeddings_size + decoder_size + attention_size
self.gate = nn.Linear(input_size, output_size, bias=True)
self.sig = nn.Sigmoid()
self.source_proj = nn.Linear(attention_size, output_size)
self.target_proj = nn.Linear(embeddings_size + decoder_size, output_size)
def forward(self, prev_emb, dec_state, attn_state):
input_tensor = torch.cat((prev_emb, dec_state, attn_state), dim=1)
z = self.sig(self.gate(input_tensor))
proj_source = self.source_proj(attn_state)
proj_target = self.target_proj(torch.cat((prev_emb, dec_state), dim=1))
return z, proj_source, proj_target
class SourceContextGate(nn.Module):
"""Apply the context gate only to the source context"""
def __init__(self, embeddings_size, decoder_size, attention_size, output_size):
super(SourceContextGate, self).__init__()
self.context_gate = ContextGate(
embeddings_size, decoder_size, attention_size, output_size
)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh(target + z * source)
class TargetContextGate(nn.Module):
"""Apply the context gate only to the target context"""
def __init__(self, embeddings_size, decoder_size, attention_size, output_size):
super(TargetContextGate, self).__init__()
self.context_gate = ContextGate(
embeddings_size, decoder_size, attention_size, output_size
)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh(z * target + source)
class BothContextGate(nn.Module):
"""Apply the context gate to both contexts"""
def __init__(self, embeddings_size, decoder_size, attention_size, output_size):
super(BothContextGate, self).__init__()
self.context_gate = ContextGate(
embeddings_size, decoder_size, attention_size, output_size
)
self.tanh = nn.Tanh()
def forward(self, prev_emb, dec_state, attn_state):
z, source, target = self.context_gate(prev_emb, dec_state, attn_state)
return self.tanh((1.0 - z) * target + z * source)