English
Fishfishfishfishfish commited on
Commit
a1b5703
·
verified ·
1 Parent(s): 55e948c

Upload 4 files

Browse files
Files changed (4) hide show
  1. continue.py +117 -0
  2. inference.py +83 -0
  3. tokenizer.js +26 -0
  4. trainer.py +113 -0
continue.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import pickle
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from safetensors.torch import load_file, save_file
7
+ import logging
8
+ import json
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
+
13
+ # Hyperparameters
14
+ sequence_length = 16
15
+ batch_size = 32
16
+ num_epochs = 1 # Continue training for 1 more epoch
17
+ learning_rate = 0.00001
18
+ embedding_dim = 256
19
+ hidden_dim = 512
20
+ num_layers = 2
21
+
22
+ # LSTM Model
23
+ class LSTMModel(nn.Module):
24
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
25
+ super(LSTMModel, self).__init__()
26
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
27
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
28
+ self.fc = nn.Linear(hidden_dim, vocab_size)
29
+
30
+ def forward(self, x):
31
+ embeds = self.embedding(x)
32
+ lstm_out, _ = self.lstm(embeds)
33
+ logits = self.fc(lstm_out[:, -1, :])
34
+ return logits
35
+
36
+ # Load the model and vocabulary
37
+ logging.info('Loading the model and vocabulary...')
38
+ model_state_dict = load_file('lstm_model.safetensors')
39
+ with open('word2idx.pkl', 'rb') as f:
40
+ word2idx = pickle.load(f)
41
+ with open('idx2word.pkl', 'rb') as f:
42
+ idx2word = pickle.load(f)
43
+
44
+ vocab_size = len(word2idx)
45
+ model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
46
+ model.load_state_dict(model_state_dict)
47
+ model.train()
48
+
49
+ logging.info('Model and vocabulary loaded successfully.')
50
+
51
+ # Output the total number of parameters
52
+ total_params = sum(p.numel() for p in model.parameters())
53
+ logging.info(f'Total number of parameters: {total_params}')
54
+
55
+ # Read the text file
56
+ logging.info('Reading the text file...')
57
+ with open('text.txt', 'r') as file:
58
+ text = file.read()
59
+ logging.info('Text file read successfully.')
60
+
61
+ # Preprocess the text
62
+ logging.info('Preprocessing the text...')
63
+ words = json.loads(text)
64
+ sequences = []
65
+ for i in range(len(words) - sequence_length):
66
+ seq = words[i:i + sequence_length]
67
+ label = words[i + sequence_length]
68
+ sequences.append((seq, label))
69
+
70
+ logging.info(f'Number of sequences: {len(sequences)}')
71
+
72
+ # Dataset and DataLoader
73
+ class TextDataset(Dataset):
74
+ def __init__(self, sequences, word2idx):
75
+ self.sequences = sequences
76
+ self.word2idx = word2idx
77
+
78
+ def __len__(self):
79
+ return len(self.sequences)
80
+
81
+ def __getitem__(self, idx):
82
+ seq, label = self.sequences[idx]
83
+ seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq]
84
+ label_idx = self.word2idx.get(label, self.word2idx['<UNK>'])
85
+ return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long)
86
+
87
+ logging.info('Creating dataset and dataloader...')
88
+ dataset = TextDataset(sequences, word2idx)
89
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
90
+
91
+ # Continue training
92
+ criterion = nn.CrossEntropyLoss()
93
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
94
+
95
+ logging.info('Starting continued training...')
96
+ for epoch in range(num_epochs):
97
+ for batch_idx, batch in enumerate(dataloader):
98
+ inputs, targets = batch
99
+ outputs = model(inputs)
100
+ loss = criterion(outputs, targets)
101
+
102
+ optimizer.zero_grad()
103
+ loss.backward()
104
+ optimizer.step()
105
+
106
+ if batch_idx % 10 == 0:
107
+ logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
108
+
109
+ # Save the updated model
110
+ logging.info('Saving the updated model...')
111
+ save_file(model.state_dict(), 'lstm_model.safetensors')
112
+ with open('word2idx.pkl', 'wb') as f:
113
+ pickle.dump(word2idx, f)
114
+ with open('idx2word.pkl', 'wb') as f:
115
+ pickle.dump(idx2word, f)
116
+
117
+ logging.info('Updated model and vocabulary saved successfully.')
inference.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pickle
5
+ from safetensors.torch import load_file
6
+ import logging
7
+
8
+ # Set up logging
9
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10
+
11
+ # Hyperparameters
12
+ embedding_dim = 8
13
+
14
+ hidden_dim = 16
15
+ num_layers = 1
16
+ sequence_length = 64
17
+ temp = 1.0 # Temperature parameter
18
+ top_k = 10 # Top-k sampling parameter
19
+
20
+ # LSTM Model
21
+ class LSTMModel(nn.Module):
22
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
23
+ super(LSTMModel, self).__init__()
24
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
25
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
26
+ self.fc = nn.Linear(hidden_dim, vocab_size)
27
+
28
+ def forward(self, x):
29
+ embeds = self.embedding(x)
30
+ lstm_out, _ = self.lstm(embeds)
31
+ logits = self.fc(lstm_out[:, -1, :])
32
+ return logits
33
+
34
+ # Load the model and vocabulary
35
+ logging.info('Loading the model and vocabulary...')
36
+ model_state_dict = load_file('lstm_model.safetensors')
37
+ with open('word2idx.pkl', 'rb') as f:
38
+ word2idx = pickle.load(f)
39
+ with open('idx2word.pkl', 'rb') as f:
40
+ idx2word = pickle.load(f)
41
+
42
+ vocab_size = len(word2idx)
43
+ model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
44
+ model.load_state_dict(model_state_dict)
45
+ model.eval()
46
+
47
+ logging.info('Model and vocabulary loaded successfully.')
48
+
49
+ # Function to predict the next word with temperature and top-k sampling
50
+ def predict_next_word(model, word2idx, idx2word, sequence, sequence_length, temp, top_k):
51
+ model.eval()
52
+ with torch.no_grad():
53
+ seq_idx = [word2idx.get(word, word2idx['<UNK>']) for word in sequence.split()]
54
+ seq_idx = seq_idx[-sequence_length:] # Ensure the sequence length is correct
55
+ seq_tensor = torch.tensor(seq_idx, dtype=torch.long).unsqueeze(0)
56
+ outputs = model(seq_tensor)
57
+ outputs = outputs / temp # Apply temperature
58
+ probs = F.softmax(outputs, dim=1).squeeze()
59
+ top_k_probs, top_k_idx = torch.topk(probs, top_k)
60
+ predicted_idx = torch.multinomial(top_k_probs, 1).item()
61
+ predicted_word = idx2word[top_k_idx[predicted_idx].item()]
62
+ return predicted_word
63
+
64
+ # Function to generate a sentence
65
+ def generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k, max_length=50):
66
+ sentence = start_sequence
67
+ for _ in range(max_length):
68
+ next_word = predict_next_word(model, word2idx, idx2word, sentence, sequence_length, temp, top_k)
69
+ sentence += ' ' + next_word
70
+ if next_word == '<pad>' or next_word == 'User':
71
+ break
72
+ return sentence
73
+
74
+ # Example usage
75
+ start_sequence = "User : What is the capital of France ? Bot :"
76
+
77
+
78
+ temp = 0.5 # Adjust temperature
79
+ top_k = 32 # Adjust top-k
80
+ logging.info(f'Starting sequence: {start_sequence}')
81
+ logging.info(f'Temperature: {temp}, Top-k: {top_k}')
82
+ generated_sentence = generate_sentence(model, word2idx, idx2word, start_sequence, sequence_length, temp, top_k)
83
+ logging.info(f'Generated sentence: {generated_sentence}')
tokenizer.js ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ const fs = require('fs');
2
+
3
+ function tokenizeText(text) {
4
+ return text.split(/([\s,.!?:;()*-])/).filter(token => token.trim() !== '');
5
+ }
6
+
7
+
8
+ fs.readFile('text.txt', 'utf8', (err, data) => {
9
+ if (err) {
10
+ console.error('Error reading file:', err);
11
+ return;
12
+ }
13
+
14
+ const tokens = tokenizeText(data);
15
+
16
+ const jsonData = JSON.stringify(tokens);
17
+
18
+ fs.writeFile('tokens.json', jsonData, (err) => {
19
+ if (err) {
20
+ console.error('Error writing file:', err);
21
+ } else {
22
+ console.log('Tokens written to tokens.json');
23
+ }
24
+ });
25
+ });
26
+
trainer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import numpy as np
5
+ import pickle
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from safetensors.torch import save_file
8
+ import logging
9
+ import json
10
+
11
+ # Set up logging
12
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
+
14
+ # Hyperparameters
15
+ sequence_length = 64
16
+ batch_size = 1
17
+ num_epochs = 1
18
+ learning_rate = 0.00001
19
+ embedding_dim = 256
20
+ hidden_dim = 800
21
+ num_layers = 4
22
+
23
+ # Read the text file
24
+ logging.info('Reading the text file...')
25
+ with open('text.txt', 'r') as file:
26
+ text = file.read()
27
+ logging.info('Text file read successfully.')
28
+
29
+ # Preprocess the text
30
+ logging.info('Preprocessing the text...')
31
+ words = json.loads(text)
32
+ vocab = set(words)
33
+ vocab.add('<pad>')
34
+ vocab.add('<UNK>')
35
+ word2idx = {word: idx for idx, word in enumerate(vocab)}
36
+ idx2word = {idx: word for idx, word in enumerate(vocab)}
37
+ vocab_size = len(vocab)
38
+
39
+ logging.info(f'Vocabulary size: {vocab_size}')
40
+ #logging.info(f'Word to index mapping: {word2idx}')
41
+
42
+ # Create sequences
43
+ logging.info('Creating sequences...')
44
+ sequences = []
45
+ for i in range(len(words) - sequence_length):
46
+ seq = words[i:i + sequence_length]
47
+ label = words[i + sequence_length]
48
+ sequences.append((seq, label))
49
+
50
+ logging.info(f'Number of sequences: {len(sequences)}')
51
+
52
+ # Dataset and DataLoader
53
+ class TextDataset(Dataset):
54
+ def __init__(self, sequences, word2idx):
55
+ self.sequences = sequences
56
+ self.word2idx = word2idx
57
+
58
+ def __len__(self):
59
+ return len(self.sequences)
60
+
61
+ def __getitem__(self, idx):
62
+ seq, label = self.sequences[idx]
63
+ seq_idx = [self.word2idx.get(word, self.word2idx['<UNK>']) for word in seq]
64
+ label_idx = self.word2idx.get(label, self.word2idx['<UNK>'])
65
+ return torch.tensor(seq_idx, dtype=torch.long), torch.tensor(label_idx, dtype=torch.long)
66
+
67
+ logging.info('Creating dataset and dataloader...')
68
+ dataset = TextDataset(sequences, word2idx)
69
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
70
+
71
+ # LSTM Model
72
+ class LSTMModel(nn.Module):
73
+ def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
74
+ super(LSTMModel, self).__init__()
75
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
76
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
77
+ self.fc = nn.Linear(hidden_dim, vocab_size)
78
+
79
+ def forward(self, x):
80
+ embeds = self.embedding(x)
81
+ lstm_out, _ = self.lstm(embeds)
82
+ logits = self.fc(lstm_out[:, -1, :])
83
+ return logits
84
+
85
+ logging.info('Initializing the LSTM model...')
86
+ model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
87
+ criterion = nn.CrossEntropyLoss()
88
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
89
+
90
+ # Training loop
91
+ logging.info('Starting training...')
92
+ for epoch in range(num_epochs):
93
+ for batch_idx, batch in enumerate(dataloader):
94
+ inputs, targets = batch
95
+ outputs = model(inputs)
96
+ loss = criterion(outputs, targets)
97
+
98
+ optimizer.zero_grad()
99
+ loss.backward()
100
+ optimizer.step()
101
+
102
+ if batch_idx % 10 == 0:
103
+ logging.info(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
104
+
105
+ # Save the model
106
+ logging.info('Saving the model...')
107
+ save_file(model.state_dict(), 'lstm_model.safetensors')
108
+ with open('word2idx.pkl', 'wb') as f:
109
+ pickle.dump(word2idx, f)
110
+ with open('idx2word.pkl', 'wb') as f:
111
+ pickle.dump(idx2word, f)
112
+
113
+ logging.info('Model and vocabulary saved successfully.')