English
Fishfishfishfishfish commited on
Commit
cc68a7c
·
verified ·
1 Parent(s): ac28b1a

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +44 -31
inference.py CHANGED
@@ -4,18 +4,16 @@ import torch.nn.functional as F
4
  import pickle
5
  from safetensors.torch import load_file
6
  import logging
 
7
 
8
  # Set up logging
9
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
 
11
  # Hyperparameters
12
- embedding_dim = 256
13
- #change H number according to model used
14
  hidden_dim = 256
15
- num_layers = 4
16
- sequence_length = 64
17
- temp = 1.0 # Temperature parameter
18
- top_k = 10 # Top-k sampling parameter
19
 
20
  # LSTM Model
21
  class LSTMModel(nn.Module):
@@ -31,21 +29,6 @@ class LSTMModel(nn.Module):
31
  logits = self.fc(lstm_out[:, -1, :])
32
  return logits
33
 
34
- # Load the model and vocabulary
35
- logging.info('Loading the model and vocabulary...')
36
- model_state_dict = load_file('lstm_H256.safetensors')
37
- with open('word2idx.pkl', 'rb') as f:
38
- word2idx = pickle.load(f)
39
- with open('idx2word.pkl', 'rb') as f:
40
- idx2word = pickle.load(f)
41
-
42
- vocab_size = len(word2idx)
43
- model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
44
- model.load_state_dict(model_state_dict)
45
- model.eval()
46
-
47
- logging.info('Model and vocabulary loaded successfully.')
48
-
49
  # Function to predict the next word with temperature and top-k sampling
50
  def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
51
  model.eval()
@@ -62,22 +45,52 @@ def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp
62
  return predicted_word
63
 
64
  # Function to generate a sentence
65
- def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length=50):
66
  sentence = start_sequence
67
  for _ in range(max_length):
68
  next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
69
  sentence += ' ' + next_word
70
- if next_word == '<pad>' or next_word == 'User':
71
  break
72
  return sentence
73
 
74
- # Example usage
75
- start_sequence = "User : What is the capital of France ? Bot :"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
 
 
 
 
 
77
 
78
- temp = 0.5 # Adjust temperature
79
- top_k = 32 # Adjust top-k
80
- logging.info(f'Starting sequence: {start_sequence}')
81
- logging.info(f'Temperature: {temp}, Top-k: {top_k}')
82
- generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k)
83
- logging.info(f'Generated sentence: {generated_sentence}')
 
4
  import pickle
5
  from safetensors.torch import load_file
6
  import logging
7
+ import argparse
8
 
9
  # Set up logging
10
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
 
12
  # Hyperparameters
13
+ embedding_dim = 128
 
14
  hidden_dim = 256
15
+ num_layers = 2
16
+ sequence_length = 10
 
 
17
 
18
  # LSTM Model
19
  class LSTMModel(nn.Module):
 
29
  logits = self.fc(lstm_out[:, -1, :])
30
  return logits
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Function to predict the next word with temperature and top-k sampling
33
  def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
34
  model.eval()
 
45
  return predicted_word
46
 
47
  # Function to generate a sentence
48
+ def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length):
49
  sentence = start_sequence
50
  for _ in range(max_length):
51
  next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
52
  sentence += ' ' + next_word
53
+ if next_word == '<pad>':
54
  break
55
  return sentence
56
 
57
+ # Parse command-line arguments
58
+ def parse_args():
59
+ parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot')
60
+ parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter')
61
+ parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter')
62
+ parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file')
63
+ parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation')
64
+ parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate')
65
+ return parser.parse_args()
66
+
67
+ # Main function
68
+ def main():
69
+ args = parse_args()
70
+ temp = args.temp
71
+ top_k = args.top_k
72
+ model_file = args.model_file
73
+ start_sequence = args.start_sequence
74
+ max_length = args.max_length
75
+
76
+ logging.info(f'Loading the model and vocabulary from {model_file}...')
77
+ model_state_dict = load_file(model_file)
78
+ with open('word2idx.pkl', 'rb') as f:
79
+ word2idx = pickle.load(f)
80
+
81
+ # Generate idx2word from word2idx
82
+ idx2word = {idx: word for word, idx in word2idx.items()}
83
+
84
+ vocab_size = len(word2idx)
85
+ model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
86
+ model.load_state_dict(model_state_dict)
87
+ model.eval()
88
 
89
+ logging.info('Model and vocabulary loaded successfully.')
90
+ logging.info(f'Starting sequence: {start_sequence}')
91
+ logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}')
92
+ generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length)
93
+ logging.info(f'Generated sentence: {generated_sentence}')
94
 
95
+ if __name__ == '__main__':
96
+ main()