Fishfishfishfishfish
commited on
Update inference.py
Browse files- inference.py +44 -31
inference.py
CHANGED
@@ -4,18 +4,16 @@ import torch.nn.functional as F
|
|
4 |
import pickle
|
5 |
from safetensors.torch import load_file
|
6 |
import logging
|
|
|
7 |
|
8 |
# Set up logging
|
9 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
10 |
|
11 |
# Hyperparameters
|
12 |
-
embedding_dim =
|
13 |
-
#change H number according to model used
|
14 |
hidden_dim = 256
|
15 |
-
num_layers =
|
16 |
-
sequence_length =
|
17 |
-
temp = 1.0 # Temperature parameter
|
18 |
-
top_k = 10 # Top-k sampling parameter
|
19 |
|
20 |
# LSTM Model
|
21 |
class LSTMModel(nn.Module):
|
@@ -31,21 +29,6 @@ class LSTMModel(nn.Module):
|
|
31 |
logits = self.fc(lstm_out[:, -1, :])
|
32 |
return logits
|
33 |
|
34 |
-
# Load the model and vocabulary
|
35 |
-
logging.info('Loading the model and vocabulary...')
|
36 |
-
model_state_dict = load_file('lstm_H256.safetensors')
|
37 |
-
with open('word2idx.pkl', 'rb') as f:
|
38 |
-
word2idx = pickle.load(f)
|
39 |
-
with open('idx2word.pkl', 'rb') as f:
|
40 |
-
idx2word = pickle.load(f)
|
41 |
-
|
42 |
-
vocab_size = len(word2idx)
|
43 |
-
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
|
44 |
-
model.load_state_dict(model_state_dict)
|
45 |
-
model.eval()
|
46 |
-
|
47 |
-
logging.info('Model and vocabulary loaded successfully.')
|
48 |
-
|
49 |
# Function to predict the next word with temperature and top-k sampling
|
50 |
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
|
51 |
model.eval()
|
@@ -62,22 +45,52 @@ def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp
|
|
62 |
return predicted_word
|
63 |
|
64 |
# Function to generate a sentence
|
65 |
-
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length
|
66 |
sentence = start_sequence
|
67 |
for _ in range(max_length):
|
68 |
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
|
69 |
sentence += ' ' + next_word
|
70 |
-
if next_word == '<pad>'
|
71 |
break
|
72 |
return sentence
|
73 |
|
74 |
-
#
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
|
|
|
|
|
|
|
|
|
|
77 |
|
78 |
-
|
79 |
-
|
80 |
-
logging.info(f'Starting sequence: {start_sequence}')
|
81 |
-
logging.info(f'Temperature: {temp}, Top-k: {top_k}')
|
82 |
-
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k)
|
83 |
-
logging.info(f'Generated sentence: {generated_sentence}')
|
|
|
4 |
import pickle
|
5 |
from safetensors.torch import load_file
|
6 |
import logging
|
7 |
+
import argparse
|
8 |
|
9 |
# Set up logging
|
10 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
11 |
|
12 |
# Hyperparameters
|
13 |
+
embedding_dim = 128
|
|
|
14 |
hidden_dim = 256
|
15 |
+
num_layers = 2
|
16 |
+
sequence_length = 10
|
|
|
|
|
17 |
|
18 |
# LSTM Model
|
19 |
class LSTMModel(nn.Module):
|
|
|
29 |
logits = self.fc(lstm_out[:, -1, :])
|
30 |
return logits
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
# Function to predict the next word with temperature and top-k sampling
|
33 |
def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
|
34 |
model.eval()
|
|
|
45 |
return predicted_word
|
46 |
|
47 |
# Function to generate a sentence
|
48 |
+
def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length):
|
49 |
sentence = start_sequence
|
50 |
for _ in range(max_length):
|
51 |
next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
|
52 |
sentence += ' ' + next_word
|
53 |
+
if next_word == '<pad>':
|
54 |
break
|
55 |
return sentence
|
56 |
|
57 |
+
# Parse command-line arguments
|
58 |
+
def parse_args():
|
59 |
+
parser = argparse.ArgumentParser(description='LSTM Next Word Prediction Chatbot')
|
60 |
+
parser.add_argument('--temp', type=float, default=1.0, help='Temperature parameter')
|
61 |
+
parser.add_argument('--top_k', type=int, default=10, help='Top-k sampling parameter')
|
62 |
+
parser.add_argument('--model_file', type=str, default='lstm_model.safetensors', help='Path to the safetensors model file')
|
63 |
+
parser.add_argument('--start_sequence', type=str, default='Once upon a time', help='Starting sequence for sentence generation')
|
64 |
+
parser.add_argument('--max_length', type=int, default=50, help='Maximum number of words to generate')
|
65 |
+
return parser.parse_args()
|
66 |
+
|
67 |
+
# Main function
|
68 |
+
def main():
|
69 |
+
args = parse_args()
|
70 |
+
temp = args.temp
|
71 |
+
top_k = args.top_k
|
72 |
+
model_file = args.model_file
|
73 |
+
start_sequence = args.start_sequence
|
74 |
+
max_length = args.max_length
|
75 |
+
|
76 |
+
logging.info(f'Loading the model and vocabulary from {model_file}...')
|
77 |
+
model_state_dict = load_file(model_file)
|
78 |
+
with open('word2idx.pkl', 'rb') as f:
|
79 |
+
word2idx = pickle.load(f)
|
80 |
+
|
81 |
+
# Generate idx2word from word2idx
|
82 |
+
idx2word = {idx: word for word, idx in word2idx.items()}
|
83 |
+
|
84 |
+
vocab_size = len(word2idx)
|
85 |
+
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
|
86 |
+
model.load_state_dict(model_state_dict)
|
87 |
+
model.eval()
|
88 |
|
89 |
+
logging.info('Model and vocabulary loaded successfully.')
|
90 |
+
logging.info(f'Starting sequence: {start_sequence}')
|
91 |
+
logging.info(f'Temperature: {temp}, Top-k: {top_k}, Max Length: {max_length}')
|
92 |
+
generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length)
|
93 |
+
logging.info(f'Generated sentence: {generated_sentence}')
|
94 |
|
95 |
+
if __name__ == '__main__':
|
96 |
+
main()
|
|
|
|
|
|
|
|