Jan Philipp Harries
commited on
Save Axolotl config as WandB artifact (#716)
Browse files
src/axolotl/cli/__init__.py
CHANGED
@@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|
194 |
# load the config from the yaml file
|
195 |
with open(config, encoding="utf-8") as file:
|
196 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
|
|
197 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
198 |
# then overwrite the value
|
199 |
cfg_keys = cfg.keys()
|
|
|
194 |
# load the config from the yaml file
|
195 |
with open(config, encoding="utf-8") as file:
|
196 |
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
197 |
+
cfg.axolotl_config_path = config
|
198 |
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
199 |
# then overwrite the value
|
200 |
cfg_keys = cfg.keys()
|
src/axolotl/utils/callbacks.py
CHANGED
@@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
514 |
return control
|
515 |
|
516 |
return LogPredictionCallback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
return control
|
515 |
|
516 |
return LogPredictionCallback
|
517 |
+
|
518 |
+
|
519 |
+
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
520 |
+
"""Callback to save axolotl config to wandb"""
|
521 |
+
|
522 |
+
def __init__(self, axolotl_config_path):
|
523 |
+
self.axolotl_config_path = axolotl_config_path
|
524 |
+
|
525 |
+
def on_train_begin(
|
526 |
+
self,
|
527 |
+
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
528 |
+
state: TrainerState, # pylint: disable=unused-argument
|
529 |
+
control: TrainerControl,
|
530 |
+
**kwargs, # pylint: disable=unused-argument
|
531 |
+
):
|
532 |
+
if is_main_process():
|
533 |
+
try:
|
534 |
+
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
535 |
+
artifact.add_file(local_path=self.axolotl_config_path)
|
536 |
+
wandb.run.log_artifact(artifact)
|
537 |
+
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
538 |
+
except (FileNotFoundError, ConnectionError) as err:
|
539 |
+
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
540 |
+
return control
|
src/axolotl/utils/trainer.py
CHANGED
@@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
30 |
from axolotl.utils.callbacks import (
|
31 |
EvalFirstStepCallback,
|
32 |
GPUStatsCallback,
|
|
|
33 |
SaveBetterTransformerModelCallback,
|
34 |
bench_eval_callback_factory,
|
35 |
log_prediction_callback_factory,
|
@@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
775 |
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
776 |
trainer.add_callback(LogPredictionCallback(cfg))
|
777 |
|
|
|
|
|
|
|
778 |
if cfg.do_bench_eval:
|
779 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
780 |
|
|
|
30 |
from axolotl.utils.callbacks import (
|
31 |
EvalFirstStepCallback,
|
32 |
GPUStatsCallback,
|
33 |
+
SaveAxolotlConfigtoWandBCallback,
|
34 |
SaveBetterTransformerModelCallback,
|
35 |
bench_eval_callback_factory,
|
36 |
log_prediction_callback_factory,
|
|
|
776 |
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
777 |
trainer.add_callback(LogPredictionCallback(cfg))
|
778 |
|
779 |
+
if cfg.use_wandb:
|
780 |
+
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
|
781 |
+
|
782 |
if cfg.do_bench_eval:
|
783 |
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
784 |
|