|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import pickle |
|
from safetensors.torch import load_file |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
embedding_dim = 8 |
|
|
|
hidden_dim = 16 |
|
num_layers = 1 |
|
sequence_length = 64 |
|
temp = 1.0 |
|
top_k = 10 |
|
|
|
|
|
class LSTMModel(nn.Module): |
|
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): |
|
super(LSTMModel, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, embedding_dim) |
|
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) |
|
self.fc = nn.Linear(hidden_dim, vocab_size) |
|
|
|
def forward(self, x): |
|
embeds = self.embedding(x) |
|
lstm_out, _ = self.lstm(embeds) |
|
logits = self.fc(lstm_out[:, -1, :]) |
|
return logits |
|
|
|
|
|
logging.info('Loading the model and vocabulary...') |
|
model_state_dict = load_file('lstm_model.safetensors') |
|
with open('word2idx.pkl', 'rb') as f: |
|
word2idx = pickle.load(f) |
|
with open('idx2word.pkl', 'rb') as f: |
|
idx2word = pickle.load(f) |
|
|
|
vocab_size = len(word2idx) |
|
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) |
|
model.load_state_dict(model_state_dict) |
|
model.eval() |
|
|
|
logging.info('Model and vocabulary loaded successfully.') |
|
|
|
|
|
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k): |
|
model.eval() |
|
with torch.no_grad(): |
|
seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()] |
|
seq_idx = seq_idx[-sequence_length:] |
|
seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0) |
|
outputs = model(seq_tensor) |
|
outputs = outputs / temp |
|
probs = F.softmax(outputs, dim=1).squeeze() |
|
top_k_probs, top_k_idx = torch.topk(probs, top_k) |
|
predicted_idx = torch.multinomial(top_k_probs, 1).item() |
|
predicted_word = idx2word[top_k_idx[predicted_idx].item()] |
|
return predicted_word |
|
|
|
|
|
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length=50): |
|
sentence = start_sequence |
|
for _ in range(max_length): |
|
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k) |
|
sentence += ' ' + next_word |
|
if next_word == '<pad>' or next_word == 'User': |
|
break |
|
return sentence |
|
|
|
|
|
start_sequence = "User : What is the capital of France ? Bot :" |
|
|
|
|
|
temp = 0.5 |
|
top_k = 32 |
|
logging.info(f'Starting sequence: {start_sequence}') |
|
logging.info(f'Temperature: {temp}, Top-k: {top_k}') |
|
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k) |
|
logging.info(f'Generated sentence: {generated_sentence}') |
|
|