shamashel commited on
Commit
b4d4e2a
·
1 Parent(s): 070fe06

Update logging

Browse files
Files changed (2) hide show
  1. bad_gpt.py +3 -1
  2. main.py +7 -4
bad_gpt.py CHANGED
@@ -57,7 +57,9 @@ class BadGPTModel(nn.Module):
57
  # generate new tokens in the next sentence
58
  def generate(self, idx: torch.Tensor, max_new_tokens: int):
59
  for _ in range(max_new_tokens):
60
- print(f'Iteration {_} of {max_new_tokens}')
 
 
61
  # Crop out the last block_size tokens
62
  cropped_idx = idx[:, -self.block_size:]
63
  logits = self(cropped_idx)
 
57
  # generate new tokens in the next sentence
58
  def generate(self, idx: torch.Tensor, max_new_tokens: int):
59
  for _ in range(max_new_tokens):
60
+ # Log progress so I don't go insane
61
+ if _ % 16 == 0:
62
+ logger.debug(f'Iteration {_} of {max_new_tokens}')
63
  # Crop out the last block_size tokens
64
  cropped_idx = idx[:, -self.block_size:]
65
  logits = self(cropped_idx)
main.py CHANGED
@@ -1,6 +1,8 @@
1
- import torch
2
-
3
  from bad_gpt import BadGPT
 
 
 
 
4
 
5
  # HYPERPARAMETERS #
6
  ### Impacts performance ###
@@ -19,6 +21,7 @@ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
 
21
  if __name__ == '__main__':
 
22
  bad_gpt = BadGPT(
23
  device=DEVICE,
24
  batch_size=BATCH_SIZE,
@@ -32,7 +35,7 @@ if __name__ == '__main__':
32
  lr=LEARNING_RATE
33
  )
34
 
35
- print("Generating response...\n")
36
  resp = bad_gpt.generate(
37
  'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256)
38
- print("Response:\n" + resp)
 
 
 
1
  from bad_gpt import BadGPT
2
+ import torch
3
+ import logging
4
+ logging.basicConfig()
5
+ logger = logging.getLogger('bad_gpt').getChild(__name__)
6
 
7
  # HYPERPARAMETERS #
8
  ### Impacts performance ###
 
21
 
22
 
23
  if __name__ == '__main__':
24
+ logging.getLogger('bad_gpt').setLevel(logging.DEBUG)
25
  bad_gpt = BadGPT(
26
  device=DEVICE,
27
  batch_size=BATCH_SIZE,
 
35
  lr=LEARNING_RATE
36
  )
37
 
38
+ logger.info("Generating response...")
39
  resp = bad_gpt.generate(
40
  'JULIET:\nRomeo, Romeo, wherefore art thou Romeo?', 256)
41
+ logger.info("Response:\n" + resp)