johnowhitaker commited on
Commit
eee3c6d
·
1 Parent(s): 6e64f14

from minGPT

Browse files
Files changed (1) hide show
  1. model.py +199 -0
model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT model:
3
+ - the initial stem consists of a combination of token encoding and a positional encoding
4
+ - the meat of it is a uniform sequence of Transformer blocks
5
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6
+ - all blocks feed into a central residual pathway similar to resnets
7
+ - the final decoder is a linear projection into a vanilla Softmax classifier
8
+ """
9
+
10
+ import math
11
+ import logging
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import functional as F
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ class GPTConfig:
20
+ """ base GPT config, params common to all GPT versions """
21
+ embd_pdrop = 0.1
22
+ resid_pdrop = 0.1
23
+ attn_pdrop = 0.1
24
+
25
+ def __init__(self, vocab_size, block_size, **kwargs):
26
+ self.vocab_size = vocab_size
27
+ self.block_size = block_size
28
+ for k,v in kwargs.items():
29
+ setattr(self, k, v)
30
+
31
+ class GPT1Config(GPTConfig):
32
+ """ GPT-1 like network roughly 125M params """
33
+ n_layer = 12
34
+ n_head = 12
35
+ n_embd = 768
36
+
37
+ class CausalSelfAttention(nn.Module):
38
+ """
39
+ A vanilla multi-head masked self-attention layer with a projection at the end.
40
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
41
+ explicit implementation here to show that there is nothing too scary here.
42
+ """
43
+
44
+ def __init__(self, config):
45
+ super().__init__()
46
+ assert config.n_embd % config.n_head == 0
47
+ # key, query, value projections for all heads
48
+ self.key = nn.Linear(config.n_embd, config.n_embd)
49
+ self.query = nn.Linear(config.n_embd, config.n_embd)
50
+ self.value = nn.Linear(config.n_embd, config.n_embd)
51
+ # regularization
52
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
53
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
54
+ # output projection
55
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
56
+ # causal mask to ensure that attention is only applied to the left in the input sequence
57
+ self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
58
+ .view(1, 1, config.block_size, config.block_size))
59
+ self.n_head = config.n_head
60
+
61
+ def forward(self, x):
62
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68
+
69
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
71
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
72
+ att = F.softmax(att, dim=-1)
73
+ att = self.attn_drop(att)
74
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
75
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
76
+
77
+ # output projection
78
+ y = self.resid_drop(self.proj(y))
79
+ return y
80
+
81
+ class Block(nn.Module):
82
+ """ an unassuming Transformer block """
83
+
84
+ def __init__(self, config):
85
+ super().__init__()
86
+ self.ln1 = nn.LayerNorm(config.n_embd)
87
+ self.ln2 = nn.LayerNorm(config.n_embd)
88
+ self.attn = CausalSelfAttention(config)
89
+ self.mlp = nn.Sequential(
90
+ nn.Linear(config.n_embd, 4 * config.n_embd),
91
+ nn.GELU(),
92
+ nn.Linear(4 * config.n_embd, config.n_embd),
93
+ nn.Dropout(config.resid_pdrop),
94
+ )
95
+
96
+ def forward(self, x):
97
+ x = x + self.attn(self.ln1(x))
98
+ x = x + self.mlp(self.ln2(x))
99
+ return x
100
+
101
+ class GPT(nn.Module):
102
+ """ the full GPT language model, with a context size of block_size """
103
+
104
+ def __init__(self, config):
105
+ super().__init__()
106
+
107
+ # input embedding stem
108
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
109
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
110
+ self.drop = nn.Dropout(config.embd_pdrop)
111
+ # transformer
112
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
113
+ # decoder head
114
+ self.ln_f = nn.LayerNorm(config.n_embd)
115
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116
+
117
+ self.block_size = config.block_size
118
+ self.apply(self._init_weights)
119
+
120
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
121
+
122
+ def get_block_size(self):
123
+ return self.block_size
124
+
125
+ def _init_weights(self, module):
126
+ if isinstance(module, (nn.Linear, nn.Embedding)):
127
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
128
+ if isinstance(module, nn.Linear) and module.bias is not None:
129
+ torch.nn.init.zeros_(module.bias)
130
+ elif isinstance(module, nn.LayerNorm):
131
+ torch.nn.init.zeros_(module.bias)
132
+ torch.nn.init.ones_(module.weight)
133
+ elif isinstance(module, GPT):
134
+ torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
135
+
136
+ def configure_optimizers(self, train_config):
137
+ """
138
+ This long function is unfortunately doing something very simple and is being very defensive:
139
+ We are separating out all parameters of the model into two buckets: those that will experience
140
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
141
+ We are then returning the PyTorch optimizer object.
142
+ """
143
+
144
+ # separate out all parameters to those that will and won't experience regularizing weight decay
145
+ decay = set()
146
+ no_decay = set()
147
+ whitelist_weight_modules = (torch.nn.Linear, )
148
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
149
+ for mn, m in self.named_modules():
150
+ for pn, p in m.named_parameters():
151
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
152
+
153
+ if pn.endswith('bias'):
154
+ # all biases will not be decayed
155
+ no_decay.add(fpn)
156
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
157
+ # weights of whitelist modules will be weight decayed
158
+ decay.add(fpn)
159
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
160
+ # weights of blacklist modules will NOT be weight decayed
161
+ no_decay.add(fpn)
162
+
163
+ # special case the position embedding parameter in the root GPT module as not decayed
164
+ no_decay.add('pos_emb')
165
+
166
+ # validate that we considered every parameter
167
+ param_dict = {pn: p for pn, p in self.named_parameters()}
168
+ inter_params = decay & no_decay
169
+ union_params = decay | no_decay
170
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
171
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
172
+ % (str(param_dict.keys() - union_params), )
173
+
174
+ # create the pytorch optimizer object
175
+ optim_groups = [
176
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
177
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
178
+ ]
179
+ optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
180
+ return optimizer
181
+
182
+ def forward(self, idx, targets=None):
183
+ b, t = idx.size()
184
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
185
+
186
+ # forward the GPT model
187
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
188
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
189
+ x = self.drop(token_embeddings + position_embeddings)
190
+ x = self.blocks(x)
191
+ x = self.ln_f(x)
192
+ logits = self.head(x)
193
+
194
+ # if we are given some desired targets also calculate the loss
195
+ loss = None
196
+ if targets is not None:
197
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
198
+
199
+ return logits, loss