MasaakiKotera commited on
Commit
106d320
·
verified ·
1 Parent(s): 8a7f15e

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +65 -75
model.py CHANGED
@@ -145,7 +145,7 @@ class GPT(nn.Module):
145
 
146
  def _init_weights(self, module):
147
  if isinstance(module, nn.Linear):
148
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
149
  if module.bias is not None:
150
  torch.nn.init.zeros_(module.bias)
151
  elif isinstance(module, nn.Embedding):
@@ -285,90 +285,80 @@ class GPT(nn.Module):
285
  flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
286
  mfu = flops_achieved / flops_promised
287
  return mfu
288
-
289
  @torch.no_grad()
290
  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
291
- """
292
- Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
293
- the sequence max_new_tokens times, feeding the predictions back into the model each time.
294
- Strategy can be 'greedy', 'sampling' or 'top-k'.
295
- """
296
- # check strategy valid
297
  assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
298
 
299
- for _ in range(max_new_tokens):
300
- # if the sequence context is growing too long we must crop it at block_size
301
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
302
- # forward the model to get the logits for the index in the sequence
303
- logits, _ = self(idx_cond)
304
- # pluck the logits at the final step and scale by desired temperature
305
- logits = logits[:, -1, :] / temperature
306
-
307
- # Apply repetition penalty
308
- if repetition_penalty != 1.0:
309
- for i in range(idx.size(0)):
310
- for j in range(idx.size(1)):
311
- logits[i, idx[i, j]] /= repetition_penalty
312
-
313
- if strategy == 'greedy_search':
314
- idx_next = torch.argmax(logits, dim=-1, keepdim=True)
315
-
316
- elif strategy == 'beam_search':
317
- # If beam search is the selected strategy
318
-
319
- # First, initialize the sequences and their scores
320
- beam_seqs = [idx] * beam_size
321
- beam_scores = torch.zeros((idx.size(0), beam_size), device=idx.device)
322
-
323
- for k in range(max_new_tokens):
324
- all_candidates = []
325
-
326
- for i, seq in enumerate(beam_seqs):
327
- # Get next token probabilities
328
- idx_cond = seq if seq.size(1) <= self.config.block_size else seq[:, -self.config.block_size:]
329
- logits, __ = self(idx_cond)
330
- logits = logits[:, -1, :]
331
- probs = F.log_softmax(logits, dim=-1) # Use log probs to avoid numerical instability
332
-
333
- # Get top sequences for this beam (we could use more than beam_size here for diversity)
334
- scores, indices = torch.topk(probs, beam_size)
335
- for j in range(beam_size):
336
- candidate_seq = torch.cat([seq, indices[:, j:j+1]], dim=1)
337
- candidate_score = beam_scores[:, i] + scores[:, j]
338
-
339
  all_candidates.append((candidate_score, candidate_seq))
340
 
341
- # Sort all candidates by score
342
- all_candidates.sort(key=lambda x: -x[0].mean().item()) # Average score over the batch
343
-
344
- # Get the top sequences
345
- beam_seqs = [all_candidates[i][1] for i in range(beam_size)]
346
- beam_scores = torch.stack([all_candidates[i][0] for i in range(beam_size)], dim=1)
347
-
348
- # At the end, choose the sequence with the highest score
349
- idx = beam_seqs[0]
350
- return idx if idx[0][0] != eos_token_id else idx[:, 1:]
351
-
352
- elif strategy == 'sampling':
353
- # apply softmax to convert logits to (normalized) probabilities
354
- probs = F.softmax(logits, dim=-1)
355
- # sample from the distribution
356
- idx_next = torch.multinomial(probs, num_samples=1)
357
-
358
- elif strategy == 'top_k':
359
- if top_k is not None:
360
- logits, indices = torch.topk(logits, min(top_k, logits.size(-1)))
 
 
 
 
 
 
 
 
 
 
361
  probs = F.softmax(logits, dim=-1)
362
  idx_next = torch.multinomial(probs, num_samples=1)
363
- idx_next = torch.gather(indices, dim=-1, index=idx_next)
364
-
365
- # append sampled index to the running sequence and continue
366
- if strategy != 'beam_search':
 
 
 
 
367
  if idx_next == eos_token_id:
368
  break
369
  idx = torch.cat((idx, idx_next), dim=1)
370
 
