import torch from torch import nn class FeedForward(nn.Module): def __init__(self, n_embd: int, dropout: float): super().__init__() self.net = nn.Sequential( # Scale out data before applying ReLU so we get more variance nn.Linear(n_embd, n_embd * 4), nn.ReLU(), # Scale back down before returning, effectively averaging the variance from earlier nn.Linear(n_embd * 4, n_embd), nn.Dropout(dropout) ) def forward(self, x: torch.Tensor): return self.net(x) class Block(nn.Module): def __init__(self, n_embd: int, block_size: int, n_head: int, dropout: float): super().__init__() head_size = n_embd // n_head self.sa_head = MultiHead( n_head, block_size, n_embd, head_size, dropout) self.ffwd = FeedForward(n_embd, dropout) self.norm1 = nn.LayerNorm(n_embd) self.norm2 = nn.LayerNorm(n_embd) def forward(self, x: torch.Tensor): x = x + self.sa_head(self.norm1(x)) x = x + self.ffwd(self.norm2(x)) return x class MultiHead(nn.Module): def __init__(self, num_heads: int, block_size: int, n_embd: int, head_size: int, dropout: float): super().__init__() self.heads = nn.ModuleList( [Head(block_size, n_embd, head_size, dropout) for _ in range(num_heads)]) self.proj = nn.Linear(n_embd, n_embd) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor): out = torch.cat([head(x) for head in self.heads], dim=-1) out = self.proj(out) return self.drop(out) class Head(nn.Module): def __init__(self, block_size: int, n_embd: int, head_size: int, dropout: float): super().__init__() self.key = nn.Linear(n_embd, head_size, bias=False) self.query = nn.Linear(n_embd, head_size, bias=False) self.value = nn.Linear(n_embd, head_size, bias=False) self.register_buffer('tril', torch.tril( torch.ones(block_size, block_size))) self.drop = nn.Dropout(dropout) def forward(self, x: torch.Tensor): # From Attention is All You Need # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V # In this case, Q K and V are all the same. d_k is head_size, and T is basically 1 # not transposing since we just want the queries q: torch.Tensor = self.query(x) # transposing so we can get the keys (vocab) from the last dimension k: torch.Tensor = self.key(x).transpose(-2, -1) v = self.value(x) # Compute attention scores B, T, C = x.shape wei = q @ k # Q * K^T # wei / sqrt(d_k), normalize weights wei: torch.Tensor = wei * (C**-0.5) wei = wei.masked_fill(self.tril[:T, :T] == 0, float( '-inf')) # Ignore future tokens # aggregate values wei = torch.softmax(wei, dim=-1) wei = self.drop(wei) out: torch.Tensor = wei @ v return out