File size: 1,440 Bytes
ed0e769 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
import torch.nn as nn
class LSTMClassifier(nn.Module):
def __init__(self, rnn_conf) -> None:
super().__init__()
self.embedding_dim = rnn_conf.embedding_dim
self.hidden_size = rnn_conf.hidden_size
self.bidirectional = rnn_conf.bidirectional
self.n_layers = rnn_conf.n_layers
self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
self.lstm = nn.LSTM(
input_size = self.embedding_dim,
hidden_size = self.hidden_size,
bidirectional = self.bidirectional,
batch_first = True,
num_layers = self.n_layers
)
self.bidirect_factor = 2 if self.bidirectional else 1
self.clf = nn.Sequential(
nn.Linear(self.hidden_size * self.bidirect_factor, 32),
nn.Tanh(),
nn.Dropout(),
nn.Linear(32, 1)
)
def model_description(self):
direction = 'bidirect' if self.bidirectional else 'onedirect'
return f'lstm_{direction}_{self.n_layers}'
def forward(self, x: torch.Tensor):
embeddings = self.embedding(x)
out, _ = self.lstm(embeddings)
out = out[:, -1, :] # [все элементы батча, последний h_n, все элементы последнего h_n]
out = self.clf(out.squeeze())
return out
|