ReactSeq / onmt /modules /util_class.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
1.46 kB
""" Misc classes """
import torch
import torch.nn as nn
# At the moment this class is only used by embeddings.Embeddings look-up tables
class Elementwise(nn.ModuleList):
"""
A simple network container.
Parameters are a list of modules.
emb is a 3d Tensor whose last dimension is the same length
as the list.
emb_out is the result of applying modules to emb elementwise.
An optional merge parameter allows the emb_out to be reduced to a
single Tensor.
"""
def __init__(self, merge=None, *args):
assert merge in [None, "first", "concat", "sum", "mlp"]
self.merge = merge
super(Elementwise, self).__init__(*args)
def forward(self, emb):
emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)]
assert len(self) == len(emb_)
emb_out = [f(x) for f, x in zip(self, emb_)]
if self.merge == "first":
return emb_out[0]
elif self.merge == "concat" or self.merge == "mlp":
return torch.cat(emb_out, 2)
elif self.merge == "sum":
return sum(emb_out)
else:
return emb_out
class Cast(nn.Module):
"""
Basic layer that casts its emb to a specific data type. The same tensor
is returned if the data type is already correct.
"""
def __init__(self, dtype):
super(Cast, self).__init__()
self._dtype = dtype
def forward(self, x):
return x.to(self._dtype)