|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from torch.utils.data import DataLoader, Dataset |
|
from super_large_language_model import TransformerModel |
|
|
|
class TextDataset(Dataset): |
|
def __init__(self, texts, vocab): |
|
self.texts = texts |
|
self.vocab = vocab |
|
|
|
def __len__(self): |
|
return len(self.texts) |
|
|
|
def __getitem__(self, idx): |
|
text = self.texts[idx] |
|
text_indices = [self.vocab[char] for char in text] |
|
return torch.tensor(text_indices) |
|
|
|
def train_model(model, dataset, num_epochs=10, batch_size=32, learning_rate=0.001): |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
for epoch in range(num_epochs): |
|
model.train() |
|
for batch in dataloader: |
|
optimizer.zero_grad() |
|
output = model(batch[:-1], batch[1:]) |
|
loss = criterion(output.view(-1, output.size(-1)), batch[1:].view(-1)) |
|
loss.backward() |
|
optimizer.step() |
|
print(f'Epoch {epoch+1}, Loss: {loss.item()}') |
|
|
|
if __name__ == "__main__": |
|
|
|
texts = ["hello world", "pytorch is great"] |
|
vocab = {char: idx for idx, char in enumerate(set("".join(texts)))} |
|
|
|
dataset = TextDataset(texts, vocab) |
|
model = TransformerModel(vocab_size=len(vocab), d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048) |
|
|
|
train_model(model, dataset) |
|
|