fix bettertransformers save, force it to skip after saving correctly in callback
Browse files
src/axolotl/utils/callbacks.py
CHANGED
@@ -9,7 +9,7 @@ from transformers import (
|
|
9 |
TrainerState,
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
-
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
13 |
|
14 |
|
15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
@@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-
|
|
36 |
class SaveBetterTransformerModelCallback(
|
37 |
TrainerCallback
|
38 |
): # pylint: disable=too-few-public-methods
|
39 |
-
"""Callback to save the
|
40 |
|
41 |
-
def
|
42 |
self,
|
43 |
args: TrainingArguments,
|
44 |
state: TrainerState,
|
45 |
control: TrainerControl,
|
46 |
**kwargs,
|
47 |
):
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
54 |
-
|
55 |
|
|
|
|
|
|
|
56 |
return control
|
|
|
9 |
TrainerState,
|
10 |
TrainingArguments,
|
11 |
)
|
12 |
+
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
13 |
|
14 |
|
15 |
class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods
|
|
|
36 |
class SaveBetterTransformerModelCallback(
|
37 |
TrainerCallback
|
38 |
): # pylint: disable=too-few-public-methods
|
39 |
+
"""Callback to save the BetterTransformer wrapped model"""
|
40 |
|
41 |
+
def on_step_end(
|
42 |
self,
|
43 |
args: TrainingArguments,
|
44 |
state: TrainerState,
|
45 |
control: TrainerControl,
|
46 |
**kwargs,
|
47 |
):
|
48 |
+
# Save
|
49 |
+
if (
|
50 |
+
args.save_strategy == IntervalStrategy.STEPS
|
51 |
+
and args.save_steps > 0
|
52 |
+
and state.global_step % args.save_steps == 0
|
53 |
+
):
|
54 |
+
control.should_save = True
|
55 |
+
|
56 |
+
if control.should_save:
|
57 |
+
checkpoint_folder = os.path.join(
|
58 |
+
args.output_dir,
|
59 |
+
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
60 |
+
)
|
61 |
|
62 |
+
model = BetterTransformer.reverse(kwargs["model"])
|
63 |
+
model.save_pretrained(checkpoint_folder)
|
64 |
|
65 |
+
# since we're saving here, we don't need the trainer loop to attempt to save too b/c
|
66 |
+
# the trainer will raise an exception since it can't save a BetterTransformer wrapped model
|
67 |
+
control.should_save = False
|
68 |
return control
|
src/axolotl/utils/trainer.py
CHANGED
@@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|
232 |
callbacks.append(SavePeftModelCallback)
|
233 |
|
234 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
|
|
235 |
callbacks.append(SaveBetterTransformerModelCallback)
|
236 |
|
237 |
data_collator_kwargs = {
|
|
|
232 |
callbacks.append(SavePeftModelCallback)
|
233 |
|
234 |
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
235 |
+
logging.info("Setting up SaveBetterTransformerModelCallback.")
|
236 |
callbacks.append(SaveBetterTransformerModelCallback)
|
237 |
|
238 |
data_collator_kwargs = {
|
src/axolotl/utils/validation.py
CHANGED
@@ -66,9 +66,10 @@ def validate_config(cfg):
|
|
66 |
)
|
67 |
if cfg.fp16 or cfg.bf16:
|
68 |
raise ValueError("AMP is not supported with BetterTransformer")
|
69 |
-
if cfg.float16 is not True:
|
70 |
logging.warning(
|
71 |
-
"You should probably set float16 to true to
|
|
|
72 |
)
|
73 |
if int(torch.__version__.split(".")[0]) < 2:
|
74 |
logging.warning("torch>=2.0.0 required")
|
|
|
66 |
)
|
67 |
if cfg.fp16 or cfg.bf16:
|
68 |
raise ValueError("AMP is not supported with BetterTransformer")
|
69 |
+
if cfg.float16 is not True and cfg.bloat16 is not True:
|
70 |
logging.warning(
|
71 |
+
"You should probably set bfloat16 or float16 to true to "
|
72 |
+
"load the model in float16 for BetterTransformers"
|
73 |
)
|
74 |
if int(torch.__version__.split(".")[0]) < 2:
|
75 |
logging.warning("torch>=2.0.0 required")
|