casperhansen winglian commited on
Commit
d66b101
·
unverified ·
1 Parent(s): 304ea1b

Disable caching on `--disable_caching` in CLI (#1110)

Browse files

* Disable caching on `--disable_caching` in CLI

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>

src/axolotl/cli/preprocess.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  import fire
8
  import transformers
9
  from colorama import Fore
 
10
 
11
  from axolotl.cli import (
12
  check_accelerate_default_config,
@@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
28
  check_accelerate_default_config()
29
  check_user_token()
30
  parser = transformers.HfArgumentParser((PreprocessCliArgs))
31
- parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
  return_remaining_strings=True
33
  )
 
 
 
 
 
 
34
  if not parsed_cfg.dataset_prepared_path:
35
  msg = (
36
  Fore.RED
 
7
  import fire
8
  import transformers
9
  from colorama import Fore
10
+ from datasets import disable_caching
11
 
12
  from axolotl.cli import (
13
  check_accelerate_default_config,
 
29
  check_accelerate_default_config()
30
  check_user_token()
31
  parser = transformers.HfArgumentParser((PreprocessCliArgs))
32
+ parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
33
  return_remaining_strings=True
34
  )
35
+
36
+ if (
37
+ remaining_args.get("disable_caching") is not None
38
+ and remaining_args["disable_caching"]
39
+ ):
40
+ disable_caching()
41
  if not parsed_cfg.dataset_prepared_path:
42
  msg = (
43
  Fore.RED
src/axolotl/cli/train.py CHANGED
@@ -6,6 +6,7 @@ from pathlib import Path
6
 
7
  import fire
8
  import transformers
 
9
 
10
  from axolotl.cli import (
11
  check_accelerate_default_config,
@@ -28,9 +29,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
28
  check_accelerate_default_config()
29
  check_user_token()
30
  parser = transformers.HfArgumentParser((TrainerCliArgs))
31
- parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
  return_remaining_strings=True
33
  )
 
 
 
 
 
 
34
  if parsed_cfg.rl:
35
  dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
36
  else:
 
6
 
7
  import fire
8
  import transformers
9
+ from datasets import disable_caching
10
 
11
  from axolotl.cli import (
12
  check_accelerate_default_config,
 
29
  check_accelerate_default_config()
30
  check_user_token()
31
  parser = transformers.HfArgumentParser((TrainerCliArgs))
32
+ parsed_cli_args, remaining_args = parser.parse_args_into_dataclasses(
33
  return_remaining_strings=True
34
  )
35
+
36
+ if (
37
+ remaining_args.get("disable_caching") is not None
38
+ and remaining_args["disable_caching"]
39
+ ):
40
+ disable_caching()
41
  if parsed_cfg.rl:
42
  dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
43
  else: