chore: Refactor inf_kwargs out
Browse files- scripts/finetune.py +5 -5
scripts/finetune.py
CHANGED
@@ -63,7 +63,7 @@ def get_multi_line_input() -> Optional[str]:
|
|
63 |
return instruction
|
64 |
|
65 |
|
66 |
-
def do_inference(cfg, model, tokenizer, prompter
|
67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
68 |
|
69 |
for token, symbol in default_tokens.items():
|
@@ -257,13 +257,13 @@ def train(
|
|
257 |
|
258 |
if cfg.inference:
|
259 |
logging.info("calling do_inference function")
|
260 |
-
|
261 |
if "prompter" in kwargs:
|
262 |
if kwargs["prompter"] == "None":
|
263 |
-
|
264 |
else:
|
265 |
-
|
266 |
-
do_inference(cfg, model, tokenizer,
|
267 |
return
|
268 |
|
269 |
if "shard" in kwargs:
|
|
|
63 |
return instruction
|
64 |
|
65 |
|
66 |
+
def do_inference(cfg, model, tokenizer, prompter: Optional[str]):
|
67 |
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
68 |
|
69 |
for token, symbol in default_tokens.items():
|
|
|
257 |
|
258 |
if cfg.inference:
|
259 |
logging.info("calling do_inference function")
|
260 |
+
prompter: Optional[str] = "AlpacaPrompter"
|
261 |
if "prompter" in kwargs:
|
262 |
if kwargs["prompter"] == "None":
|
263 |
+
prompter = None
|
264 |
else:
|
265 |
+
prompter = kwargs["prompter"]
|
266 |
+
do_inference(cfg, model, tokenizer, prompter=prompter)
|
267 |
return
|
268 |
|
269 |
if "shard" in kwargs:
|