File size: 1,894 Bytes
737e9a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from huggingface_hub import PyTorchModelHubMixin
from torch import nn
import torch
class BiLSTM(nn.Module, PyTorchModelHubMixin):
def __init__(self, vocab_size=23626, embed_dim=100,
num_layers=1, hidden_dim=256, dropout=0.33,
output_dim=128, predict_output=10, device="cuda:0"):
super().__init__()
self.hidden_dim = hidden_dim
self.predict_output = predict_output
self.embed_layer = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.biLSTM = nn.LSTM(input_size=embed_dim,
hidden_size=hidden_dim // 2, # BiLSTM will concatenate the 2 directional LSTMs
num_layers=num_layers,
bidirectional=True,
batch_first=True)
self.linear = nn.Linear(hidden_dim, output_dim)
self.dropout = nn.Dropout(dropout)
self.elu = nn.ELU()
self.fc = nn.Linear(output_dim, predict_output)
self.device_ = device
def forward(self, input): # input is a list of indices, shape batch_size, seq_len
x = self.embed_layer(input) # batch_size, seq_len, 100 (This is only when batch_first=True!!!!)
batch_size = x.size(0)
hidden, cell = self.init_hidden(batch_size)
out, hidden = self.biLSTM(x, (hidden, cell)) # seq_len, batch_size, (hidden_dim//2) * 2
out = self.dropout(out)
out = self.elu(self.linear(out)) # self.linear(out): batch_size, seq_len, output_dim
out = self.fc(out)
return out, hidden
def init_hidden(self, batch_size):
hidden = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
cell = torch.zeros(2, batch_size, self.hidden_dim//2, device=self.device_)
return hidden, cell |