Nanobit commited on
Commit
dc77c8e
·
1 Parent(s): 51a4c12

chore: Refactor inf_kwargs out

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +5 -5
scripts/finetune.py CHANGED
@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
63
  return instruction
64
 
65
 
66
- def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
68
 
69
  for token, symbol in default_tokens.items():
@@ -257,13 +257,13 @@ def train(
257
 
258
  if cfg.inference:
259
  logging.info("calling do_inference function")
260
- inf_kwargs: Dict[str, Any] = {}
261
  if "prompter" in kwargs:
262
  if kwargs["prompter"] == "None":
263
- inf_kwargs["prompter"] = None
264
  else:
265
- inf_kwargs["prompter"] = kwargs["prompter"]
266
- do_inference(cfg, model, tokenizer, **inf_kwargs)
267
  return
268
 
269
  if "shard" in kwargs:
 
63
  return instruction
64
 
65
 
66
+ def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
67
  default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
68
 
69
  for token, symbol in default_tokens.items():
 
257
 
258
  if cfg.inference:
259
  logging.info("calling do_inference function")
260
+ prompter: Optional[str] = "AlpacaPrompter"
261
  if "prompter" in kwargs:
262
  if kwargs["prompter"] == "None":
263
+ prompter = None
264
  else:
265
+ prompter = kwargs["prompter"]
266
+ do_inference(cfg, model, tokenizer, prompter=prompter)
267
  return
268
 
269
  if "shard" in kwargs: