Angainor Development
commited on
Feed cfg.inference
Browse files- scripts/finetune.py +7 -5
scripts/finetune.py
CHANGED
@@ -182,6 +182,9 @@ def train(
|
|
182 |
if cfg.bf16:
|
183 |
cfg.fp16 = True
|
184 |
cfg.bf16 = False
|
|
|
|
|
|
|
185 |
|
186 |
# load the tokenizer first
|
187 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
@@ -189,8 +192,8 @@ def train(
|
|
189 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
190 |
|
191 |
if check_not_in(
|
192 |
-
["
|
193 |
-
): # don't need to load dataset for these
|
194 |
train_dataset, eval_dataset = load_prepare_datasets(
|
195 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
196 |
)
|
@@ -216,8 +219,7 @@ def train(
|
|
216 |
cfg.model_type,
|
217 |
tokenizer,
|
218 |
cfg,
|
219 |
-
adapter=cfg.adapter
|
220 |
-
inference=("inference" in kwargs),
|
221 |
)
|
222 |
|
223 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
@@ -230,7 +232,7 @@ def train(
|
|
230 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
231 |
return
|
232 |
|
233 |
-
if
|
234 |
logging.info("calling do_inference function")
|
235 |
do_inference(cfg, model, tokenizer)
|
236 |
return
|
|
|
182 |
if cfg.bf16:
|
183 |
cfg.fp16 = True
|
184 |
cfg.bf16 = False
|
185 |
+
|
186 |
+
# Store inference mode into cfg when passed via args
|
187 |
+
cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
|
188 |
|
189 |
# load the tokenizer first
|
190 |
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
|
|
192 |
tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
|
193 |
|
194 |
if check_not_in(
|
195 |
+
["shard", "merge_lora"], kwargs
|
196 |
+
) and not cfg.inference: # don't need to load dataset for these
|
197 |
train_dataset, eval_dataset = load_prepare_datasets(
|
198 |
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
199 |
)
|
|
|
219 |
cfg.model_type,
|
220 |
tokenizer,
|
221 |
cfg,
|
222 |
+
adapter=cfg.adapter
|
|
|
223 |
)
|
224 |
|
225 |
if "merge_lora" in kwargs and cfg.adapter is not None:
|
|
|
232 |
model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
|
233 |
return
|
234 |
|
235 |
+
if cfg.inference:
|
236 |
logging.info("calling do_inference function")
|
237 |
do_inference(cfg, model, tokenizer)
|
238 |
return
|