File size: 1,894 Bytes
737e9a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from huggingface_hub import PyTorchModelHubMixin

from torch import nn
import torch

class BiLSTM(nn.Module, PyTorchModelHubMixin):
    def __init__(self, vocab_size=23626, embed_dim=100,
                 num_layers=1, hidden_dim=256, dropout=0.33,
                 output_dim=128, predict_output=10, device="cuda:0"):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.predict_output = predict_output

        self.embed_layer = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.biLSTM = nn.LSTM(input_size=embed_dim,
                              hidden_size=hidden_dim // 2, # BiLSTM will concatenate the 2 directional LSTMs
                              num_layers=num_layers,
                              bidirectional=True,
                              batch_first=True)
        self.linear = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)
        self.elu = nn.ELU()
        self.fc = nn.Linear(output_dim, predict_output)
        self.device_ = device
    
    def forward(self, input):   # input is a list of indices, shape batch_size, seq_len
        x = self.embed_layer(input)                     # batch_size, seq_len, 100  (This is only when batch_first=True!!!!)
        batch_size = x.size(0)
        hidden, cell = self.init_hidden(batch_size)

        out, hidden = self.biLSTM(x, (hidden, cell))    # seq_len, batch_size, (hidden_dim//2) * 2
        
        out = self.dropout(out)

        out = self.elu(self.linear(out))                # self.linear(out): batch_size, seq_len, output_dim
        
        out = self.fc(out)
        
        return out, hidden
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
        cell = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
        return hidden, cell