Add actual prompt instead of randomly generating text
Browse files
main.py
CHANGED
@@ -55,8 +55,11 @@ m = BigramLanguageModel(
|
|
55 |
).to(DEVICE)
|
56 |
|
57 |
|
58 |
-
def run_model(model: nn.Module, response_size: int = BLOCK_SIZE):
|
59 |
-
|
|
|
|
|
|
|
60 |
encoded = model.generate(
|
61 |
idx=context, max_new_tokens=response_size)[0]
|
62 |
return decode(encoded.tolist())
|
@@ -76,5 +79,5 @@ else:
|
|
76 |
}, 'model.pth')
|
77 |
print("Training complete!")
|
78 |
print("Generating response...\n")
|
79 |
-
resp = run_model(m, 256)
|
80 |
print("Response:", resp)
|
|
|
55 |
).to(DEVICE)
|
56 |
|
57 |
|
58 |
+
def run_model(model: nn.Module, response_size: int = BLOCK_SIZE, query: str = ''):
|
59 |
+
start_ids = encode(query)
|
60 |
+
context = torch.tensor(start_ids, dtype=torch.long, device=DEVICE)
|
61 |
+
# add batch dimension. it's just 1 batch, but we still need it cuz tensors
|
62 |
+
context = context[None, ...]
|
63 |
encoded = model.generate(
|
64 |
idx=context, max_new_tokens=response_size)[0]
|
65 |
return decode(encoded.tolist())
|
|
|
79 |
}, 'model.pth')
|
80 |
print("Training complete!")
|
81 |
print("Generating response...\n")
|
82 |
+
resp = run_model(m, 256, 'To be or not to be, that is the question:')
|
83 |
print("Response:", resp)
|