|
import torch |
|
from torch import nn |
|
|
|
|
|
class MultiHead(nn.Module): |
|
def __init__(self, num_heads: int, block_size: int, n_embd: int, head_size: int, dropout: float): |
|
super().__init__() |
|
self.heads = nn.ModuleList( |
|
[Head(block_size, n_embd, head_size, dropout) for _ in range(num_heads)]) |
|
self.proj = nn.Linear(n_embd, n_embd) |
|
self.drop = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor): |
|
out = torch.cat([head(x) for head in self.heads], dim=-1) |
|
out = self.proj(out) |
|
return self.drop(out) |
|
|
|
|
|
class Head(nn.Module): |
|
def __init__(self, block_size: int, n_embd: int, head_size: int, dropout: float): |
|
super().__init__() |
|
self.key = nn.Linear(n_embd, head_size, bias=False) |
|
self.query = nn.Linear(n_embd, head_size, bias=False) |
|
self.value = nn.Linear(n_embd, head_size, bias=False) |
|
self.register_buffer('tril', torch.tril( |
|
torch.ones(block_size, block_size))) |
|
self.drop = nn.Dropout(dropout) |
|
|
|
def forward(self, x: torch.Tensor): |
|
|
|
|
|
|
|
|
|
q: torch.Tensor = self.query(x) |
|
|
|
k: torch.Tensor = self.key(x).transpose(-2, -1) |
|
v = self.value(x) |
|
|
|
B, T, C = x.shape |
|
wei = q @ k |
|
|
|
wei: torch.Tensor = wei * (C**-0.5) |
|
wei = wei.masked_fill(self.tril[:T, :T] == 0, float( |
|
'-inf')) |
|
|
|
wei = torch.softmax(wei, dim=-1) |
|
wei = self.drop(wei) |
|
out: torch.Tensor = wei @ v |
|
return out |
|
|