DavidFarago Dave Farago winglian commited on
Commit
057fa44
·
unverified ·
1 Parent(s): 8fa0785

WIP: Support table logging for mlflow, too (#1506)

Browse files

* WIP: Support table logging for mlflow, too

Create a `LogPredictionCallback` for both "wandb" and "mlflow" if
specified.

In `log_prediction_callback_factory`, create a generic table and make it
specific only if the newly added `logger` argument is set to "wandb"
resp. "mlflow".

See https://github.com/OpenAccess-AI-Collective/axolotl/issues/1505

* chore: lint

* add additional clause for mlflow as it's optional

* Fix circular imports

---------

Co-authored-by: Dave Farago <[email protected]>
Co-authored-by: Wing Lian <[email protected]>

src/axolotl/core/trainer_builder.py CHANGED
@@ -36,6 +36,7 @@ from trl.trainer.utils import pad_to_length
36
  from axolotl.loraplus import create_loraplus_optimizer
37
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
38
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
 
39
  from axolotl.utils.callbacks import (
40
  EvalFirstStepCallback,
41
  GPUStatsCallback,
@@ -71,10 +72,6 @@ except ImportError:
71
  LOG = logging.getLogger("axolotl.core.trainer_builder")
72
 
73
 
74
- def is_mlflow_available():
75
- return importlib.util.find_spec("mlflow") is not None
76
-
77
-
78
  def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
79
  if isinstance(tag_names, str):
80
  tag_names = [tag_names]
@@ -943,7 +940,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
943
  callbacks = []
944
  if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
945
  LogPredictionCallback = log_prediction_callback_factory(
946
- trainer, self.tokenizer
 
 
 
 
 
 
 
 
 
947
  )
948
  callbacks.append(LogPredictionCallback(self.cfg))
949
 
 
36
  from axolotl.loraplus import create_loraplus_optimizer
37
  from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
38
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
39
+ from axolotl.utils import is_mlflow_available
40
  from axolotl.utils.callbacks import (
41
  EvalFirstStepCallback,
42
  GPUStatsCallback,
 
72
  LOG = logging.getLogger("axolotl.core.trainer_builder")
73
 
74
 
 
 
 
 
75
  def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
76
  if isinstance(tag_names, str):
77
  tag_names = [tag_names]
 
940
  callbacks = []
941
  if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
942
  LogPredictionCallback = log_prediction_callback_factory(
943
+ trainer, self.tokenizer, "wandb"
944
+ )
945
+ callbacks.append(LogPredictionCallback(self.cfg))
946
+ if (
947
+ self.cfg.use_mlflow
948
+ and is_mlflow_available()
949
+ and self.cfg.eval_table_size > 0
950
+ ):
951
+ LogPredictionCallback = log_prediction_callback_factory(
952
+ trainer, self.tokenizer, "mlflow"
953
  )
954
  callbacks.append(LogPredictionCallback(self.cfg))
955
 
src/axolotl/utils/__init__.py CHANGED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Basic utils for Axolotl
3
+ """
4
+ import importlib
5
+
6
+
7
+ def is_mlflow_available():
8
+ return importlib.util.find_spec("mlflow") is not None
src/axolotl/utils/callbacks/__init__.py CHANGED
@@ -6,7 +6,7 @@ import logging
6
  import os
7
  from shutil import copyfile
8
  from tempfile import NamedTemporaryFile
9
- from typing import TYPE_CHECKING, Dict, List
10
 
11
  import evaluate
12
  import numpy as np
@@ -27,7 +27,9 @@ from transformers import (
27
  )
28
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
29
 
 
30
  from axolotl.utils.bench import log_gpu_memory_usage
 
31
  from axolotl.utils.distributed import (
32
  barrier,
33
  broadcast_dict,
@@ -540,7 +542,7 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
540
  return CausalLMBenchEvalCallback
541
 
542
 
543
- def log_prediction_callback_factory(trainer: Trainer, tokenizer):
544
  class LogPredictionCallback(TrainerCallback):
545
  """Callback to log prediction values during each evaluation"""
546
 
@@ -597,15 +599,13 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
597
  return ranges
598
 
599
  def log_table_from_dataloader(name: str, table_dataloader):
600
- table = wandb.Table( # type: ignore[attr-defined]
601
- columns=[
602
- "id",
603
- "Prompt",
604
- "Correct Completion",
605
- "Predicted Completion (model.generate)",
606
- "Predicted Completion (trainer.prediction_step)",
607
- ]
608
- )
609
  row_index = 0
610
 
611
  for batch in tqdm(table_dataloader):
@@ -709,16 +709,29 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
709
  ) in zip(
710
  prompt_texts, completion_texts, predicted_texts, pred_step_texts
711
  ):
712
- table.add_data(
713
- row_index,
714
- prompt_text,
715
- completion_text,
716
- prediction_text,
717
- pred_step_text,
718
  )
 
 
 
719
  row_index += 1
720
-
721
- wandb.run.log({f"{name} - Predictions vs Ground Truth": table}) # type: ignore[attr-defined]
 
 
 
 
 
 
 
 
 
 
 
722
 
723
  if is_main_process():
724
  log_table_from_dataloader("Eval", eval_dataloader)
 
6
  import os
7
  from shutil import copyfile
8
  from tempfile import NamedTemporaryFile
9
+ from typing import TYPE_CHECKING, Any, Dict, List
10
 
11
  import evaluate
12
  import numpy as np
 
27
  )
28
  from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
29
 
30
+ from axolotl.utils import is_mlflow_available
31
  from axolotl.utils.bench import log_gpu_memory_usage
32
+ from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
33
  from axolotl.utils.distributed import (
34
  barrier,
35
  broadcast_dict,
 
542
  return CausalLMBenchEvalCallback
543
 
544
 
545
+ def log_prediction_callback_factory(trainer: Trainer, tokenizer, logger: str):
546
  class LogPredictionCallback(TrainerCallback):
547
  """Callback to log prediction values during each evaluation"""
548
 
 
599
  return ranges
600
 
601
  def log_table_from_dataloader(name: str, table_dataloader):
602
+ table_data: Dict[str, List[Any]] = {
603
+ "id": [],
604
+ "Prompt": [],
605
+ "Correct Completion": [],
606
+ "Predicted Completion (model.generate)": [],
607
+ "Predicted Completion (trainer.prediction_step)": [],
608
+ }
 
 
609
  row_index = 0
610
 
611
  for batch in tqdm(table_dataloader):
 
709
  ) in zip(
710
  prompt_texts, completion_texts, predicted_texts, pred_step_texts
711
  ):
712
+ table_data["id"].append(row_index)
713
+ table_data["Prompt"].append(prompt_text)
714
+ table_data["Correct Completion"].append(completion_text)
715
+ table_data["Predicted Completion (model.generate)"].append(
716
+ prediction_text
 
717
  )
718
+ table_data[
719
+ "Predicted Completion (trainer.prediction_step)"
720
+ ].append(pred_step_text)
721
  row_index += 1
722
+ if logger == "wandb":
723
+ wandb.run.log({f"{name} - Predictions vs Ground Truth": pd.DataFrame(table_data)}) # type: ignore[attr-defined]
724
+ elif logger == "mlflow" and is_mlflow_available():
725
+ import mlflow
726
+
727
+ tracking_uri = AxolotlInputConfig(
728
+ **self.cfg.to_dict()
729
+ ).mlflow_tracking_uri
730
+ mlflow.log_table(
731
+ data=table_data,
732
+ artifact_file="PredictionsVsGroundTruth.json",
733
+ tracking_uri=tracking_uri,
734
+ )
735
 
736
  if is_main_process():
737
  log_table_from_dataloader("Eval", eval_dataloader)