Lint trainer.py
Browse files- src/axolotl/utils/trainer.py +20 -11
src/axolotl/utils/trainer.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import importlib
|
2 |
import math
|
3 |
import os
|
@@ -17,12 +19,19 @@ from axolotl.utils.callbacks import SavePeftModelCallback
|
|
17 |
|
18 |
|
19 |
class OneCycleLRSchedulerTrainer(Trainer):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
def create_scheduler(
|
21 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
22 |
):
|
23 |
optimizer = self.optimizer if optimizer is None else optimizer
|
24 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
25 |
-
num_training_steps = num_training_steps
|
26 |
pct_start = num_warmup_steps / num_training_steps
|
27 |
|
28 |
self.lr_scheduler = OneCycleLR(
|
@@ -58,11 +67,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
58 |
training_arguments_kwargs["bf16_full_eval"] = True
|
59 |
else:
|
60 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
61 |
-
training_arguments_kwargs["fp16"] =
|
62 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
63 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
64 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
65 |
-
if cfg.gradient_checkpointing
|
66 |
if cfg.gptq:
|
67 |
from alpaca_lora_4bit.gradient_checkpointing import (
|
68 |
apply_gradient_checkpointing,
|
@@ -112,13 +121,13 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
112 |
save_steps=save_steps,
|
113 |
output_dir=cfg.output_dir,
|
114 |
save_total_limit=3,
|
115 |
-
load_best_model_at_end=
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
123 |
group_by_length=cfg.group_by_length,
|
124 |
report_to="wandb" if cfg.use_wandb else None,
|
@@ -140,7 +149,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
140 |
if (
|
141 |
cfg.optimizer == "adamw_bnb_8bit"
|
142 |
and not cfg.gptq
|
143 |
-
and
|
144 |
and not cfg.fsdp
|
145 |
):
|
146 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
|
|
1 |
+
"""Module containing the Trainer class and related functions"""
|
2 |
+
|
3 |
import importlib
|
4 |
import math
|
5 |
import os
|
|
|
19 |
|
20 |
|
21 |
class OneCycleLRSchedulerTrainer(Trainer):
|
22 |
+
"""
|
23 |
+
Trainer subclass that uses the OneCycleLR scheduler
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, *args, **kwargs):
|
27 |
+
super().__init__(*args, **kwargs)
|
28 |
+
self.lr_scheduler = None
|
29 |
+
|
30 |
def create_scheduler(
|
31 |
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
32 |
):
|
33 |
optimizer = self.optimizer if optimizer is None else optimizer
|
34 |
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
|
|
35 |
pct_start = num_warmup_steps / num_training_steps
|
36 |
|
37 |
self.lr_scheduler = OneCycleLR(
|
|
|
67 |
training_arguments_kwargs["bf16_full_eval"] = True
|
68 |
else:
|
69 |
training_arguments_kwargs["bf16"] = cfg.bf16
|
70 |
+
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
71 |
training_arguments_kwargs["tf32"] = cfg.tf32
|
72 |
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
73 |
training_arguments_kwargs["logging_steps"] = logging_steps
|
74 |
+
if cfg.gradient_checkpointing:
|
75 |
if cfg.gptq:
|
76 |
from alpaca_lora_4bit.gradient_checkpointing import (
|
77 |
apply_gradient_checkpointing,
|
|
|
121 |
save_steps=save_steps,
|
122 |
output_dir=cfg.output_dir,
|
123 |
save_total_limit=3,
|
124 |
+
load_best_model_at_end=(
|
125 |
+
cfg.val_set_size > 0
|
126 |
+
and save_steps
|
127 |
+
and save_steps % eval_steps == 0
|
128 |
+
and cfg.load_in_8bit is not True
|
129 |
+
)
|
130 |
+
or False,
|
131 |
ddp_find_unused_parameters=False if cfg.ddp else None,
|
132 |
group_by_length=cfg.group_by_length,
|
133 |
report_to="wandb" if cfg.use_wandb else None,
|
|
|
149 |
if (
|
150 |
cfg.optimizer == "adamw_bnb_8bit"
|
151 |
and not cfg.gptq
|
152 |
+
and "deepspeed" not in training_arguments_kwargs
|
153 |
and not cfg.fsdp
|
154 |
):
|
155 |
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|