da03 commited on
Commit
3098025
·
1 Parent(s): ae9daf1
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -40,15 +40,15 @@ def predict_product(num1, num2):
40
  generated_ids = inputs['input_ids']
41
  past_key_values = None
42
  for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
43
- outputs = model(
44
  input_ids=generated_ids,
 
45
  past_key_values=past_key_values,
 
46
  use_cache=True
47
  )
48
- logits = outputs.logits
49
-
50
- next_token_id = torch.argmax(logits[:, -1, :], dim=-1)
51
- generated_ids = torch.cat((generated_ids, next_token_id.view(1,-1)), dim=-1)
52
  print (next_token_id)
53
 
54
  if next_token_id.item() == eos_token_id:
 
40
  generated_ids = inputs['input_ids']
41
  past_key_values = None
42
  for _ in range(MAX_PRODUCT_DIGITS): # Set a maximum limit to prevent infinite loops
43
+ outputs = model.generate(
44
  input_ids=generated_ids,
45
+ max_new_tokens=1,
46
  past_key_values=past_key_values,
47
+ return_dict_in_generate=True,
48
  use_cache=True
49
  )
50
+ generated_ids = outputs.sequences
51
+ next_token_id = generated_ids[0, -1]
 
 
52
  print (next_token_id)
53
 
54
  if next_token_id.item() == eos_token_id: