Nanobit winglian commited on
Commit
a581e9f
·
unverified ·
1 Parent(s): 992e742

feat: add check for quantized model (#913)

Browse files

* feat: add check for quantized model

* chore: refactor and add another check

* Update src/axolotl/utils/models.py

---------

Co-authored-by: Wing Lian <[email protected]>

Files changed (1) hide show
  1. src/axolotl/utils/models.py +23 -0
src/axolotl/utils/models.py CHANGED
@@ -28,6 +28,27 @@ from axolotl.utils.dict import DictDefault
28
  LOG = logging.getLogger("axolotl")
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def load_model_config(cfg):
32
  model_config_name = cfg.base_model_config or cfg.base_model
33
  trust_remote_code = cfg.trust_remote_code is True
@@ -38,6 +59,8 @@ def load_model_config(cfg):
38
  for key, val in cfg.model_config.items():
39
  setattr(model_config, key, val)
40
 
 
 
41
  return model_config
42
 
43
 
 
28
  LOG = logging.getLogger("axolotl")
29
 
30
 
31
+ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
32
+ quant_config_exists = hasattr(model_config, "quantization_config")
33
+ quant_config_method_is_gptq = (
34
+ quant_config_exists
35
+ and "quant_method" in model_config.quantization_config
36
+ and model_config.quantization_config["quant_method"] == "gptq"
37
+ )
38
+
39
+ if cfg.gptq and not quant_config_method_is_gptq:
40
+ raise ValueError(
41
+ "model_config.quantization_config is not set or quant_method is not set to gptq. "
42
+ "Please make sure to point to a GPTQ model."
43
+ )
44
+
45
+ if not cfg.gptq and quant_config_exists:
46
+ raise ValueError(
47
+ "model_config.quantization_config is set but `gptq` flag is not. "
48
+ "Please use the `gptq` flag to train quantized model or point to a non-quantized model."
49
+ )
50
+
51
+
52
  def load_model_config(cfg):
53
  model_config_name = cfg.base_model_config or cfg.base_model
54
  trust_remote_code = cfg.trust_remote_code is True
 
59
  for key, val in cfg.model_config.items():
60
  setattr(model_config, key, val)
61
 
62
+ check_model_config(cfg, model_config)
63
+
64
  return model_config
65
 
66