feat: add check for quantized model (#913)
Browse files* feat: add check for quantized model
* chore: refactor and add another check
* Update src/axolotl/utils/models.py
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/utils/models.py +23 -0
src/axolotl/utils/models.py
CHANGED
@@ -28,6 +28,27 @@ from axolotl.utils.dict import DictDefault
|
|
28 |
LOG = logging.getLogger("axolotl")
|
29 |
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def load_model_config(cfg):
|
32 |
model_config_name = cfg.base_model_config or cfg.base_model
|
33 |
trust_remote_code = cfg.trust_remote_code is True
|
@@ -38,6 +59,8 @@ def load_model_config(cfg):
|
|
38 |
for key, val in cfg.model_config.items():
|
39 |
setattr(model_config, key, val)
|
40 |
|
|
|
|
|
41 |
return model_config
|
42 |
|
43 |
|
|
|
28 |
LOG = logging.getLogger("axolotl")
|
29 |
|
30 |
|
31 |
+
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
32 |
+
quant_config_exists = hasattr(model_config, "quantization_config")
|
33 |
+
quant_config_method_is_gptq = (
|
34 |
+
quant_config_exists
|
35 |
+
and "quant_method" in model_config.quantization_config
|
36 |
+
and model_config.quantization_config["quant_method"] == "gptq"
|
37 |
+
)
|
38 |
+
|
39 |
+
if cfg.gptq and not quant_config_method_is_gptq:
|
40 |
+
raise ValueError(
|
41 |
+
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
42 |
+
"Please make sure to point to a GPTQ model."
|
43 |
+
)
|
44 |
+
|
45 |
+
if not cfg.gptq and quant_config_exists:
|
46 |
+
raise ValueError(
|
47 |
+
"model_config.quantization_config is set but `gptq` flag is not. "
|
48 |
+
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
def load_model_config(cfg):
|
53 |
model_config_name = cfg.base_model_config or cfg.base_model
|
54 |
trust_remote_code = cfg.trust_remote_code is True
|
|
|
59 |
for key, val in cfg.model_config.items():
|
60 |
setattr(model_config, key, val)
|
61 |
|
62 |
+
check_model_config(cfg, model_config)
|
63 |
+
|
64 |
return model_config
|
65 |
|
66 |
|