Spaces:
Running
on
Zero
Running
on
Zero
da03
commited on
Commit
·
3098025
1
Parent(s):
ae9daf1
app.py
CHANGED
@@ -40,15 +40,15 @@ def predict_product(num1, num2):
|
|
40 |
generated_ids = inputs['input_ids']
|
41 |
past_key_values = None
|
42 |
for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
|
43 |
-
outputs = model(
|
44 |
input_ids=generated_ids,
|
|
|
45 |
past_key_values=past_key_values,
|
|
|
46 |
use_cache=True
|
47 |
)
|
48 |
-
|
49 |
-
|
50 |
-
next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
|
51 |
-
generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
|
52 |
print (next_token_id)
|
53 |
|
54 |
if next_token_id.item() == eos_token_id:
|
|
|
40 |
generated_ids = inputs['input_ids']
|
41 |
past_key_values = None
|
42 |
for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
|
43 |
+
outputs = model.generate(
|
44 |
input_ids=generated_ids,
|
45 |
+
max_new_tokens=1,
|
46 |
past_key_values=past_key_values,
|
47 |
+
return_dict_in_generate=True,
|
48 |
use_cache=True
|
49 |
)
|
50 |
+
generated_ids = outputs.sequences
|
51 |
+
next_token_id = generated_ids[0, -1]
|
|
|
|
|
52 |
print (next_token_id)
|
53 |
|
54 |
if next_token_id.item() == eos_token_id:
|