zetavg commited on
Commit
8788753
·
unverified ·
1 Parent(s): 5929f2a

update finetune

Browse files
Files changed (1) hide show
  1. llama_lora/ui/finetune_ui.py +15 -6
llama_lora/ui/finetune_ui.py CHANGED
@@ -261,6 +261,11 @@ def do_train(
261
  progress=gr.Progress(track_tqdm=True),
262
  ):
263
  try:
 
 
 
 
 
264
  prompter = Prompter(template)
265
  variable_names = prompter.get_variable_names()
266
 
@@ -392,16 +397,20 @@ Train data (first 10):
392
 
393
  training_callbacks = [UiTrainerCallback]
394
 
395
- # If model has been used in inference, we need to unload it first.
396
- # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
397
- # gradient at index 1 - expected device meta but got cuda:0' error.
398
  unload_models_if_already_used()
399
 
400
- Global.should_stop_training = False
 
 
 
 
401
 
402
  results = Global.train_fn(
403
- get_base_model(), # base_model
404
- get_tokenizer(), # tokenizer
405
  os.path.join(Global.data_dir, "lora_models",
406
  model_name), # output_dir
407
  train_data,
 
261
  progress=gr.Progress(track_tqdm=True),
262
  ):
263
  try:
264
+ # If model has been used in inference, we need to unload it first.
265
+ # Otherwise, we'll get a 'Function MmBackward0 returned an invalid
266
+ # gradient at index 1 - expected device meta but got cuda:0' error.
267
+ unload_models_if_already_used()
268
+
269
  prompter = Prompter(template)
270
  variable_names = prompter.get_variable_names()
271
 
 
397
 
398
  training_callbacks = [UiTrainerCallback]
399
 
400
+ Global.should_stop_training = False
401
+
402
+ # Do this again right before training to make sure the model is not used in inference.
403
  unload_models_if_already_used()
404
 
405
+ base_model = get_base_model()
406
+ tokenizer = get_tokenizer()
407
+
408
+ # Do not let other tqdm iterations interfere the progress reporting after training starts.
409
+ progress.track_tqdm = False
410
 
411
  results = Global.train_fn(
412
+ base_model, # base_model
413
+ tokenizer, # tokenizer
414
  os.path.join(Global.data_dir, "lora_models",
415
  model_name), # output_dir
416
  train_data,