English
LSTM-1225 / inference.py
Fishfishfishfishfish's picture
Upload 4 files
a1b5703 verified
raw
history blame
3.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from safetensors.torch import load_file
import logging
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Hyperparameters
embedding_dim = 8
hidden_dim = 16
num_layers = 1
sequence_length = 64
temp = 1.0 # Temperature parameter
top_k = 10 # Top-k sampling parameter
# LSTM Model
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embeds = self.embedding(x)
lstm_out, _ = self.lstm(embeds)
logits = self.fc(lstm_out[:, -1, :])
return logits
# Load the model and vocabulary
logging.info('Loading the model and vocabulary...')
model_state_dict = load_file('lstm_model.safetensors')
with open('word2idx.pkl', 'rb') as f:
word2idx = pickle.load(f)
with open('idx2word.pkl', 'rb') as f:
idx2word = pickle.load(f)
vocab_size = len(word2idx)
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
model.load_state_dict(model_state_dict)
model.eval()
logging.info('Model and vocabulary loaded successfully.')
# Function to predict the next word with temperature and top-k sampling
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
model.eval()
with torch.no_grad():
seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()]
seq_idx = seq_idx[-sequence_length:] # Ensure the sequence length is correct
seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0)
outputs = model(seq_tensor)
outputs = outputs / temp # Apply temperature
probs = F.softmax(outputs, dim=1).squeeze()
top_k_probs, top_k_idx = torch.topk(probs, top_k)
predicted_idx = torch.multinomial(top_k_probs, 1).item()
predicted_word = idx2word[top_k_idx[predicted_idx].item()]
return predicted_word
# Function to generate a sentence
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length=50):
sentence = start_sequence
for _ in range(max_length):
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
sentence += ' ' + next_word
if next_word == '<pad>' or next_word == 'User':
break
return sentence
# Example usage
start_sequence = "User : What is the capital of France ? Bot :"
temp = 0.5 # Adjust temperature
top_k = 32 # Adjust top-k
logging.info(f'Starting sequence: {start_sequence}')
logging.info(f'Temperature: {temp}, Top-k: {top_k}')
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k)
logging.info(f'Generated sentence: {generated_sentence}')