371
-
372
  return idx if idx[0][0] != eos_token_id else idx[:, 1:]
373
-
374
-
 
145
 
146
  def _init_weights(self, module):
147
  if isinstance(module, nn.Linear):
148
+ torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in', nonlinearity='relu')
149
  if module.bias is not None:
150
  torch.nn.init.zeros_(module.bias)
151
  elif isinstance(module, nn.Embedding):
 
285
  flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
286
  mfu = flops_achieved / flops_promised
287
  return mfu
288
+
289
  @torch.no_grad()
290
  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None, strategy='sampling', beam_size=3, eos_token_id=0, repetition_penalty=1.0):
 
 
 
 
 
 
291
  assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search']
292
 
293
+ batch_size = idx.size(0)
294
+ if strategy == 'beam_search':
295
+ # Initialize beams
296
+ beam_seqs = [idx.clone() for _ in range(beam_size)]
297
+ beam_scores = torch.zeros((batch_size, beam_size), device=idx.device)
298
+ completed_seqs = []
299
+
300
+ for _ in range(max_new_tokens):
301
+ all_candidates = []
302
+ for i in range(beam_size):
303
+ idx_cond = beam_seqs[i] if beam_seqs[i].size(1) <= self.config.block_size else beam_seqs[i][:, -self.config.block_size:]
304
+ logits, _ = self(idx_cond)
305
+ logits = logits[:, -1, :] / temperature
306
+ if repetition_penalty != 1.0:
307
+ for j in range(idx_cond.size(1)):
308
+ logits[:, idx_cond[:, j]] /= repetition_penalty
309
+ probs = F.log_softmax(logits, dim=-1)
310
+ scores, indices = torch.topk(probs, beam_size, dim=-1)
311
+
312
+ for j in range(beam_size):
313
+ candidate_seq = torch.cat([beam_seqs[i], indices[:, j:j+1]], dim=1)
314
+ candidate_score = beam_scores[:, i] + scores[:, j]
315
+ if indices[0, j] == eos_token_id:
316
+ completed_seqs.append((candidate_score, candidate_seq))
317
+ else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
  all_candidates.append((candidate_score, candidate_seq))
319
 
320
+ # add random noise when sorting beacause that generated sequences of beam_search remain unchanged if they have the same prefix
321
+ all_candidates.sort(key=lambda x: x[0].mean().item() + torch.rand(1).item() * 5e-1, reverse=True)
322
+
323
+ beam_seqs = [all_candidates[i][1] for i in range(min(beam_size, len(all_candidates)))]
324
+ beam_scores = torch.stack([all_candidates[i][0] for i in range(min(beam_size, len(all_candidates)))], dim=1)
325
+ if len(completed_seqs) >= beam_size:
326
+ break
327
+
328
+ if not completed_seqs:
329
+ completed_seqs = [(beam_scores[:, i], beam_seqs[i]) for i in range(beam_size)]
330
+
331
+ completed_seqs.sort(key=lambda x: x[0].mean().item(), reverse=True)
332
+ return completed_seqs[0][1]
333
+
334
+
335
+ else:
336
+ for _ in range(max_new_tokens):
337
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
338
+ logits, _ = self(idx_cond)
339
+ logits = logits[:, -1, :] / temperature
340
+
341
+ if repetition_penalty != 1.0:
342
+ for i in range(idx.size(0)):
343
+ for j in range(idx.size(1)):
344
+ logits[i, idx[i, j]] /= repetition_penalty
345
+
346
+ if strategy == 'greedy_search':
347
+ idx_next = torch.argmax(logits, dim=-1, keepdim=True)
348
+
349
+ elif strategy == 'sampling':
350
  probs = F.softmax(logits, dim=-1)
351
  idx_next = torch.multinomial(probs, num_samples=1)
352
+
353
+ elif strategy == 'top_k':
354
+ if top_k is not None:
355
+ logits, indices = torch.topk(logits, min(top_k, logits.size(-1)))
356
+ probs = F.softmax(logits, dim=-1)
357
+ idx_next = torch.multinomial(probs, num_samples=1)
358
+ idx_next = torch.gather(indices, dim=-1, index=idx_next)
359
+
360
  if idx_next == eos_token_id:
361
  break
362
  idx = torch.cat((idx, idx_next), dim=1)
363
 
 
364
  return idx if idx[0][0] != eos_token_id else idx[:, 1:]