File size: 5,358 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Ensemble decoding.

Decodes using multiple models simultaneously,
combining their prediction distributions by averaging.
All models in the ensemble must share a target vocabulary.
"""

import torch
import torch.nn as nn

from onmt.encoders.encoder import EncoderBase
from onmt.decoders.decoder import DecoderBase
from onmt.models import NMTModel
import onmt.model_builder


class EnsembleDecoderOutput(object):
    """Wrapper around multiple decoder final hidden states."""

    def __init__(self, model_dec_outs):
        self.model_dec_outs = tuple(model_dec_outs)

    def squeeze(self, dim=None):
        """Delegate squeeze to avoid modifying
        :func:`onmt.translate.translator.Translator.translate_batch()`
        """
        return EnsembleDecoderOutput([x.squeeze(dim) for x in self.model_dec_outs])

    def __getitem__(self, index):
        return self.model_dec_outs[index]


class EnsembleEncoder(EncoderBase):
    """Dummy Encoder that delegates to individual real Encoders."""

    def __init__(self, model_encoders):
        super(EnsembleEncoder, self).__init__()
        self.model_encoders = nn.ModuleList(model_encoders)

    def forward(self, src, src_len=None):
        enc_out, enc_final_hs, _ = zip(
            *[model_encoder(src, src_len) for model_encoder in self.model_encoders]
        )
        return enc_out, enc_final_hs, src_len


class EnsembleDecoder(DecoderBase):
    """Dummy Decoder that delegates to individual real Decoders."""

    def __init__(self, model_decoders):
        model_decoders = nn.ModuleList(model_decoders)
        attentional = any([dec.attentional for dec in model_decoders])
        super(EnsembleDecoder, self).__init__(attentional)
        self.model_decoders = model_decoders

    def forward(self, tgt, enc_out, src_len=None, step=None, **kwargs):
        """See :func:`onmt.decoders.decoder.DecoderBase.forward()`."""
        # src_len is a single tensor shared between all models.
        # This assumption will not hold if Translator is modified
        # to calculate src_len as something other than the length
        # of the input.
        dec_outs, attns = zip(
            *[
                model_decoder(tgt, enc_out[i], src_len=src_len, step=step, **kwargs)
                for i, model_decoder in enumerate(self.model_decoders)
            ]
        )
        mean_attns = self.combine_attns(attns)
        return EnsembleDecoderOutput(dec_outs), mean_attns

    def combine_attns(self, attns):
        result = {}
        for key in attns[0].keys():
            result[key] = torch.stack(
                [attn[key] for attn in attns if attn[key] is not None]
            ).mean(0)
        return result

    def init_state(self, src, enc_out, enc_hidden):
        """See :obj:`RNNDecoderBase.init_state()`"""
        for i, model_decoder in enumerate(self.model_decoders):
            model_decoder.init_state(src, enc_out[i], enc_hidden[i])

    def map_state(self, fn):
        for model_decoder in self.model_decoders:
            model_decoder.map_state(fn)


class EnsembleGenerator(nn.Module):
    """
    Dummy Generator that delegates to individual real Generators,
    and then averages the resulting target distributions.
    """

    def __init__(self, model_generators, raw_probs=False):
        super(EnsembleGenerator, self).__init__()
        self.model_generators = nn.ModuleList(model_generators)
        self._raw_probs = raw_probs

    def forward(self, hidden, attn=None, src_map=None):
        """
        Compute a distribution over the target dictionary
        by averaging distributions from models in the ensemble.
        All models in the ensemble must share a target vocabulary.
        """
        distributions = torch.stack(
            [
                mg(h) if attn is None else mg(h, attn, src_map)
                for h, mg in zip(hidden, self.model_generators)
            ]
        )
        if self._raw_probs:
            return torch.log(torch.exp(distributions).mean(0))
        else:
            return distributions.mean(0)


class EnsembleModel(NMTModel):
    """Dummy NMTModel wrapping individual real NMTModels."""

    def __init__(self, models, raw_probs=False):
        encoder = EnsembleEncoder(model.encoder for model in models)
        decoder = EnsembleDecoder(model.decoder for model in models)
        super(EnsembleModel, self).__init__(encoder, decoder)
        self.generator = EnsembleGenerator(
            [model.generator for model in models], raw_probs
        )
        self.models = nn.ModuleList(models)


def load_test_model(opt, device_id=0):
    """Read in multiple models for ensemble."""
    shared_vocabs = None
    shared_model_opt = None
    models = []
    for model_path in opt.models:
        vocabs, model, model_opt = onmt.model_builder.load_test_model(
            opt, device_id, model_path=model_path
        )
        if shared_vocabs is None:
            shared_vocabs = vocabs
        else:
            assert (
                shared_vocabs["src"].tokens_to_ids == vocabs["src"].tokens_to_ids
            ), "Ensemble models must use the same vocabs "
        models.append(model)
        if shared_model_opt is None:
            shared_model_opt = model_opt
    ensemble_model = EnsembleModel(models, opt.avg_raw_probs)
    return shared_vocabs, ensemble_model, shared_model_opt