import os from typing import Literal import torch import torch.nn as nn from torch.nn import functional as F import numpy as np import logging from encoder import encode, decode from self_attention import Block from dataset import Batcher logger = logging.getLogger('bad_gpt').getChild(__name__) class BadGPTModel(nn.Module): def __init__( self, device: Literal['cuda', 'cpu'], block_size: int, vocab_size: int, n_embd: int, n_head: int = 4, n_layers: int = 3, dropout: float = 0.2 ): super().__init__() self.block_size = block_size self.vocab_size = vocab_size self.n_embd = n_embd self.device = device # Create a table to embed both token and position self.token_embedding_table = nn.Embedding(vocab_size, n_embd) self.position_embedding_table = nn.Embedding(block_size, n_embd) self.lm_head = nn.Linear(n_embd, vocab_size) self.expected_loss: np.float64 = np.log(1/vocab_size) * -1 self.blocks = nn.Sequential( *[ Block(n_embd, block_size, n_head, dropout) for _ in range(n_layers) ], nn.LayerNorm(n_embd) ) def forward(self, idx: torch.Tensor, targets: torch.Tensor = None): # Predict next tokens B, T = idx.shape tok_emb: torch.Tensor = self.token_embedding_table(idx) pos_emb = self.position_embedding_table( torch.arange(T, device=self.device)) x: torch.Tensor = tok_emb + pos_emb x = self.blocks(x) logits: torch.Tensor = self.lm_head(x) return logits # Given a 2d matrix of dimensions token and sentence # generate new tokens in the next sentence def generate(self, ctx: torch.Tensor, max_new_tokens: int): for index in range(max_new_tokens): # Log progress so I don't go insane if index % 16 == 0: logger.debug(f'Iteration {index} of {max_new_tokens}') # Crop out the last block_size tokens cropped_ctx = ctx[:, -self.block_size:] logits = self(cropped_ctx) # Logits has dimensions token, sentence, token_list # We want to make a new sentence, so only look at the last sentence logits = logits[:, -1, :] # Get possible next tokens and select one probabilities = F.softmax(logits, dim=-1) ctx_next = torch.multinomial(probabilities, num_samples=1) # Add the new token to the end of the tensor ctx = torch.cat((ctx, ctx_next), dim=1) return ctx @torch.no_grad() def estimate_loss(gpt: BadGPTModel, batcher: Batcher, eval_interval: int, device: Literal['cuda', 'cpu'] = 'cuda'): out = {} gpt.eval() for split in ['train', 'val']: losses = torch.zeros(eval_interval) for epoch in range(eval_interval): train, answer = batcher.get_batch(split='train') logits = gpt(train) # Reformat pediction and answer so each entry can be compared batch, block, vocab = logits.shape logits = logits.view(batch * block, vocab) answer = answer.view(batch * block) # Compare entropy of predicted tokens to actual loss = F.cross_entropy(logits, answer).item() losses[epoch] = loss out[split] = losses.mean() gpt.train() return out class BadGPTTrainer(): def __init__(self, model: BadGPTModel, batcher: Batcher, eval_interval: int, iterations: int, learning_rate: float): self.model = model self.batcher = batcher self.eval_interval = eval_interval self.iterations = iterations self.learning_rate = learning_rate self.device = self.model.device self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.learning_rate) def train(self): if os.path.exists('model.pth'): logger.debug("Loading model from file...") checkpoint = torch.load('model.pth', map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) logger.debug("Model loaded!") else: logger.debug("Training model...") self._train() torch.save({ 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict() }, 'model.pth') logger.debug("Training complete!") def _train(self): for i in range(self.iterations): if i % self.eval_interval == 0: losses = estimate_loss( self.model, self.batcher, self.eval_interval, self.device) logger.debug( f"step {i}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") context_stack, answer_stack = self.batcher.get_batch(split='train') logits = self.model(context_stack.to( self.device), answer_stack.to(self.device)) batch, block, vocab = logits.shape # Reformat logits and val so each entry can be compared logits = logits.view(batch * block, vocab).to(self.device) answer_stack = answer_stack.view(batch * block).to(self.device) # Compare predicted tokens to actual loss = F.cross_entropy(logits, answer_stack) self.optimizer.zero_grad(set_to_none=True) loss.backward() self.optimizer.step() class BadGPT(): def __init__( self, device: Literal['cuda', 'cpu'], block_size: int, batch_size: int, n_embd: int, n_head: int, n_layers: int, dropout: float, eval_interval: int, iterations: int, lr: float ): self.device = device self._batcher = Batcher( device=device, batch_size=batch_size, block_size=block_size ) self._model = BadGPTModel( device=device, block_size=block_size, vocab_size=len(self._batcher.vocab), n_embd=n_embd, n_head=n_head, n_layers=n_layers, dropout=dropout ).to(device) self._trainer = BadGPTTrainer( model=self._model, batcher=self._batcher, eval_interval=eval_interval, iterations=iterations, learning_rate=lr ) self._trainer.train() # set to eval phase since we're only taking user input from here on self._model.eval() def generate(self, prompt: str, response_size: int): start_ids = encode(prompt) context = torch.tensor(start_ids, dtype=torch.long, device=self.device) # add batch dimension. it's just 1 batch, but we still need it cuz tensors context = context[None, ...] encoded = self._model.generate( ctx=context, max_new_tokens=response_size)[0] return decode(encoded.tolist())