Qwen2 (#1166)
Browse files* qwen2 multipack support
* fix qwen derived model check so it doesn't break qwen2
* fixes to ensure qwen2 packing works
* bump requirements for qwen2
* requirements typo
- requirements.txt +2 -2
- src/axolotl/core/trainer_builder.py +1 -1
- src/axolotl/monkeypatch/qwen2/__init__.py +12 -0
- src/axolotl/utils/config.py +6 -11
- src/axolotl/utils/models.py +10 -2
requirements.txt
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.7.0
|
4 |
-
transformers
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.41.1
|
7 |
-
accelerate
|
8 |
deepspeed
|
9 |
addict
|
10 |
fire
|
|
|
1 |
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
2 |
packaging==23.2
|
3 |
peft==0.7.0
|
4 |
+
transformers==4.37.0
|
5 |
tokenizers==0.15.0
|
6 |
bitsandbytes>=0.41.1
|
7 |
+
accelerate==0.26.1
|
8 |
deepspeed
|
9 |
addict
|
10 |
fire
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -905,7 +905,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
905 |
]
|
906 |
]
|
907 |
if use_batch_sampler_collator:
|
908 |
-
if self.cfg.model_config_type
|
909 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
910 |
else:
|
911 |
collator = BatchSamplerDataCollatorForSeq2Seq
|
|
|
905 |
]
|
906 |
]
|
907 |
if use_batch_sampler_collator:
|
908 |
+
if self.cfg.model_config_type in ["mixtral", "qwen2"]:
|
909 |
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
910 |
else:
|
911 |
collator = BatchSamplerDataCollatorForSeq2Seq
|
src/axolotl/monkeypatch/qwen2/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Patches to support multipack for qwen2
|
3 |
+
"""
|
4 |
+
import transformers
|
5 |
+
|
6 |
+
from axolotl.monkeypatch.utils import get_unpad_data
|
7 |
+
|
8 |
+
|
9 |
+
def replace_qwen2_attn_with_multipack_flash_attn():
|
10 |
+
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
11 |
+
get_unpad_data
|
12 |
+
)
|
src/axolotl/utils/config.py
CHANGED
@@ -142,17 +142,12 @@ def normalize_config(cfg):
|
|
142 |
)
|
143 |
|
144 |
cfg.is_qwen_derived_model = (
|
145 |
-
(
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
)
|
152 |
-
or cfg.is_qwen_derived_model
|
153 |
-
or "qwen" in cfg.base_model.lower()
|
154 |
-
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
155 |
-
)
|
156 |
|
157 |
if isinstance(cfg.learning_rate, str):
|
158 |
cfg.learning_rate = float(cfg.learning_rate)
|
|
|
142 |
)
|
143 |
|
144 |
cfg.is_qwen_derived_model = (
|
145 |
+
hasattr(model_config, "model_type")
|
146 |
+
and model_config.model_type
|
147 |
+
in [
|
148 |
+
"qwen",
|
149 |
+
]
|
150 |
+
) or cfg.is_qwen_derived_model
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
if isinstance(cfg.learning_rate, str):
|
153 |
cfg.learning_rate = float(cfg.learning_rate)
|
src/axolotl/utils/models.py
CHANGED
@@ -334,6 +334,14 @@ def load_model(
|
|
334 |
LOG.info("patching mixtral with flash attention")
|
335 |
replace_mixtral_attn_with_multipack_flash_attn()
|
336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
337 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
338 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
339 |
|
@@ -426,14 +434,14 @@ def load_model(
|
|
426 |
cfg.is_llama_derived_model
|
427 |
or cfg.is_falcon_derived_model
|
428 |
or cfg.is_mistral_derived_model
|
429 |
-
or model_config.model_type
|
430 |
):
|
431 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
432 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
433 |
"flash_attention_2"
|
434 |
)
|
435 |
else:
|
436 |
-
if model_config.model_type
|
437 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
438 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
439 |
"flash_attention_2"
|
|
|
334 |
LOG.info("patching mixtral with flash attention")
|
335 |
replace_mixtral_attn_with_multipack_flash_attn()
|
336 |
|
337 |
+
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
338 |
+
from axolotl.monkeypatch.qwen2 import (
|
339 |
+
replace_qwen2_attn_with_multipack_flash_attn,
|
340 |
+
)
|
341 |
+
|
342 |
+
LOG.info("patching qwen2 with flash attention")
|
343 |
+
replace_qwen2_attn_with_multipack_flash_attn()
|
344 |
+
|
345 |
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
346 |
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
347 |
|
|
|
434 |
cfg.is_llama_derived_model
|
435 |
or cfg.is_falcon_derived_model
|
436 |
or cfg.is_mistral_derived_model
|
437 |
+
or model_config.model_type in ["mixtral", "qwen2"]
|
438 |
):
|
439 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
440 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
441 |
"flash_attention_2"
|
442 |
)
|
443 |
else:
|
444 |
+
if model_config.model_type in ["mixtral", "qwen2"]:
|
445 |
model_kwargs["attn_implementation"] = "flash_attention_2"
|
446 |
model_config._attn_implementation = ( # pylint: disable=protected-access
|
447 |
"flash_attention_2"
|