JohanWork Nanobit winglian commited on
Commit
d756534
·
unverified ·
1 Parent(s): c6b01e0

ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]

Browse files

* Add checkpoint logging to mlflow artifact registry

* clean up

* Update README.md

Co-authored-by: NanoCode012 <[email protected]>

* update pydantic config from rebase

---------

Co-authored-by: NanoCode012 <[email protected]>
Co-authored-by: Wing Lian <[email protected]>

README.md CHANGED
@@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
763
  # mlflow configuration if you're using it
764
  mlflow_tracking_uri: # URI to mlflow
765
  mlflow_experiment_name: # Your experiment name
 
766
 
767
  # Where to save the full-finetuned model to
768
  output_dir: ./completed-model
 
763
  # mlflow configuration if you're using it
764
  mlflow_tracking_uri: # URI to mlflow
765
  mlflow_experiment_name: # Your experiment name
766
+ hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
767
 
768
  # Where to save the full-finetuned model to
769
  output_dir: ./completed-model
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel):
305
  use_mlflow: Optional[str] = None
306
  mlflow_tracking_uri: Optional[str] = None
307
  mlflow_experiment_name: Optional[str] = None
 
308
 
309
 
310
  class WandbConfig(BaseModel):
 
305
  use_mlflow: Optional[str] = None
306
  mlflow_tracking_uri: Optional[str] = None
307
  mlflow_experiment_name: Optional[str] = None
308
+ hf_mlflow_log_artifacts: Optional[bool] = None
309
 
310
 
311
  class WandbConfig(BaseModel):
src/axolotl/utils/mlflow_.py CHANGED
@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
7
 
8
  def setup_mlflow_env_vars(cfg: DictDefault):
9
  for key in cfg.keys():
10
- if key.startswith("mlflow_"):
11
  value = cfg.get(key, "")
12
 
13
  if value and isinstance(value, str) and len(value) > 0:
 
7
 
8
  def setup_mlflow_env_vars(cfg: DictDefault):
9
  for key in cfg.keys():
10
+ if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
11
  value = cfg.get(key, "")
12
 
13
  if value and isinstance(value, str) and len(value) > 0: