import torch import torch.nn as nn import torch.nn.functional as F import pickle from safetensors.torch import load_file import logging import argparse # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Hyperparameters embedding_dim = 128 hidden_dim = 256 num_layers = 2 sequence_length = 10 # LSTM Model class LSTMModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers): super(LSTMModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True) self.fc = nn.Linear(hidden_dim, vocab_size) def forward(self, x): embeds = self.embedding(x) lstm_out, _ = self.lstm(embeds) logits = self.fc(lstm_out[:, -1, :]) return logits # Function to predict the next word with temperature and top-k sampling def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k): model.eval() with torch.no_grad(): seq_idx = [word2idx.get(word, word2idx['']) for word in sequence.split()] seq_idx = seq_idx[-sequence_length:] # Ensure the sequence length is correct seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0) outputs = model(seq_tensor) outputs = outputs / temp # Apply temperature probs = F.softmax(outputs, dim=1).squeeze() top_k_probs, top_k_idx = torch.topk(probs, top_k) predicted_idx = torch.multinomial(top_k_probs, 1).item() predicted_word = idx2word[top_k_idx[predicted_idx].item()] return predicted_word # Function to generate a sentence def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length): sentence = start_sequence for _ in range(max_length): next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k) sentence += ' ' + next_word if next_word == '': break return sentence # Parse command-line arguments def parse_args(): parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot') parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter') parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter') parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file') parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation') parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate') return parser.parse_args() # Main function def main(): args = parse_args() temp = args.temp top_k = args.top_k model_file = args.model_file start_sequence = args.start_sequence max_length = args.max_length logging.info(f'Loading the model and vocabulary from {model_file}...') model_state_dict = load_file(model_file) with open('word2idx.pkl', 'rb') as f: word2idx = pickle.load(f) # Generate idx2word from word2idx idx2word = {idx: word for word, idx in word2idx.items()} vocab_size = len(word2idx) model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers) model.load_state_dict(model_state_dict) model.eval() logging.info('Model and vocabulary loaded successfully.') logging.info(f'Starting sequence: {start_sequence}') logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}') generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length) logging.info(f'Generated sentence: {generated_sentence}') if __name__ == '__main__': main()