|
""" Misc classes """ |
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
|
|
class Elementwise(nn.ModuleList): |
|
""" |
|
A simple network container. |
|
Parameters are a list of modules. |
|
emb is a 3d Tensor whose last dimension is the same length |
|
as the list. |
|
emb_out is the result of applying modules to emb elementwise. |
|
An optional merge parameter allows the emb_out to be reduced to a |
|
single Tensor. |
|
""" |
|
|
|
def __init__(self, merge=None, *args): |
|
assert merge in [None, "first", "concat", "sum", "mlp"] |
|
self.merge = merge |
|
super(Elementwise, self).__init__(*args) |
|
|
|
def forward(self, emb): |
|
emb_ = [feat.squeeze(2) for feat in emb.split(1, dim=2)] |
|
assert len(self) == len(emb_) |
|
emb_out = [f(x) for f, x in zip(self, emb_)] |
|
if self.merge == "first": |
|
return emb_out[0] |
|
elif self.merge == "concat" or self.merge == "mlp": |
|
return torch.cat(emb_out, 2) |
|
elif self.merge == "sum": |
|
return sum(emb_out) |
|
else: |
|
return emb_out |
|
|
|
|
|
class Cast(nn.Module): |
|
""" |
|
Basic layer that casts its emb to a specific data type. The same tensor |
|
is returned if the data type is already correct. |
|
""" |
|
|
|
def __init__(self, dtype): |
|
super(Cast, self).__init__() |
|
self._dtype = dtype |
|
|
|
def forward(self, x): |
|
return x.to(self._dtype) |
|
|