bad-gpt / main.py
Mike Gabriel
Add prompt
002443a
raw
history blame
2.65 kB
from typing import Literal
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import os
from encoder import encode, decode
from bigram import BigramLanguageModel, Batcher, estimate_loss
# HYPERPARAMETERS #
### Impacts performance ###
BATCH_SIZE = 64 # how many sequences of tokens will we process in parallel
BLOCK_SIZE = 256 # how long is a single token sequence (context length)
LEARNING_RATE = 1e-4
NUM_EMBEDDING_DIMENSIONS = 384
NUM_HEADS = 6
NUM_LAYERS = 6
MAX_ITERS = 5000
### Others ###
EVAL_INTERVAL = 500
DROPOUT_RATE = 0.2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --------------- #
def train_model(model: nn.Module, batcher: Batcher, iterations=MAX_ITERS, lr=LEARNING_RATE):
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
for i in range(iterations):
if i % EVAL_INTERVAL == 0:
losses = estimate_loss(model, batcher, EVAL_INTERVAL, DEVICE)
print(
f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
context_stack, answer_stack = batcher.get_batch(split='train')
_, loss = model(context_stack.to(DEVICE), answer_stack.to(DEVICE))
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
return optimizer
b = Batcher(
device=DEVICE,
batch_size=BATCH_SIZE,
block_size=BLOCK_SIZE
)
m = BigramLanguageModel(
device=DEVICE,
block_size=BLOCK_SIZE,
vocab_size=len(b.vocab),
n_embd=NUM_EMBEDDING_DIMENSIONS,
n_head=NUM_HEADS,
n_layers=NUM_LAYERS,
dropout=DROPOUT_RATE
).to(DEVICE)
def run_model(model: nn.Module, response_size: int = BLOCK_SIZE, query: str = ''):
start_ids = encode(query)
context = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)
# add batch dimension. it's just 1 batch, but we still need it cuz tensors
context = context[None, ...]
encoded = model.generate(
idx=context, max_new_tokens=response_size)[0]
return decode(encoded.tolist())
if os.path.exists('model.pth'):
print("Loading model from file...")
checkpoint = torch.load('model.pth', map_location=DEVICE)
m.load_state_dict(checkpoint['model_state_dict'])
print("Model loaded!")
else:
print("Training model...")
optimizer = train_model(m, b)
torch.save({
'model_state_dict': m.state_dict(),
'optimizer_state_dict': optimizer.state_dict()
}, 'model.pth')
print("Training complete!")
print("Generating response...\n")
resp = run_model(m, 256, 'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?')
print("Response:\n" + resp)