ReactSeq / onmt /utils /rnn_factory.py
Oopstom's picture
Upload 313 files
c668e80 verified
raw
history blame
468 Bytes
"""
RNN tools
"""
import torch.nn as nn
import onmt.modules
def rnn_factory(rnn_type, **kwargs):
"""rnn factory, Use pytorch version when available."""
no_pack_padded_seq = False
if rnn_type == "SRU":
# SRU doesn't support PackedSequence.
no_pack_padded_seq = True
rnn = onmt.modules.sru.SRU(batch_first=True, **kwargs)
else:
rnn = getattr(nn, rnn_type)(batch_first=True, **kwargs)
return rnn, no_pack_padded_seq