tmm1 commited on
Commit
efb3b2c
·
1 Parent(s): 7b55fe6

simplify `load_tokenizer`

Browse files
scripts/finetune.py CHANGED
@@ -177,9 +177,8 @@ def train(
177
  setup_wandb_env_vars(cfg)
178
 
179
  # load the tokenizer first
180
- tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
181
- LOG.info(f"loading tokenizer... {tokenizer_config}")
182
- tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
183
 
184
  if (
185
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
 
177
  setup_wandb_env_vars(cfg)
178
 
179
  # load the tokenizer first
180
+ LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
181
+ tokenizer = load_tokenizer(cfg)
 
182
 
183
  if (
184
  check_not_in(["shard", "merge_lora"], kwargs) and not cfg.inference
src/axolotl/utils/models.py CHANGED
@@ -32,32 +32,27 @@ if TYPE_CHECKING:
32
  from axolotl.utils.dict import DictDefault # noqa: F401
33
 
34
 
35
- def load_tokenizer(
36
- tokenizer_config,
37
- tokenizer_type,
38
- cfg,
39
- ):
40
  tokenizer_kwargs = {}
41
  use_fast = True # this is the default
 
42
  if cfg.tokenizer_use_fast is not None:
43
  use_fast = cfg.tokenizer_use_fast
44
  if cfg.tokenizer_legacy is not None:
45
  # True is the default w/ https://github.com/huggingface/transformers/pull/25224
46
  tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
47
- if tokenizer_type:
48
- tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
49
- tokenizer_config,
50
- trust_remote_code=cfg.trust_remote_code or False,
51
- use_fast=use_fast,
52
- **tokenizer_kwargs,
53
- )
54
- else:
55
- tokenizer = AutoTokenizer.from_pretrained(
56
- tokenizer_config,
57
- trust_remote_code=cfg.trust_remote_code or False,
58
- use_fast=use_fast,
59
- **tokenizer_kwargs,
60
- )
61
 
62
  if tokenizer.__class__.__name__ in [
63
  "LlamaTokenizer",
 
32
  from axolotl.utils.dict import DictDefault # noqa: F401
33
 
34
 
35
+ def load_tokenizer(cfg):
 
 
 
 
36
  tokenizer_kwargs = {}
37
  use_fast = True # this is the default
38
+
39
  if cfg.tokenizer_use_fast is not None:
40
  use_fast = cfg.tokenizer_use_fast
41
  if cfg.tokenizer_legacy is not None:
42
  # True is the default w/ https://github.com/huggingface/transformers/pull/25224
43
  tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
44
+
45
+ tokenizer_cls = AutoTokenizer
46
+ if cfg.tokenizer_type:
47
+ tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
48
+
49
+ tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
50
+ tokenizer = tokenizer_cls.from_pretrained(
51
+ tokenizer_config,
52
+ trust_remote_code=cfg.trust_remote_code or False,
53
+ use_fast=use_fast,
54
+ **tokenizer_kwargs,
55
+ )
 
 
56
 
57
  if tokenizer.__class__.__name__ in [
58
  "LlamaTokenizer",
tests/test_tokenizers.py CHANGED
@@ -13,17 +13,22 @@ class TestTokenizers(unittest.TestCase):
13
  """
14
 
15
  def test_default_use_fast(self):
16
- cfg = DictDefault({})
17
- tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
 
 
 
 
18
  assert "Fast" in tokenizer.__class__.__name__
19
 
20
  def test_dont_use_fast(self):
21
  cfg = DictDefault(
22
  {
 
23
  "tokenizer_use_fast": False,
24
  }
25
  )
26
- tokenizer = load_tokenizer("huggyllama/llama-7b", None, cfg)
27
  assert "Fast" not in tokenizer.__class__.__name__
28
 
29
 
 
13
  """
14
 
15
  def test_default_use_fast(self):
16
+ cfg = DictDefault(
17
+ {
18
+ "tokenizer_config": "huggyllama/llama-7b",
19
+ }
20
+ )
21
+ tokenizer = load_tokenizer(cfg)
22
  assert "Fast" in tokenizer.__class__.__name__
23
 
24
  def test_dont_use_fast(self):
25
  cfg = DictDefault(
26
  {
27
+ "tokenizer_config": "huggyllama/llama-7b",
28
  "tokenizer_use_fast": False,
29
  }
30
  )
31
+ tokenizer = load_tokenizer(cfg)
32
  assert "Fast" not in tokenizer.__class__.__name__
33
 
34