Angainor Development commited on
Commit
bd3b537
·
unverified ·
1 Parent(s): 813cfa4

Feed cfg.inference

Browse files
Files changed (1) hide show
  1. scripts/finetune.py +7 -5
scripts/finetune.py CHANGED
@@ -182,6 +182,9 @@ def train(
182
  if cfg.bf16:
183
  cfg.fp16 = True
184
  cfg.bf16 = False
 
 
 
185
 
186
  # load the tokenizer first
187
  tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
@@ -189,8 +192,8 @@ def train(
189
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
190
 
191
  if check_not_in(
192
- ["inference", "shard", "merge_lora"], kwargs
193
- ): # don't need to load dataset for these
194
  train_dataset, eval_dataset = load_prepare_datasets(
195
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
196
  )
@@ -216,8 +219,7 @@ def train(
216
  cfg.model_type,
217
  tokenizer,
218
  cfg,
219
- adapter=cfg.adapter,
220
- inference=("inference" in kwargs),
221
  )
222
 
223
  if "merge_lora" in kwargs and cfg.adapter is not None:
@@ -230,7 +232,7 @@ def train(
230
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
231
  return
232
 
233
- if "inference" in kwargs:
234
  logging.info("calling do_inference function")
235
  do_inference(cfg, model, tokenizer)
236
  return
 
182
  if cfg.bf16:
183
  cfg.fp16 = True
184
  cfg.bf16 = False
185
+
186
+ # Store inference mode into cfg when passed via args
187
+ cfg.inference = True if "inference" in kwargs else cfg.get("inference", False)
188
 
189
  # load the tokenizer first
190
  tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
 
192
  tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
193
 
194
  if check_not_in(
195
+ ["shard", "merge_lora"], kwargs
196
+ ) and not cfg.inference: # don't need to load dataset for these
197
  train_dataset, eval_dataset = load_prepare_datasets(
198
  tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
199
  )
 
219
  cfg.model_type,
220
  tokenizer,
221
  cfg,
222
+ adapter=cfg.adapter
 
223
  )
224
 
225
  if "merge_lora" in kwargs and cfg.adapter is not None:
 
232
  model.save_pretrained(str(Path(cfg.output_dir) / "merged"))
233
  return
234
 
235
+ if cfg.inference:
236
  logging.info("calling do_inference function")
237
  do_inference(cfg, model, tokenizer)
238
  return