File size: 1,461 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
""" Misc classes """
import torch
import torch.nn as nn
# At the moment this class is only used by embeddings.Embeddings look-up tables
class Elementwise(nn.ModuleList):
"""
A simple network container.
Parameters are a list of modules.
emb is a 3d Tensor whose last dimension is the same length
as the list.
emb_out is the result of applying modules to emb elementwise.
An optional merge parameter allows the emb_out to be reduced to a
single Tensor.
"""
def __init__(self, merge=None, *args):
assert merge in [None, "first", "concat", "sum", "mlp"]
self.merge = merge
super(Elementwise, self).__init__(*args)
def forward(self, emb):
emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)]
assert len(self) == len(emb_)
emb_out = [f(x) for f, x in zip(self, emb_)]
if self.merge == "first":
return emb_out[0]
elif self.merge == "concat" or self.merge == "mlp":
return torch.cat(emb_out, 2)
elif self.merge == "sum":
return sum(emb_out)
else:
return emb_out
class Cast(nn.Module):
"""
Basic layer that casts its emb to a specific data type. The same tensor
is returned if the data type is already correct.
"""
def __init__(self, dtype):
super(Cast, self).__init__()
self._dtype = dtype
def forward(self, x):
return x.to(self._dtype)
|