English
LSTM-1225 / inference.py
Fishfishfishfishfish's picture
Update inference.py
cc68a7c verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from safetensors.torch import load_file
import logging
import argparse
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Hyperparameters
embedding_dim = 128
hidden_dim = 256
num_layers = 2
sequence_length = 10
# LSTM Model
class LSTMModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
super(LSTMModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embeds = self.embedding(x)
lstm_out, _ = self.lstm(embeds)
logits = self.fc(lstm_out[:, -1, :])
return logits
# Function to predict the next word with temperature and top-k sampling
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
model.eval()
with torch.no_grad():
seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()]
seq_idx = seq_idx[-sequence_length:] # Ensure the sequence length is correct
seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0)
outputs = model(seq_tensor)
outputs = outputs / temp # Apply temperature
probs = F.softmax(outputs, dim=1).squeeze()
top_k_probs, top_k_idx = torch.topk(probs, top_k)
predicted_idx = torch.multinomial(top_k_probs, 1).item()
predicted_word = idx2word[top_k_idx[predicted_idx].item()]
return predicted_word
# Function to generate a sentence
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length):
sentence = start_sequence
for _ in range(max_length):
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
sentence += ' ' + next_word
if next_word == '<pad>':
break
return sentence
# Parse command-line arguments
def parse_args():
parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot')
parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter')
parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter')
parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file')
parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation')
parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate')
return parser.parse_args()
# Main function
def main():
args = parse_args()
temp = args.temp
top_k = args.top_k
model_file = args.model_file
start_sequence = args.start_sequence
max_length = args.max_length
logging.info(f'Loading the model and vocabulary from {model_file}...')
model_state_dict = load_file(model_file)
with open('word2idx.pkl', 'rb') as f:
word2idx = pickle.load(f)
# Generate idx2word from word2idx
idx2word = {idx: word for word, idx in word2idx.items()}
vocab_size = len(word2idx)
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
model.load_state_dict(model_state_dict)
model.eval()
logging.info('Model and vocabulary loaded successfully.')
logging.info(f'Starting sequence: {start_sequence}')
logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}')
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length)
logging.info(f'Generated sentence: {generated_sentence}')
if __name__ == '__main__':
main()