BAAI
/

ldwang commited on
Commit
c9ddf70
·
1 Parent(s): bd2b1bd

Upload predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. predict.py +33 -5
predict.py CHANGED
@@ -185,6 +185,8 @@ def get_conversation_template(model_path: str) -> Conversation:
185
  """Get the default conversation template."""
186
  if "aquila-v1" in model_path:
187
  return get_conv_template("aquila-v1")
 
 
188
  elif "aquila-chat" in model_path:
189
  return get_conv_template("aquila-chat")
190
  elif "aquila-legacy" in model_path:
@@ -252,6 +254,21 @@ register_conv_template(
252
  )
253
  )
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  if __name__ == "__main__":
257
  print("aquila template:")
@@ -294,6 +311,17 @@ if __name__ == "__main__":
294
 
295
  print("\n")
296
 
 
 
 
 
 
 
 
 
 
 
 
297
  def set_random_seed(seed):
298
  """Set random seed for reproducability."""
299
  if seed is not None and seed > 0:
@@ -330,9 +358,9 @@ def covert_prompt_to_input_ids_with_history(text, history, tokenizer, max_token,
330
  return example
331
 
332
  def predict(model, text, tokenizer=None,
333
- max_gen_len=200, top_p=0.95,
334
- seed=1234, topk=100,
335
- temperature=0.9,
336
  sft=True, convo_template = "",
337
  device = "cuda",
338
  model_name="AquilaChat2-7B",
@@ -346,8 +374,8 @@ def predict(model, text, tokenizer=None,
346
 
347
  template_map = {"AquilaChat2-7B": "aquila-v1",
348
  "AquilaChat2-34B": "aquila-legacy",
349
- "AquilaChat2-7B-16K": "aquila",
350
  "AquilaChat2-70B-Expr": "aquila-v2",
 
351
  "AquilaChat2-34B-16K": "aquila"}
352
  if not convo_template:
353
  convo_template=template_map.get(model_name, "aquila-chat")
@@ -357,7 +385,7 @@ def predict(model, text, tokenizer=None,
357
  topk = 1
358
  temperature = 1.0
359
  if sft:
360
- tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=2048, convo_template=convo_template)
361
  tokens = torch.tensor(tokens)[None,].to(device)
362
  else :
363
  tokens = tokenizer.encode_plus(text)["input_ids"]
 
185
  """Get the default conversation template."""
186
  if "aquila-v1" in model_path:
187
  return get_conv_template("aquila-v1")
188
+ elif "aquila-v2" in model_path:
189
+ return get_conv_template("aquila-v2")
190
  elif "aquila-chat" in model_path:
191
  return get_conv_template("aquila-chat")
192
  elif "aquila-legacy" in model_path:
 
254
  )
255
  )
256
 
257
+ register_conv_template(
258
+ Conversation(
259
+ name="aquila-v2",
260
+ system_message="A chat between a curious human and an artificial intelligence assistant. "
261
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
262
+ roles=("<|startofpiece|>", "<|endofpiece|>", ""),
263
+ messages=(),
264
+ offset=0,
265
+ sep_style=SeparatorStyle.NO_COLON_TWO,
266
+ sep="",
267
+ sep2="</s>",
268
+ stop_str=["</s>", "<|endoftext|>", "<|startofpiece|>", "<|endofpiece|>"],
269
+ )
270
+ )
271
+
272
 
273
  if __name__ == "__main__":
274
  print("aquila template:")
 
311
 
312
  print("\n")
313
 
314
+ print("aquila-v2 template:")
315
+ conv = get_conv_template("aquila-v2")
316
+ conv.append_message(conv.roles[0], "Hello!")
317
+ conv.append_message(conv.roles[1], "Hi!")
318
+ conv.append_message(conv.roles[0], "How are you?")
319
+ conv.append_message(conv.roles[1], None)
320
+ print(conv.get_prompt())
321
+
322
+ print("\n")
323
+
324
+
325
  def set_random_seed(seed):
326
  """Set random seed for reproducability."""
327
  if seed is not None and seed > 0:
 
358
  return example
359
 
360
  def predict(model, text, tokenizer=None,
361
+ max_gen_len=200, top_p=0.9,
362
+ seed=123, topk=15,
363
+ temperature=1.0,
364
  sft=True, convo_template = "",
365
  device = "cuda",
366
  model_name="AquilaChat2-7B",
 
374
 
375
  template_map = {"AquilaChat2-7B": "aquila-v1",
376
  "AquilaChat2-34B": "aquila-legacy",
 
377
  "AquilaChat2-70B-Expr": "aquila-v2",
378
+ "AquilaChat2-7B-16K": "aquila",
379
  "AquilaChat2-34B-16K": "aquila"}
380
  if not convo_template:
381
  convo_template=template_map.get(model_name, "aquila-chat")
 
385
  topk = 1
386
  temperature = 1.0
387
  if sft:
388
+ tokens = covert_prompt_to_input_ids_with_history(text, history=history, tokenizer=tokenizer, max_token=20480, convo_template=convo_template)
389
  tokens = torch.tensor(tokens)[None,].to(device)
390
  else :
391
  tokens = tokenizer.encode_plus(text)["input_ids"]