feat: expose bnb kwargs (#1018)
Browse files* feat: expose bnb kwargs
* chore: added examples and link per suggestion
* Uncomment defaults per suggestion for readability
Co-authored-by: Hamel Husain <[email protected]>
---------
Co-authored-by: Hamel Husain <[email protected]>
- README.md +8 -0
- src/axolotl/utils/models.py +13 -6
README.md
CHANGED
@@ -520,6 +520,14 @@ model_config:
|
|
520 |
type: # linear | dynamic
|
521 |
factor: # float
|
522 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
523 |
|
524 |
# Whether you are training a 4-bit GPTQ quantized model
|
525 |
gptq: true
|
|
|
520 |
type: # linear | dynamic
|
521 |
factor: # float
|
522 |
|
523 |
+
# optional overrides to the bnb 4bit quantization configuration
|
524 |
+
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
525 |
+
bnb_config_kwargs:
|
526 |
+
# These are default values
|
527 |
+
llm_int8_has_fp16_weight: false
|
528 |
+
bnb_4bit_quant_type: nf4
|
529 |
+
bnb_4bit_use_double_quant: true
|
530 |
+
|
531 |
|
532 |
# Whether you are training a 4-bit GPTQ quantized model
|
533 |
gptq: true
|
src/axolotl/utils/models.py
CHANGED
@@ -301,13 +301,20 @@ def load_model(
|
|
301 |
**model_config.quantization_config
|
302 |
)
|
303 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
305 |
-
|
306 |
-
llm_int8_threshold=6.0,
|
307 |
-
llm_int8_has_fp16_weight=False,
|
308 |
-
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
309 |
-
bnb_4bit_use_double_quant=True,
|
310 |
-
bnb_4bit_quant_type="nf4",
|
311 |
)
|
312 |
# sample packing uses custom FA2 patch
|
313 |
if cfg.flash_attention:
|
|
|
301 |
**model_config.quantization_config
|
302 |
)
|
303 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
304 |
+
bnb_config = {
|
305 |
+
"load_in_4bit": True,
|
306 |
+
"llm_int8_threshold": 6.0,
|
307 |
+
"llm_int8_has_fp16_weight": False,
|
308 |
+
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
309 |
+
"bnb_4bit_use_double_quant": True,
|
310 |
+
"bnb_4bit_quant_type": "nf4",
|
311 |
+
}
|
312 |
+
|
313 |
+
if cfg.bnb_config_kwargs:
|
314 |
+
bnb_config.update(cfg.bnb_config_kwargs)
|
315 |
+
|
316 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
317 |
+
**bnb_config,
|
|
|
|
|
|
|
|
|
|
|
318 |
)
|
319 |
# sample packing uses custom FA2 patch
|
320 |
if cfg.flash_attention:
|