winglian commited on
Commit
39a208c
·
1 Parent(s): 2520ecd

fix up tokenizer config, isort fix

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +3 -2
  2. src/axolotl/utils/models.py +10 -5
scripts/finetune.py CHANGED
@@ -171,8 +171,9 @@ def train(
171
  validate_config(cfg)
172
 
173
  # load the tokenizer first
174
- logging.info("loading tokenizer...")
175
- tokenizer = load_tokenizer(cfg.base_model_config, cfg.tokenizer_type, cfg)
 
176
 
177
  if check_not_in(
178
  ["inference", "shard", "merge_lora"], kwargs
 
171
  validate_config(cfg)
172
 
173
  # load the tokenizer first
174
+ tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
175
+ logging.info(f"loading tokenizer... {tokenizer_config}")
176
+ tokenizer = load_tokenizer(tokenizer_config, cfg.tokenizer_type, cfg)
177
 
178
  if check_not_in(
179
  ["inference", "shard", "merge_lora"], kwargs
src/axolotl/utils/models.py CHANGED
@@ -10,9 +10,14 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
13
- from transformers import AutoModelForCausalLM, LlamaConfig # noqa: F401
14
  from transformers import PreTrainedModel # noqa: F401
15
- from transformers import AutoConfig, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
 
 
16
 
17
  try:
18
  from transformers import LlamaForCausalLM
@@ -31,18 +36,18 @@ if TYPE_CHECKING:
31
 
32
 
33
  def load_tokenizer(
34
- base_model_config,
35
  tokenizer_type,
36
  cfg,
37
  ):
38
  if tokenizer_type:
39
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
40
- base_model_config,
41
  trust_remote_code=cfg.trust_remote_code or False,
42
  )
43
  else:
44
  tokenizer = AutoTokenizer.from_pretrained(
45
- base_model_config,
46
  trust_remote_code=cfg.trust_remote_code or False,
47
  )
48
 
 
10
  import bitsandbytes as bnb
11
  import torch
12
  import transformers
 
13
  from transformers import PreTrainedModel # noqa: F401
14
+ from transformers import ( # noqa: F401
15
+ AutoConfig,
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ BitsAndBytesConfig,
19
+ LlamaConfig,
20
+ )
21
 
22
  try:
23
  from transformers import LlamaForCausalLM
 
36
 
37
 
38
  def load_tokenizer(
39
+ tokenizer_config,
40
  tokenizer_type,
41
  cfg,
42
  ):
43
  if tokenizer_type:
44
  tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
45
+ tokenizer_config,
46
  trust_remote_code=cfg.trust_remote_code or False,
47
  )
48
  else:
49
  tokenizer = AutoTokenizer.from_pretrained(
50
+ tokenizer_config,
51
  trust_remote_code=cfg.trust_remote_code or False,
52
  )
53