WIP: Support table logging for mlflow, too (#1506)
Browse files* WIP: Support table logging for mlflow, too
Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.
In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".
See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505
* chore: lint
* add additional clause for mlflow as it's optional
* Fix circular imports
---------
Co-authored-by: Dave Farago <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
src/axolotl/core/trainer_builder.py
CHANGED
@@ -36,6 +36,7 @@ from trl.trainer.utils import pad_to_length
|
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
37 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
38 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|
|
39 |
from axolotl.utils.callbacks import (
|
40 |
EvalFirstStepCallback,
|
41 |
GPUStatsCallback,
|
@@ -71,10 +72,6 @@ except ImportError:
|
|
71 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
72 |
|
73 |
|
74 |
-
def is_mlflow_available():
|
75 |
-
return importlib.util.find_spec("mlflow") is not None
|
76 |
-
|
77 |
-
|
78 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
79 |
if isinstance(tag_names, str):
|
80 |
tag_names = [tag_names]
|
@@ -943,7 +940,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
943 |
callbacks = []
|
944 |
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
945 |
LogPredictionCallback = log_prediction_callback_factory(
|
946 |
-
trainer, self.tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
947 |
)
|
948 |
callbacks.append(LogPredictionCallback(self.cfg))
|
949 |
|
|
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
37 |
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
38 |
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
39 |
+
from axolotl.utils import is_mlflow_available
|
40 |
from axolotl.utils.callbacks import (
|
41 |
EvalFirstStepCallback,
|
42 |
GPUStatsCallback,
|
|
|
72 |
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
73 |
|
74 |
|
|
|
|
|
|
|
|
|
75 |
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
|
76 |
if isinstance(tag_names, str):
|
77 |
tag_names = [tag_names]
|
|
|
940 |
callbacks = []
|
941 |
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
942 |
LogPredictionCallback = log_prediction_callback_factory(
|
943 |
+
trainer, self.tokenizer, "wandb"
|
944 |
+
)
|
945 |
+
callbacks.append(LogPredictionCallback(self.cfg))
|
946 |
+
if (
|
947 |
+
self.cfg.use_mlflow
|
948 |
+
and is_mlflow_available()
|
949 |
+
and self.cfg.eval_table_size > 0
|
950 |
+
):
|
951 |
+
LogPredictionCallback = log_prediction_callback_factory(
|
952 |
+
trainer, self.tokenizer, "mlflow"
|
953 |
)
|
954 |
callbacks.append(LogPredictionCallback(self.cfg))
|
955 |
|
src/axolotl/utils/__init__.py
CHANGED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Basic utils for Axolotl
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
|
6 |
+
|
7 |
+
def is_mlflow_available():
|
8 |
+
return importlib.util.find_spec("mlflow") is not None
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
@@ -6,7 +6,7 @@ import logging
|
|
6 |
import os
|
7 |
from shutil import copyfile
|
8 |
from tempfile import NamedTemporaryFile
|
9 |
-
from typing import TYPE_CHECKING, Dict, List
|
10 |
|
11 |
import evaluate
|
12 |
import numpy as np
|
@@ -27,7 +27,9 @@ from transformers import (
|
|
27 |
)
|
28 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
29 |
|
|
|
30 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
31 |
from axolotl.utils.distributed import (
|
32 |
barrier,
|
33 |
broadcast_dict,
|
@@ -540,7 +542,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|
540 |
return CausalLMBenchEvalCallback
|
541 |
|
542 |
|
543 |
-
def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
544 |
class LogPredictionCallback(TrainerCallback):
|
545 |
"""Callback to log prediction values during each evaluation"""
|
546 |
|
@@ -597,15 +599,13 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
597 |
return ranges
|
598 |
|
599 |
def log_table_from_dataloader(name: str, table_dataloader):
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
]
|
608 |
-
)
|
609 |
row_index = 0
|
610 |
|
611 |
for batch in tqdm(table_dataloader):
|
@@ -709,16 +709,29 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|
709 |
) in zip(
|
710 |
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
711 |
):
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
prediction_text
|
717 |
-
pred_step_text,
|
718 |
)
|
|
|
|
|
|
|
719 |
row_index += 1
|
720 |
-
|
721 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
722 |
|
723 |
if is_main_process():
|
724 |
log_table_from_dataloader("Eval", eval_dataloader)
|
|
|
6 |
import os
|
7 |
from shutil import copyfile
|
8 |
from tempfile import NamedTemporaryFile
|
9 |
+
from typing import TYPE_CHECKING, Any, Dict, List
|
10 |
|
11 |
import evaluate
|
12 |
import numpy as np
|
|
|
27 |
)
|
28 |
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
29 |
|
30 |
+
from axolotl.utils import is_mlflow_available
|
31 |
from axolotl.utils.bench import log_gpu_memory_usage
|
32 |
+
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
33 |
from axolotl.utils.distributed import (
|
34 |
barrier,
|
35 |
broadcast_dict,
|
|
|
542 |
return CausalLMBenchEvalCallback
|
543 |
|
544 |
|
545 |
+
def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
|
546 |
class LogPredictionCallback(TrainerCallback):
|
547 |
"""Callback to log prediction values during each evaluation"""
|
548 |
|
|
|
599 |
return ranges
|
600 |
|
601 |
def log_table_from_dataloader(name: str, table_dataloader):
|
602 |
+
table_data: Dict[str, List[Any]] = {
|
603 |
+
"id": [],
|
604 |
+
"Prompt": [],
|
605 |
+
"Correct Completion": [],
|
606 |
+
"Predicted Completion (model.generate)": [],
|
607 |
+
"Predicted Completion (trainer.prediction_step)": [],
|
608 |
+
}
|
|
|
|
|
609 |
row_index = 0
|
610 |
|
611 |
for batch in tqdm(table_dataloader):
|
|
|
709 |
) in zip(
|
710 |
prompt_texts, completion_texts, predicted_texts, pred_step_texts
|
711 |
):
|
712 |
+
table_data["id"].append(row_index)
|
713 |
+
table_data["Prompt"].append(prompt_text)
|
714 |
+
table_data["Correct Completion"].append(completion_text)
|
715 |
+
table_data["Predicted Completion (model.generate)"].append(
|
716 |
+
prediction_text
|
|
|
717 |
)
|
718 |
+
table_data[
|
719 |
+
"Predicted Completion (trainer.prediction_step)"
|
720 |
+
].append(pred_step_text)
|
721 |
row_index += 1
|
722 |
+
if logger == "wandb":
|
723 |
+
wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
|
724 |
+
elif logger == "mlflow" and is_mlflow_available():
|
725 |
+
import mlflow
|
726 |
+
|
727 |
+
tracking_uri = AxolotlInputConfig(
|
728 |
+
**self.cfg.to_dict()
|
729 |
+
).mlflow_tracking_uri
|
730 |
+
mlflow.log_table(
|
731 |
+
data=table_data,
|
732 |
+
artifact_file="PredictionsVsGroundTruth.json",
|
733 |
+
tracking_uri=tracking_uri,
|
734 |
+
)
|
735 |
|
736 |
if is_main_process():
|
737 |
log_table_from_dataloader("Eval", eval_dataloader)
|