Whoops
Browse files- bad_gpt.py +1 -1
bad_gpt.py
CHANGED
@@ -82,7 +82,7 @@ def estimate_loss(gpt: BadGPTModel, batcher: Batcher, eval_interval: int, device
|
|
82 |
losses = torch.zeros(eval_interval)
|
83 |
for epoch in range(eval_interval):
|
84 |
train, answer = batcher.get_batch(split='train')
|
85 |
-
logits = gpt
|
86 |
# Reformat pediction and answer so each entry can be compared
|
87 |
batch, block, vocab = logits.shape
|
88 |
logits = logits.view(batch * block, vocab)
|
|
|
82 |
losses = torch.zeros(eval_interval)
|
83 |
for epoch in range(eval_interval):
|
84 |
train, answer = batcher.get_batch(split='train')
|
85 |
+
logits = gpt(train)
|
86 |
# Reformat pediction and answer so each entry can be compared
|
87 |
batch, block, vocab = logits.shape
|
88 |
logits = logits.view(batch * block, vocab)
|