MasaakiKotera
commited on
Upload model.py with huggingface_hub
Browse files
model.py
CHANGED
@@ -145,7 +145,7 @@ class GPT(nn.Module):
|
|
145 |
|
146 |
def _init_weights(self, module):
|
147 |
if isinstance(module, nn.Linear):
|
148 |
-
torch.nn.init.
|
149 |
if module.bias is not None:
|
150 |
torch.nn.init.zeros_(module.bias)
|
151 |
elif isinstance(module, nn.Embedding):
|
@@ -285,90 +285,80 @@ class GPT(nn.Module):
|
|
285 |
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
286 |
mfu = flops_achieved / flops_promised
|
287 |
return mfu
|
288 |
-
|
289 |
@torch.no_grad()
|
290 |
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
|
291 |
-
"""
|
292 |
-
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
293 |
-
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
294 |
-
Strategy can be 'greedy', 'sampling' or 'top-k'.
|
295 |
-
"""
|
296 |
-
# check strategy valid
|
297 |
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
all_candidates = []
|
325 |
-
|
326 |
-
for i, seq in enumerate(beam_seqs):
|
327 |
-
# Get next token probabilities
|
328 |
-
idx_cond = seq if seq.size(1) <= self.config.block_size else seq[:, -self.config.block_size:]
|
329 |
-
logits, __ = self(idx_cond)
|
330 |
-
logits = logits[:, -1, :]
|
331 |
-
probs = F.log_softmax(logits, dim=-1) # Use log probs to avoid numerical instability
|
332 |
-
|
333 |
-
# Get top sequences for this beam (we could use more than beam_size here for diversity)
|
334 |
-
scores, indices = torch.topk(probs, beam_size)
|
335 |
-
for j in range(beam_size):
|
336 |
-
candidate_seq = torch.cat([seq, indices[:, j:j+1]], dim=1)
|
337 |
-
candidate_score = beam_scores[:, i] + scores[:, j]
|
338 |
-
|
339 |
all_candidates.append((candidate_score, candidate_seq))
|
340 |
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
361 |
probs = F.softmax(logits, dim=-1)
|
362 |
idx_next = torch.multinomial(probs, num_samples=1)
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
|
|
|
|
|
|
|
|
367 |
if idx_next == eos_token_id:
|
368 |
break
|
369 |
idx = torch.cat((idx, idx_next), dim=1)
|
370 |
|
371 |
-
|
372 |
return idx if idx[0][0] != eos_token_id else idx[:, 1:]
|
373 |
-
|
374 |
-
|
|
|
145 |
|
146 |
def _init_weights(self, module):
|
147 |
if isinstance(module, nn.Linear):
|
148 |
+
torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
|
149 |
if module.bias is not None:
|
150 |
torch.nn.init.zeros_(module.bias)
|
151 |
elif isinstance(module, nn.Embedding):
|
|
|
285 |
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
286 |
mfu = flops_achieved / flops_promised
|
287 |
return mfu
|
288 |
+
|
289 |
@torch.no_grad()
|
290 |
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
|
292 |
|
293 |
+
batch_size = idx.size(0)
|
294 |
+
if strategy == 'beam_search':
|
295 |
+
# Initialize beams
|
296 |
+
beam_seqs = [idx.clone() for _ in range(beam_size)]
|
297 |
+
beam_scores = torch.zeros((batch_size, beam_size), device=idx.device)
|
298 |
+
completed_seqs = []
|
299 |
+
|
300 |
+
for _ in range(max_new_tokens):
|
301 |
+
all_candidates = []
|
302 |
+
for i in range(beam_size):
|
303 |
+
idx_cond = beam_seqs[i] if beam_seqs[i].size(1) <= self.config.block_size else beam_seqs[i][:, -self.config.block_size:]
|
304 |
+
logits, _ = self(idx_cond)
|
305 |
+
logits = logits[:, -1, :] / temperature
|
306 |
+
if repetition_penalty != 1.0:
|
307 |
+
for j in range(idx_cond.size(1)):
|
308 |
+
logits[:, idx_cond[:, j]] /= repetition_penalty
|
309 |
+
probs = F.log_softmax(logits, dim=-1)
|
310 |
+
scores, indices = torch.topk(probs, beam_size, dim=-1)
|
311 |
+
|
312 |
+
for j in range(beam_size):
|
313 |
+
candidate_seq = torch.cat([beam_seqs[i], indices[:, j:j+1]], dim=1)
|
314 |
+
candidate_score = beam_scores[:, i] + scores[:, j]
|
315 |
+
if indices[0, j] == eos_token_id:
|
316 |
+
completed_seqs.append((candidate_score, candidate_seq))
|
317 |
+
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
318 |
all_candidates.append((candidate_score, candidate_seq))
|
319 |
|
320 |
+
# add random noise when sorting beacause that generated sequences of beam_search remain unchanged if they have the same prefix
|
321 |
+
all_candidates.sort(key=lambda x: x[0].mean().item() + torch.rand(1).item() * 5e-1, reverse=True)
|
322 |
+
|
323 |
+
beam_seqs = [all_candidates[i][1] for i in range(min(beam_size, len(all_candidates)))]
|
324 |
+
beam_scores = torch.stack([all_candidates[i][0] for i in range(min(beam_size, len(all_candidates)))], dim=1)
|
325 |
+
if len(completed_seqs) >= beam_size:
|
326 |
+
break
|
327 |
+
|
328 |
+
if not completed_seqs:
|
329 |
+
completed_seqs = [(beam_scores[:, i], beam_seqs[i]) for i in range(beam_size)]
|
330 |
+
|
331 |
+
completed_seqs.sort(key=lambda x: x[0].mean().item(), reverse=True)
|
332 |
+
return completed_seqs[0][1]
|
333 |
+
|
334 |
+
|
335 |
+
else:
|
336 |
+
for _ in range(max_new_tokens):
|
337 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
338 |
+
logits, _ = self(idx_cond)
|
339 |
+
logits = logits[:, -1, :] / temperature
|
340 |
+
|
341 |
+
if repetition_penalty != 1.0:
|
342 |
+
for i in range(idx.size(0)):
|
343 |
+
for j in range(idx.size(1)):
|
344 |
+
logits[i, idx[i, j]] /= repetition_penalty
|
345 |
+
|
346 |
+
if strategy == 'greedy_search':
|
347 |
+
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
|
348 |
+
|
349 |
+
elif strategy == 'sampling':
|
350 |
probs = F.softmax(logits, dim=-1)
|
351 |
idx_next = torch.multinomial(probs, num_samples=1)
|
352 |
+
|
353 |
+
elif strategy == 'top_k':
|
354 |
+
if top_k is not None:
|
355 |
+
logits, indices = torch.topk(logits, min(top_k, logits.size(-1)))
|
356 |
+
probs = F.softmax(logits, dim=-1)
|
357 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
358 |
+
idx_next = torch.gather(indices, dim=-1, index=idx_next)
|
359 |
+
|
360 |
if idx_next == eos_token_id:
|
361 |
break
|
362 |
idx = torch.cat((idx, idx_next), dim=1)
|
363 |
|
|
|
364 |
return idx if idx[0][0] != eos_token_id else idx[:, 1:]
|
|
|
|