ADD: push checkpoints to mlflow artifact registry (#1295) [skip ci]
Browse files* Add checkpoint logging to mlflow artifact registry
* clean up
* Update README.md
Co-authored-by: NanoCode012 <[email protected]>
* update pydantic config from rebase
---------
Co-authored-by: NanoCode012 <[email protected]>
Co-authored-by: Wing Lian <[email protected]>
README.md
CHANGED
@@ -763,6 +763,7 @@ wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_step
|
|
763 |
# mlflow configuration if you're using it
|
764 |
mlflow_tracking_uri: # URI to mlflow
|
765 |
mlflow_experiment_name: # Your experiment name
|
|
|
766 |
|
767 |
# Where to save the full-finetuned model to
|
768 |
output_dir: ./completed-model
|
|
|
763 |
# mlflow configuration if you're using it
|
764 |
mlflow_tracking_uri: # URI to mlflow
|
765 |
mlflow_experiment_name: # Your experiment name
|
766 |
+
hf_mlflow_log_artifacts: # set to true to copy each saved checkpoint on each save to mlflow artifact registry
|
767 |
|
768 |
# Where to save the full-finetuned model to
|
769 |
output_dir: ./completed-model
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -305,6 +305,7 @@ class MLFlowConfig(BaseModel):
|
|
305 |
use_mlflow: Optional[str] = None
|
306 |
mlflow_tracking_uri: Optional[str] = None
|
307 |
mlflow_experiment_name: Optional[str] = None
|
|
|
308 |
|
309 |
|
310 |
class WandbConfig(BaseModel):
|
|
|
305 |
use_mlflow: Optional[str] = None
|
306 |
mlflow_tracking_uri: Optional[str] = None
|
307 |
mlflow_experiment_name: Optional[str] = None
|
308 |
+
hf_mlflow_log_artifacts: Optional[bool] = None
|
309 |
|
310 |
|
311 |
class WandbConfig(BaseModel):
|
src/axolotl/utils/mlflow_.py
CHANGED
@@ -7,7 +7,7 @@ from axolotl.utils.dict import DictDefault
|
|
7 |
|
8 |
def setup_mlflow_env_vars(cfg: DictDefault):
|
9 |
for key in cfg.keys():
|
10 |
-
if key.startswith("mlflow_"):
|
11 |
value = cfg.get(key, "")
|
12 |
|
13 |
if value and isinstance(value, str) and len(value) > 0:
|
|
|
7 |
|
8 |
def setup_mlflow_env_vars(cfg: DictDefault):
|
9 |
for key in cfg.keys():
|
10 |
+
if key.startswith("mlflow_") or key.startswith("hf_mlflow_"):
|
11 |
value = cfg.get(key, "")
|
12 |
|
13 |
if value and isinstance(value, str) and len(value) > 0:
|