BAAI
/

ldwang commited on
Commit
e738e43
·
1 Parent(s): 4e307b5

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +9 -5
predict.py CHANGED
@@ -310,6 +310,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
310
 
311
  example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
312
 
 
 
 
313
  while(len(history) > 0 and (len(example) < max_token)):
314
  tmp = history.pop()
315
  if tmp[0] == 'ASSISTANT':
@@ -333,7 +336,7 @@ def predict(model, text, tokenizer=None,
333
  sft=True, convo_template = "",
334
  device = "cuda",
335
  model_name="AquilaChat2-7B",
336
- history=[],
337
  **kwargs):
338
 
339
  vocab = tokenizer.get_vocab()
@@ -353,7 +356,7 @@ def predict(model, text, tokenizer=None,
353
  topk = 1
354
  temperature = 1.0
355
  if sft:
356
- tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=1000000, convo_template=convo_template)
357
  tokens = torch.tensor(tokens)[None,].to(device)
358
  else :
359
  tokens = tokenizer.encode_plus(text)["input_ids"]
@@ -435,8 +438,9 @@ def predict(model, text, tokenizer=None,
435
  convert_tokens = convert_tokens[1:]
436
  probs = probs[1:]
437
 
438
- # Update history
439
- history.insert(0, ('ASSISTANT', out))
440
- history.insert(0, ('USER', text))
 
441
 
442
  return out
 
310
 
311
  example = tokenizer.encode_plus(f"{conv.get_prompt()} ", None, max_length=None)['input_ids']
312
 
313
+ if history is None or not isinstance(history, list):
314
+ history = []
315
+
316
  while(len(history) > 0 and (len(example) < max_token)):
317
  tmp = history.pop()
318
  if tmp[0] == 'ASSISTANT':
 
336
  sft=True, convo_template = "",
337
  device = "cuda",
338
  model_name="AquilaChat2-7B",
339
+ history=None,
340
  **kwargs):
341
 
342
  vocab = tokenizer.get_vocab()
 
356
  topk = 1
357
  temperature = 1.0
358
  if sft:
359
+ tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
360
  tokens = torch.tensor(tokens)[None,].to(device)
361
  else :
362
  tokens = tokenizer.encode_plus(text)["input_ids"]
 
438
  convert_tokens = convert_tokens[1:]
439
  probs = probs[1:]
440
 
441
+ if isinstance(history, list):
442
+ # Update history
443
+ history.insert(0, ('ASSISTANT', out))
444
+ history.insert(0, ('USER', text))
445
 
446
  return out