File size: 2,648 Bytes
3c8279d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0a30a3
 
 
 
 
3c8279d
 
 
 
 
 
 
4e9451d
3c8279d
 
 
 
 
 
 
 
 
 
 
002443a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Literal
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import os

from encoder import encode, decode
from bigram import BigramLanguageModel, Batcher, estimate_loss

# HYPERPARAMETERS #
### Impacts performance ###
BATCH_SIZE = 64  # how many sequences of tokens will we process in parallel
BLOCK_SIZE = 256  # how long is a single token sequence (context length)
LEARNING_RATE = 1e-4
NUM_EMBEDDING_DIMENSIONS = 384
NUM_HEADS = 6
NUM_LAYERS = 6
MAX_ITERS = 5000
### Others ###
EVAL_INTERVAL = 500
DROPOUT_RATE = 0.2
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --------------- #


def train_model(model: nn.Module, batcher: Batcher, iterations=MAX_ITERS, lr=LEARNING_RATE):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    for i in range(iterations):
        if i % EVAL_INTERVAL == 0:
            losses = estimate_loss(model, batcher, EVAL_INTERVAL, DEVICE)
            print(
                f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        context_stack, answer_stack = batcher.get_batch(split='train')
        _, loss = model(context_stack.to(DEVICE), answer_stack.to(DEVICE))
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    return optimizer


b = Batcher(
    device=DEVICE,
    batch_size=BATCH_SIZE,
    block_size=BLOCK_SIZE
)
m = BigramLanguageModel(
    device=DEVICE,
    block_size=BLOCK_SIZE,
    vocab_size=len(b.vocab),
    n_embd=NUM_EMBEDDING_DIMENSIONS,
    n_head=NUM_HEADS,
    n_layers=NUM_LAYERS,
    dropout=DROPOUT_RATE
).to(DEVICE)


def run_model(model: nn.Module, response_size: int = BLOCK_SIZE, query: str = ''):
    start_ids = encode(query)
    context = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)
    # add batch dimension. it's just 1 batch, but we still need it cuz tensors
    context = context[None, ...]
    encoded = model.generate(
        idx=context, max_new_tokens=response_size)[0]
    return decode(encoded.tolist())


if os.path.exists('model.pth'):
    print("Loading model from file...")
    checkpoint = torch.load('model.pth', map_location=DEVICE)
    m.load_state_dict(checkpoint['model_state_dict'])
    print("Model loaded!")
else:
    print("Training model...")
    optimizer = train_model(m, b)
    torch.save({
        'model_state_dict': m.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }, 'model.pth')
    print("Training complete!")
print("Generating response...\n")
resp = run_model(m, 256, 'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?')
print("Response:\n" + resp)