File size: 1,461 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
""" Misc classes """
import torch
import torch.nn as nn


# At the moment this class is only used by embeddings.Embeddings look-up tables
class Elementwise(nn.ModuleList):
    """
    A simple network container.
    Parameters are a list of modules.
    emb is a 3d Tensor whose last dimension is the same length
    as the list.
    emb_out is the result of applying modules to emb elementwise.
    An optional merge parameter allows the emb_out to be reduced to a
    single Tensor.
    """

    def __init__(self, merge=None, *args):
        assert merge in [None, "first", "concat", "sum", "mlp"]
        self.merge = merge
        super(Elementwise, self).__init__(*args)

    def forward(self, emb):
        emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)]
        assert len(self) == len(emb_)
        emb_out = [f(x) for f, x in zip(self, emb_)]
        if self.merge == "first":
            return emb_out[0]
        elif self.merge == "concat" or self.merge == "mlp":
            return torch.cat(emb_out, 2)
        elif self.merge == "sum":
            return sum(emb_out)
        else:
            return emb_out


class Cast(nn.Module):
    """
    Basic layer that casts its emb to a specific data type. The same tensor
    is returned if the data type is already correct.
    """

    def __init__(self, dtype):
        super(Cast, self).__init__()
        self._dtype = dtype

    def forward(self, x):
        return x.to(self._dtype)