File size: 1,440 Bytes
ed0e769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch 
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, rnn_conf) -> None:
        super().__init__()

        self.embedding_dim   = rnn_conf.embedding_dim
        self.hidden_size     = rnn_conf.hidden_size
        self.bidirectional   = rnn_conf.bidirectional
        self.n_layers        = rnn_conf.n_layers
        
        self.embedding = nn.Embedding(rnn_conf.vocab_size, self.embedding_dim)
        self.lstm = nn.LSTM(
            input_size    = self.embedding_dim,
            hidden_size   = self.hidden_size,
            bidirectional = self.bidirectional,
            batch_first   = True,
            num_layers    = self.n_layers
        )
        self.bidirect_factor = 2 if self.bidirectional else 1
        self.clf = nn.Sequential(
            nn.Linear(self.hidden_size * self.bidirect_factor, 32), 
            nn.Tanh(),
            nn.Dropout(),
            nn.Linear(32, 1)
        )
    
    def model_description(self):
        direction = 'bidirect' if self.bidirectional else 'onedirect'
        return f'lstm_{direction}_{self.n_layers}'
        

    def forward(self, x: torch.Tensor):
        embeddings = self.embedding(x)
        out, _ = self.lstm(embeddings)
        out = out[:, -1, :] # [все элементы батча, последний h_n, все элементы последнего h_n]
        out = self.clf(out.squeeze())
        return out