ReactSeq / onmt /decoders /ensemble.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
5.36 kB
"""Ensemble decoding.
Decodes using multiple models simultaneously,
combining their prediction distributions by averaging.
All models in the ensemble must share a target vocabulary.
"""
import torch
import torch.nn as nn
from onmt.encoders.encoder import EncoderBase
from onmt.decoders.decoder import DecoderBase
from onmt.models import NMTModel
import onmt.model_builder
class EnsembleDecoderOutput(object):
"""Wrapper around multiple decoder final hidden states."""
def __init__(self, model_dec_outs):
self.model_dec_outs = tuple(model_dec_outs)
def squeeze(self, dim=None):
"""Delegate squeeze to avoid modifying
:func:`onmt.translate.translator.Translator.translate_batch()`
"""
return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])
def __getitem__(self, index):
return self.model_dec_outs[index]
class EnsembleEncoder(EncoderBase):
"""Dummy Encoder that delegates to individual real Encoders."""
def __init__(self, model_encoders):
super(EnsembleEncoder, self).__init__()
self.model_encoders = nn.ModuleList(model_encoders)
def forward(self, src, src_len=None):
enc_out, enc_final_hs, _ = zip(
*[model_encoder(src, src_len) for model_encoder in self.model_encoders]
)
return enc_out, enc_final_hs, src_len
class EnsembleDecoder(DecoderBase):
"""Dummy Decoder that delegates to individual real Decoders."""
def __init__(self, model_decoders):
model_decoders = nn.ModuleList(model_decoders)
attentional = any([dec.attentional for dec in model_decoders])
super(EnsembleDecoder, self).__init__(attentional)
self.model_decoders = model_decoders
def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
"""See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
# src_len is a single tensor shared between all models.
# This assumption will not hold if Translator is modified
# to calculate src_len as something other than the length
# of the input.
dec_outs, attns = zip(
*[
model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
for i, model_decoder in enumerate(self.model_decoders)
]
)
mean_attns = self.combine_attns(attns)
return EnsembleDecoderOutput(dec_outs), mean_attns
def combine_attns(self, attns):
result = {}
for key in attns[0].keys():
result[key] = torch.stack(
[attn[key] for attn in attns if attn[key] is not None]
).mean(0)
return result
def init_state(self, src, enc_out, enc_hidden):
"""See :obj:`RNNDecoderBase.init_state()`"""
for i, model_decoder in enumerate(self.model_decoders):
model_decoder.init_state(src, enc_out[i], enc_hidden[i])
def map_state(self, fn):
for model_decoder in self.model_decoders:
model_decoder.map_state(fn)
class EnsembleGenerator(nn.Module):
"""
Dummy Generator that delegates to individual real Generators,
and then averages the resulting target distributions.
"""
def __init__(self, model_generators, raw_probs=False):
super(EnsembleGenerator, self).__init__()
self.model_generators = nn.ModuleList(model_generators)
self._raw_probs = raw_probs
def forward(self, hidden, attn=None, src_map=None):
"""
Compute a distribution over the target dictionary
by averaging distributions from models in the ensemble.
All models in the ensemble must share a target vocabulary.
"""
distributions = torch.stack(
[
mg(h) if attn is None else mg(h, attn, src_map)
for h, mg in zip(hidden, self.model_generators)
]
)
if self._raw_probs:
return torch.log(torch.exp(distributions).mean(0))
else:
return distributions.mean(0)
class EnsembleModel(NMTModel):
"""Dummy NMTModel wrapping individual real NMTModels."""
def __init__(self, models, raw_probs=False):
encoder = EnsembleEncoder(model.encoder for model in models)
decoder = EnsembleDecoder(model.decoder for model in models)
super(EnsembleModel, self).__init__(encoder, decoder)
self.generator = EnsembleGenerator(
[model.generator for model in models], raw_probs
)
self.models = nn.ModuleList(models)
def load_test_model(opt, device_id=0):
"""Read in multiple models for ensemble."""
shared_vocabs = None
shared_model_opt = None
models = []
for model_path in opt.models:
vocabs, model, model_opt = onmt.model_builder.load_test_model(
opt, device_id, model_path=model_path
)
if shared_vocabs is None:
shared_vocabs = vocabs
else:
assert (
shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
), "Ensemble models must use the same vocabs "
models.append(model)
if shared_model_opt is None:
shared_model_opt = model_opt
ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
return shared_vocabs, ensemble_model, shared_model_opt