improve handling of the prepared ds path and other cfg defaults (#701)
Browse files- src/axolotl/cli/inference.py +1 -0
- src/axolotl/cli/train.py +13 -0
- src/axolotl/common/const.py +5 -0
- src/axolotl/utils/data.py +2 -2
src/axolotl/cli/inference.py
CHANGED
@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
14 |
# pylint: disable=duplicate-code
|
15 |
print_axolotl_text_art()
|
16 |
parsed_cfg = load_cfg(config, **kwargs)
|
|
|
17 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
18 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
19 |
return_remaining_strings=True
|
|
|
14 |
# pylint: disable=duplicate-code
|
15 |
print_axolotl_text_art()
|
16 |
parsed_cfg = load_cfg(config, **kwargs)
|
17 |
+
parsed_cfg.sample_packing = False
|
18 |
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
19 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
20 |
return_remaining_strings=True
|
src/axolotl/cli/train.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
"""
|
2 |
CLI to run training on a model
|
3 |
"""
|
|
|
4 |
from pathlib import Path
|
5 |
|
6 |
import fire
|
7 |
import transformers
|
|
|
8 |
|
9 |
from axolotl.cli import (
|
10 |
check_accelerate_default_config,
|
@@ -14,8 +16,11 @@ from axolotl.cli import (
|
|
14 |
print_axolotl_text_art,
|
15 |
)
|
16 |
from axolotl.common.cli import TrainerCliArgs
|
|
|
17 |
from axolotl.train import train
|
18 |
|
|
|
|
|
19 |
|
20 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
21 |
# pylint: disable=duplicate-code
|
@@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
27 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
28 |
return_remaining_strings=True
|
29 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
32 |
if parsed_cli_args.prepare_ds_only:
|
|
|
1 |
"""
|
2 |
CLI to run training on a model
|
3 |
"""
|
4 |
+
import logging
|
5 |
from pathlib import Path
|
6 |
|
7 |
import fire
|
8 |
import transformers
|
9 |
+
from colorama import Fore
|
10 |
|
11 |
from axolotl.cli import (
|
12 |
check_accelerate_default_config,
|
|
|
16 |
print_axolotl_text_art,
|
17 |
)
|
18 |
from axolotl.common.cli import TrainerCliArgs
|
19 |
+
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
20 |
from axolotl.train import train
|
21 |
|
22 |
+
LOG = logging.getLogger("axolotl.cli.train")
|
23 |
+
|
24 |
|
25 |
def do_cli(config: Path = Path("examples/"), **kwargs):
|
26 |
# pylint: disable=duplicate-code
|
|
|
32 |
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
33 |
return_remaining_strings=True
|
34 |
)
|
35 |
+
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
36 |
+
msg = (
|
37 |
+
Fore.RED
|
38 |
+
+ "--prepare_ds_only called without dataset_prepared_path set."
|
39 |
+
+ Fore.RESET
|
40 |
+
)
|
41 |
+
LOG.warning(msg)
|
42 |
+
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
43 |
|
44 |
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
45 |
if parsed_cli_args.prepare_ds_only:
|
src/axolotl/common/const.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Various shared constants
|
3 |
+
"""
|
4 |
+
|
5 |
+
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
src/axolotl/utils/data.py
CHANGED
@@ -16,6 +16,7 @@ from datasets import (
|
|
16 |
from huggingface_hub import hf_hub_download
|
17 |
from transformers import PreTrainedTokenizerBase
|
18 |
|
|
|
19 |
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
20 |
from axolotl.prompt_strategies import load
|
21 |
from axolotl.prompt_tokenizers import (
|
@@ -44,7 +45,6 @@ from axolotl.utils.trainer import (
|
|
44 |
)
|
45 |
|
46 |
LOG = logging.getLogger("axolotl")
|
47 |
-
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
48 |
|
49 |
|
50 |
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
@@ -357,7 +357,7 @@ def load_tokenized_prepared_datasets(
|
|
357 |
if len(datasets) > 1:
|
358 |
LOG.info("shuffle merged datasets")
|
359 |
dataset = dataset.shuffle(seed=seed)
|
360 |
-
if cfg.local_rank == 0
|
361 |
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
362 |
dataset.save_to_disk(prepared_ds_path)
|
363 |
if cfg.push_dataset_to_hub:
|
|
|
16 |
from huggingface_hub import hf_hub_download
|
17 |
from transformers import PreTrainedTokenizerBase
|
18 |
|
19 |
+
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
20 |
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
21 |
from axolotl.prompt_strategies import load
|
22 |
from axolotl.prompt_tokenizers import (
|
|
|
45 |
)
|
46 |
|
47 |
LOG = logging.getLogger("axolotl")
|
|
|
48 |
|
49 |
|
50 |
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|
|
357 |
if len(datasets) > 1:
|
358 |
LOG.info("shuffle merged datasets")
|
359 |
dataset = dataset.shuffle(seed=seed)
|
360 |
+
if cfg.local_rank == 0:
|
361 |
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
362 |
dataset.save_to_disk(prepared_ds_path)
|
363 |
if cfg.push_dataset_to_hub:
|