shamashel commited on
Commit
e0a30a3
·
1 Parent(s): 4e9451d

Add actual prompt instead of randomly generating text

Browse files
Files changed (1) hide show
  1. main.py +6 -3
main.py CHANGED
@@ -55,8 +55,11 @@ m = BigramLanguageModel(
55
  ).to(DEVICE)
56
 
57
 
58
- def run_model(model: nn.Module, response_size: int = BLOCK_SIZE):
59
- context = torch.zeros((1, 1), dtype=torch.long, device=DEVICE)
 
 
 
60
  encoded = model.generate(
61
  idx=context, max_new_tokens=response_size)[0]
62
  return decode(encoded.tolist())
@@ -76,5 +79,5 @@ else:
76
  }, 'model.pth')
77
  print("Training complete!")
78
  print("Generating response...\n")
79
- resp = run_model(m, 256)
80
  print("Response:", resp)
 
55
  ).to(DEVICE)
56
 
57
 
58
+ def run_model(model: nn.Module, response_size: int = BLOCK_SIZE, query: str = ''):
59
+ start_ids = encode(query)
60
+ context = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)
61
+ # add batch dimension. it's just 1 batch, but we still need it cuz tensors
62
+ context = context[None, ...]
63
  encoded = model.generate(
64
  idx=context, max_new_tokens=response_size)[0]
65
  return decode(encoded.tolist())
 
79
  }, 'model.pth')
80
  print("Training complete!")
81
  print("Generating response...\n")
82
+ resp = run_model(m, 256, 'To be or not to be, that is the question:')
83
  print("Response:", resp)