shamashel commited on
Commit
a3df5a4
·
1 Parent(s): bb4e416
Files changed (1) hide show
  1. bad_gpt.py +1 -1
bad_gpt.py CHANGED
@@ -82,7 +82,7 @@ def estimate_loss(gpt: BadGPTModel, batcher: Batcher, eval_interval: int, device
82
  losses = torch.zeros(eval_interval)
83
  for epoch in range(eval_interval):
84
  train, answer = batcher.get_batch(split='train')
85
- logits = gpt.forward(train)
86
  # Reformat pediction and answer so each entry can be compared
87
  batch, block, vocab = logits.shape
88
  logits = logits.view(batch * block, vocab)
 
82
  losses = torch.zeros(eval_interval)
83
  for epoch in range(eval_interval):
84
  train, answer = batcher.get_batch(split='train')
85
+ logits = gpt(train)
86
  # Reformat pediction and answer so each entry can be compared
87
  batch, block, vocab = logits.shape
88
  logits = logits.view(batch * block, vocab)