Nanobit commited on
Commit
ddb86ea
·
1 Parent(s): 1a2bd7f

Lint trainer.py

Browse files
Files changed (1) hide show
  1. src/axolotl/utils/trainer.py +20 -11
src/axolotl/utils/trainer.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import importlib
2
  import math
3
  import os
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
17
 
18
 
19
  class OneCycleLRSchedulerTrainer(Trainer):
 
 
 
 
 
 
 
 
20
  def create_scheduler(
21
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
22
  ):
23
  optimizer = self.optimizer if optimizer is None else optimizer
24
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
25
- num_training_steps = num_training_steps
26
  pct_start = num_warmup_steps / num_training_steps
27
 
28
  self.lr_scheduler = OneCycleLR(
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
58
  training_arguments_kwargs["bf16_full_eval"] = True
59
  else:
60
  training_arguments_kwargs["bf16"] = cfg.bf16
61
- training_arguments_kwargs["fp16"] = True if cfg.fp16 and not cfg.bf16 else False
62
  training_arguments_kwargs["tf32"] = cfg.tf32
63
  training_arguments_kwargs["warmup_steps"] = warmup_steps
64
  training_arguments_kwargs["logging_steps"] = logging_steps
65
- if cfg.gradient_checkpointing is not None:
66
  if cfg.gptq:
67
  from alpaca_lora_4bit.gradient_checkpointing import (
68
  apply_gradient_checkpointing,
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
112
  save_steps=save_steps,
113
  output_dir=cfg.output_dir,
114
  save_total_limit=3,
115
- load_best_model_at_end=True
116
- if cfg.load_best_model_at_end is not False # if explicitly set to False, it should be resort to False
117
- and cfg.val_set_size > 0
118
- and save_steps is not None
119
- and save_steps % eval_steps == 0
120
- and cfg.load_in_8bit is not True
121
- else False,
122
  ddp_find_unused_parameters=False if cfg.ddp else None,
123
  group_by_length=cfg.group_by_length,
124
  report_to="wandb" if cfg.use_wandb else None,
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
140
  if (
141
  cfg.optimizer == "adamw_bnb_8bit"
142
  and not cfg.gptq
143
- and not "deepspeed" in training_arguments_kwargs
144
  and not cfg.fsdp
145
  ):
146
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])
 
1
+ """Module containing the Trainer class and related functions"""
2
+
3
  import importlib
4
  import math
5
  import os
 
19
 
20
 
21
  class OneCycleLRSchedulerTrainer(Trainer):
22
+ """
23
+ Trainer subclass that uses the OneCycleLR scheduler
24
+ """
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+ self.lr_scheduler = None
29
+
30
  def create_scheduler(
31
  self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
32
  ):
33
  optimizer = self.optimizer if optimizer is None else optimizer
34
  num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
 
35
  pct_start = num_warmup_steps / num_training_steps
36
 
37
  self.lr_scheduler = OneCycleLR(
 
67
  training_arguments_kwargs["bf16_full_eval"] = True
68
  else:
69
  training_arguments_kwargs["bf16"] = cfg.bf16
70
+ training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
71
  training_arguments_kwargs["tf32"] = cfg.tf32
72
  training_arguments_kwargs["warmup_steps"] = warmup_steps
73
  training_arguments_kwargs["logging_steps"] = logging_steps
74
+ if cfg.gradient_checkpointing:
75
  if cfg.gptq:
76
  from alpaca_lora_4bit.gradient_checkpointing import (
77
  apply_gradient_checkpointing,
 
121
  save_steps=save_steps,
122
  output_dir=cfg.output_dir,
123
  save_total_limit=3,
124
+ load_best_model_at_end=(
125
+ cfg.val_set_size > 0
126
+ and save_steps
127
+ and save_steps % eval_steps == 0
128
+ and cfg.load_in_8bit is not True
129
+ )
130
+ or False,
131
  ddp_find_unused_parameters=False if cfg.ddp else None,
132
  group_by_length=cfg.group_by_length,
133
  report_to="wandb" if cfg.use_wandb else None,
 
149
  if (
150
  cfg.optimizer == "adamw_bnb_8bit"
151
  and not cfg.gptq
152
+ and "deepspeed" not in training_arguments_kwargs
153
  and not cfg.fsdp
154
  ):
155
  decay_parameters = get_parameter_names(model, [nn.LayerNorm])