Add support for GPTQ using native transformers/peft (#468)
Browse files* auto gptq support
* more tweaks and add yml
* remove old gptq docker
* don't need explicit peft install for tests
* fix setup.py to use extra index url
install torch for tests
fix cuda version for autogptq index
set torch in requirements so that it installs properly
move gptq install around to work with github cicd
* gptq doesn't play well with sample packing
* address pr feedback
* remove torch install for now
* set quantization_config from model config
* Fix the implementation for getting quant config from model config
- .github/workflows/main.yml +0 -10
- .github/workflows/tests.yml +1 -1
- docker/Dockerfile +2 -3
- examples/gptq-lora-7b/README.md +0 -8
- examples/{gptq-lora-7b/config.yml → llama-2/gptq-lora.yml} +44 -31
- requirements.txt +4 -0
- setup.py +23 -16
- src/axolotl/utils/config.py +1 -3
- src/axolotl/utils/models.py +34 -89
- src/axolotl/utils/trainer.py +1 -17
.github/workflows/main.yml
CHANGED
@@ -23,11 +23,6 @@ jobs:
|
|
23 |
python_version: "3.10"
|
24 |
pytorch: 2.0.1
|
25 |
axolotl_extras:
|
26 |
-
- cuda: 118
|
27 |
-
cuda_version: 11.8.0
|
28 |
-
python_version: "3.9"
|
29 |
-
pytorch: 2.0.1
|
30 |
-
axolotl_extras: gptq
|
31 |
runs-on: self-hosted
|
32 |
steps:
|
33 |
- name: Checkout
|
@@ -73,11 +68,6 @@ jobs:
|
|
73 |
pytorch: 2.0.1
|
74 |
axolotl_extras:
|
75 |
is_latest: true
|
76 |
-
- cuda: 118
|
77 |
-
cuda_version: 11.8.0
|
78 |
-
python_version: "3.9"
|
79 |
-
pytorch: 2.0.1
|
80 |
-
axolotl_extras: gptq
|
81 |
runs-on: self-hosted
|
82 |
steps:
|
83 |
- name: Checkout
|
|
|
23 |
python_version: "3.10"
|
24 |
pytorch: 2.0.1
|
25 |
axolotl_extras:
|
|
|
|
|
|
|
|
|
|
|
26 |
runs-on: self-hosted
|
27 |
steps:
|
28 |
- name: Checkout
|
|
|
68 |
pytorch: 2.0.1
|
69 |
axolotl_extras:
|
70 |
is_latest: true
|
|
|
|
|
|
|
|
|
|
|
71 |
runs-on: self-hosted
|
72 |
steps:
|
73 |
- name: Checkout
|
.github/workflows/tests.yml
CHANGED
@@ -24,7 +24,7 @@ jobs:
|
|
24 |
|
25 |
- name: Install dependencies
|
26 |
run: |
|
27 |
-
pip install -e .
|
28 |
pip install -r requirements-tests.txt
|
29 |
|
30 |
- name: Run tests
|
|
|
24 |
|
25 |
- name: Install dependencies
|
26 |
run: |
|
27 |
+
pip install -e .
|
28 |
pip install -r requirements-tests.txt
|
29 |
|
30 |
- name: Run tests
|
docker/Dockerfile
CHANGED
@@ -11,14 +11,13 @@ RUN apt-get update && \
|
|
11 |
|
12 |
WORKDIR /workspace
|
13 |
|
14 |
-
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main"
|
15 |
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
16 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
17 |
RUN cd axolotl && \
|
18 |
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
19 |
-
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
20 |
else \
|
21 |
-
pip install -e .[flash-attn]; \
|
22 |
fi
|
23 |
|
24 |
# fix so that git fetch/pull from remote works
|
|
|
11 |
|
12 |
WORKDIR /workspace
|
13 |
|
|
|
14 |
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
15 |
# If AXOLOTL_EXTRAS is set, append it in brackets
|
16 |
RUN cd axolotl && \
|
17 |
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
18 |
+
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
|
19 |
else \
|
20 |
+
pip install -e .[flash-attn,gptq]; \
|
21 |
fi
|
22 |
|
23 |
# fix so that git fetch/pull from remote works
|
examples/gptq-lora-7b/README.md
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
# LLaMa 7B using LoRA
|
2 |
-
|
3 |
-
This is a good place to start for beginners. This will run on an NVIDIA RTX4090 with no other changes needed.
|
4 |
-
|
5 |
-
```shell
|
6 |
-
accelerate launch scripts/finetune.py examples/gptq-lora-7b/config.yml
|
7 |
-
|
8 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/{gptq-lora-7b/config.yml → llama-2/gptq-lora.yml}
RENAMED
@@ -1,63 +1,76 @@
|
|
1 |
-
base_model:
|
2 |
-
base_model_config:
|
3 |
-
|
4 |
-
tokenizer_type: LlamaTokenizer
|
5 |
-
trust_remote_code:
|
6 |
-
load_in_8bit: true
|
7 |
gptq: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
datasets:
|
9 |
-
- path:
|
10 |
type: alpaca
|
11 |
dataset_prepared_path: last_run_prepared
|
12 |
-
val_set_size: 0.
|
13 |
-
adapter:
|
14 |
lora_model_dir:
|
15 |
-
sequence_len:
|
16 |
-
|
17 |
lora_r: 8
|
18 |
-
lora_alpha:
|
19 |
lora_dropout: 0.05
|
20 |
lora_target_modules:
|
|
|
|
|
21 |
- q_proj
|
22 |
- v_proj
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
wandb_watch:
|
27 |
wandb_run_id:
|
28 |
wandb_log_model:
|
29 |
-
output_dir: ./
|
30 |
gradient_accumulation_steps: 1
|
31 |
micro_batch_size: 1
|
32 |
num_epochs: 3
|
33 |
-
optimizer:
|
|
|
|
|
|
|
34 |
torchdistx_path:
|
35 |
lr_scheduler: cosine
|
36 |
-
|
|
|
37 |
train_on_inputs: false
|
38 |
group_by_length: false
|
39 |
-
fp16: true
|
40 |
bf16: false
|
|
|
|
|
41 |
tf32: true
|
|
|
42 |
early_stopping_patience:
|
43 |
resume_from_checkpoint:
|
44 |
local_rank:
|
45 |
-
logging_steps:
|
46 |
xformers_attention:
|
47 |
flash_attention:
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
54 |
debug:
|
55 |
deepspeed:
|
56 |
-
weight_decay: 0.
|
57 |
-
|
58 |
-
fsdp_config:
|
59 |
-
tokens:
|
60 |
-
pad_token: "<pad>"
|
61 |
bos_token: "<s>"
|
62 |
eos_token: "</s>"
|
63 |
unk_token: "<unk>"
|
|
|
1 |
+
base_model: TheBloke/Llama-2-7B-GPTQ
|
2 |
+
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
3 |
+
is_llama_derived_model: false
|
|
|
|
|
|
|
4 |
gptq: true
|
5 |
+
gptq_bits: 4
|
6 |
+
model_type: AutoModelForCausalLM
|
7 |
+
tokenizer_type: LlamaTokenizer
|
8 |
+
tokenizer_use_fast: true
|
9 |
+
tokenizer_legacy: true
|
10 |
+
load_in_8bit: false
|
11 |
+
load_in_4bit: false
|
12 |
+
strict: false
|
13 |
+
push_dataset_to_hub:
|
14 |
+
hf_use_auth_token: true
|
15 |
datasets:
|
16 |
+
- path: mhenrichsen/alpaca_2k_test
|
17 |
type: alpaca
|
18 |
dataset_prepared_path: last_run_prepared
|
19 |
+
val_set_size: 0.01
|
20 |
+
adapter: lora
|
21 |
lora_model_dir:
|
22 |
+
sequence_len: 4096
|
23 |
+
sample_packing:
|
24 |
lora_r: 8
|
25 |
+
lora_alpha: 32
|
26 |
lora_dropout: 0.05
|
27 |
lora_target_modules:
|
28 |
+
- k_proj
|
29 |
+
- o_proj
|
30 |
- q_proj
|
31 |
- v_proj
|
32 |
+
lora_target_linear:
|
33 |
+
lora_fan_in_fan_out:
|
34 |
+
wandb_project:
|
35 |
wandb_watch:
|
36 |
wandb_run_id:
|
37 |
wandb_log_model:
|
38 |
+
output_dir: ./model-out
|
39 |
gradient_accumulation_steps: 1
|
40 |
micro_batch_size: 1
|
41 |
num_epochs: 3
|
42 |
+
optimizer: adamw_torch
|
43 |
+
adam_beta2: 0.95
|
44 |
+
adam_eps: 0.00001
|
45 |
+
max_grad_norm: 1.0
|
46 |
torchdistx_path:
|
47 |
lr_scheduler: cosine
|
48 |
+
lr_quadratic_warmup: true
|
49 |
+
learning_rate: 0.000017
|
50 |
train_on_inputs: false
|
51 |
group_by_length: false
|
|
|
52 |
bf16: false
|
53 |
+
fp16: false
|
54 |
+
float16: true
|
55 |
tf32: true
|
56 |
+
gradient_checkpointing: true
|
57 |
early_stopping_patience:
|
58 |
resume_from_checkpoint:
|
59 |
local_rank:
|
60 |
+
logging_steps: 1
|
61 |
xformers_attention:
|
62 |
flash_attention:
|
63 |
+
sdp_attention:
|
64 |
+
flash_optimum:
|
65 |
+
gptq_groupsize:
|
66 |
+
gptq_model_v1:
|
67 |
+
warmup_steps: 100
|
68 |
+
eval_steps:
|
69 |
+
save_steps:
|
70 |
debug:
|
71 |
deepspeed:
|
72 |
+
weight_decay: 0.1
|
73 |
+
special_tokens:
|
|
|
|
|
|
|
74 |
bos_token: "<s>"
|
75 |
eos_token: "</s>"
|
76 |
unk_token: "<unk>"
|
requirements.txt
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
packaging
|
2 |
peft @ git+https://github.com/huggingface/peft.git
|
3 |
transformers @ git+https://github.com/huggingface/transformers.git
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
2 |
+
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
3 |
+
torch==2.0.1
|
4 |
+
auto-gptq
|
5 |
packaging
|
6 |
peft @ git+https://github.com/huggingface/peft.git
|
7 |
transformers @ git+https://github.com/huggingface/transformers.git
|
setup.py
CHANGED
@@ -2,15 +2,27 @@
|
|
2 |
|
3 |
from setuptools import find_packages, setup
|
4 |
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
setup(
|
16 |
name="axolotl",
|
@@ -19,12 +31,10 @@ setup(
|
|
19 |
package_dir={"": "src"},
|
20 |
packages=find_packages(),
|
21 |
install_requires=install_requires,
|
|
|
22 |
extras_require={
|
23 |
"gptq": [
|
24 |
-
"
|
25 |
-
],
|
26 |
-
"gptq_triton": [
|
27 |
-
"alpaca_lora_4bit[triton] @ git+https://github.com/winglian/alpaca_lora_4bit.git@setup_pip",
|
28 |
],
|
29 |
"flash-attn": [
|
30 |
"flash-attn==2.0.8",
|
@@ -32,8 +42,5 @@ setup(
|
|
32 |
"extras": [
|
33 |
"deepspeed",
|
34 |
],
|
35 |
-
"peft": [
|
36 |
-
"peft @ git+https://github.com/huggingface/peft.git",
|
37 |
-
],
|
38 |
},
|
39 |
)
|
|
|
2 |
|
3 |
from setuptools import find_packages, setup
|
4 |
|
5 |
+
|
6 |
+
def parse_requirements():
|
7 |
+
_install_requires = []
|
8 |
+
_dependency_links = []
|
9 |
+
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
10 |
+
lines = [
|
11 |
+
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
|
12 |
+
]
|
13 |
+
for line in lines:
|
14 |
+
if line.startswith("--extra-index-url"):
|
15 |
+
# Handle custom index URLs
|
16 |
+
_, url = line.split()
|
17 |
+
_dependency_links.append(url)
|
18 |
+
elif "flash-attn" not in line and line and line[0] != "#":
|
19 |
+
# Handle standard packages
|
20 |
+
_install_requires.append(line)
|
21 |
+
return _install_requires, _dependency_links
|
22 |
+
|
23 |
+
|
24 |
+
install_requires, dependency_links = parse_requirements()
|
25 |
+
|
26 |
|
27 |
setup(
|
28 |
name="axolotl",
|
|
|
31 |
package_dir={"": "src"},
|
32 |
packages=find_packages(),
|
33 |
install_requires=install_requires,
|
34 |
+
dependency_links=dependency_links,
|
35 |
extras_require={
|
36 |
"gptq": [
|
37 |
+
"auto-gptq",
|
|
|
|
|
|
|
38 |
],
|
39 |
"flash-attn": [
|
40 |
"flash-attn==2.0.8",
|
|
|
42 |
"extras": [
|
43 |
"deepspeed",
|
44 |
],
|
|
|
|
|
|
|
45 |
},
|
46 |
)
|
src/axolotl/utils/config.py
CHANGED
@@ -108,9 +108,7 @@ def validate_config(cfg):
|
|
108 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
109 |
)
|
110 |
if cfg.load_4bit:
|
111 |
-
raise ValueError(
|
112 |
-
"cfg.load_4bit parameter has been deprecated and replaced by cfg.gptq"
|
113 |
-
)
|
114 |
|
115 |
if cfg.adapter == "qlora":
|
116 |
if cfg.merge_lora:
|
|
|
108 |
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
109 |
)
|
110 |
if cfg.load_4bit:
|
111 |
+
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
|
|
|
|
112 |
|
113 |
if cfg.adapter == "qlora":
|
114 |
if cfg.merge_lora:
|
src/axolotl/utils/models.py
CHANGED
@@ -4,19 +4,19 @@
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
7 |
-
from pathlib import Path
|
8 |
from typing import Optional, Tuple # noqa: F401
|
9 |
|
10 |
import bitsandbytes as bnb
|
11 |
import torch
|
12 |
import transformers
|
13 |
from optimum.bettertransformer import BetterTransformer
|
14 |
-
from peft import PeftConfig
|
15 |
from transformers import ( # noqa: F401
|
16 |
AutoConfig,
|
17 |
AutoModelForCausalLM,
|
18 |
AutoTokenizer,
|
19 |
BitsAndBytesConfig,
|
|
|
20 |
LlamaConfig,
|
21 |
PreTrainedModel,
|
22 |
PreTrainedTokenizerBase,
|
@@ -155,32 +155,17 @@ def load_model(
|
|
155 |
LOG.info("patching _expand_mask")
|
156 |
hijack_expand_mask()
|
157 |
|
158 |
-
try:
|
159 |
-
if cfg.gptq:
|
160 |
-
from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
|
161 |
-
replace_peft_model_with_int4_lora_model,
|
162 |
-
)
|
163 |
-
|
164 |
-
replace_peft_model_with_int4_lora_model()
|
165 |
-
except Exception as err:
|
166 |
-
LOG.exception(err)
|
167 |
-
raise err
|
168 |
-
|
169 |
-
if not cfg.gptq and (
|
170 |
-
(cfg.adapter == "lora" and load_in_8bit)
|
171 |
-
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
172 |
-
):
|
173 |
-
try:
|
174 |
-
from peft import prepare_model_for_kbit_training
|
175 |
-
except ImportError:
|
176 |
-
# For backward compatibility
|
177 |
-
from peft import (
|
178 |
-
prepare_model_for_int8_training as prepare_model_for_kbit_training,
|
179 |
-
)
|
180 |
-
|
181 |
model_kwargs = {}
|
182 |
if cfg.model_revision:
|
183 |
model_kwargs["revision"] = cfg.model_revision
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
185 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
186 |
load_in_4bit=True,
|
@@ -191,45 +176,7 @@ def load_model(
|
|
191 |
bnb_4bit_quant_type="nf4",
|
192 |
)
|
193 |
try:
|
194 |
-
if cfg.
|
195 |
-
from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram
|
196 |
-
from huggingface_hub import snapshot_download
|
197 |
-
|
198 |
-
try:
|
199 |
-
snapshot_download_kwargs = {}
|
200 |
-
if cfg.base_model_ignore_patterns:
|
201 |
-
snapshot_download_kwargs[
|
202 |
-
"ignore_patterns"
|
203 |
-
] = cfg.base_model_ignore_patterns
|
204 |
-
cache_model_path = Path(
|
205 |
-
snapshot_download(base_model, **snapshot_download_kwargs)
|
206 |
-
)
|
207 |
-
files = (
|
208 |
-
list(cache_model_path.glob("*.pt"))
|
209 |
-
+ list(cache_model_path.glob("*.safetensors"))
|
210 |
-
+ list(cache_model_path.glob("*.bin"))
|
211 |
-
)
|
212 |
-
if len(files) > 0:
|
213 |
-
model_path = str(files[0])
|
214 |
-
else:
|
215 |
-
LOG.warning(
|
216 |
-
"unable to find a cached model file, this will likely fail..."
|
217 |
-
)
|
218 |
-
model_path = str(cache_model_path)
|
219 |
-
except Exception: # pylint: disable=broad-exception-caught
|
220 |
-
model_path = cfg.base_model
|
221 |
-
model, _ = load_llama_model_4bit_low_ram(
|
222 |
-
base_model_config if base_model_config else base_model,
|
223 |
-
model_path,
|
224 |
-
device_map=cfg.device_map,
|
225 |
-
half=cfg.fp16,
|
226 |
-
groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1,
|
227 |
-
is_v1_model=cfg.gptq_model_v1
|
228 |
-
if cfg.gptq_model_v1 is not None
|
229 |
-
else True,
|
230 |
-
)
|
231 |
-
load_in_8bit = False
|
232 |
-
elif cfg.is_llama_derived_model and not cfg.trust_remote_code:
|
233 |
from transformers import LlamaForCausalLM
|
234 |
|
235 |
config_kwargs = {}
|
@@ -275,15 +222,24 @@ def load_model(
|
|
275 |
# )
|
276 |
# model.train() # sets to train instead of eval mode
|
277 |
elif model_type and not cfg.trust_remote_code:
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
else:
|
288 |
config = AutoConfig.from_pretrained(
|
289 |
base_model,
|
@@ -359,11 +315,12 @@ def load_model(
|
|
359 |
module.to(torch.float32)
|
360 |
|
361 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
362 |
-
if
|
363 |
-
|
364 |
-
or (cfg.adapter == "qlora" and cfg.load_in_4bit)
|
365 |
):
|
366 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
|
|
|
|
367 |
model = prepare_model_for_kbit_training(
|
368 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
369 |
)
|
@@ -385,22 +342,10 @@ def load_model(
|
|
385 |
if cfg.ddp and not load_in_8bit:
|
386 |
model.to(f"cuda:{cfg.local_rank}")
|
387 |
|
388 |
-
if cfg.gptq:
|
389 |
-
# Scales to half
|
390 |
-
LOG.info("Fitting 4bit scales and zeros to half")
|
391 |
-
for _, module in model.named_modules():
|
392 |
-
if "Autograd4bitQuantLinear" in str(type(module)) or "Linear4bitLt" in str(
|
393 |
-
type(module)
|
394 |
-
):
|
395 |
-
if hasattr(module, "is_v1_model") and module.is_v1_model:
|
396 |
-
module.zeros = module.zeros.half()
|
397 |
-
module.scales = module.scales.half()
|
398 |
-
module.bias = module.bias.half()
|
399 |
-
|
400 |
if (
|
401 |
torch.cuda.device_count() > 1
|
402 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
403 |
-
and (cfg.
|
404 |
):
|
405 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
406 |
# so let's only set it for the 4bit, see
|
|
|
4 |
import logging
|
5 |
import math
|
6 |
import os
|
|
|
7 |
from typing import Optional, Tuple # noqa: F401
|
8 |
|
9 |
import bitsandbytes as bnb
|
10 |
import torch
|
11 |
import transformers
|
12 |
from optimum.bettertransformer import BetterTransformer
|
13 |
+
from peft import PeftConfig, prepare_model_for_kbit_training
|
14 |
from transformers import ( # noqa: F401
|
15 |
AutoConfig,
|
16 |
AutoModelForCausalLM,
|
17 |
AutoTokenizer,
|
18 |
BitsAndBytesConfig,
|
19 |
+
GPTQConfig,
|
20 |
LlamaConfig,
|
21 |
PreTrainedModel,
|
22 |
PreTrainedTokenizerBase,
|
|
|
155 |
LOG.info("patching _expand_mask")
|
156 |
hijack_expand_mask()
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
model_kwargs = {}
|
159 |
if cfg.model_revision:
|
160 |
model_kwargs["revision"] = cfg.model_revision
|
161 |
+
if cfg.gptq:
|
162 |
+
model_config = load_model_config(cfg)
|
163 |
+
if hasattr(model_config, "quantization_config"):
|
164 |
+
LOG.warning("model config does not contain quantization_config information")
|
165 |
+
else:
|
166 |
+
model_kwargs["quantization_config"] = GPTQConfig(
|
167 |
+
**model_config.quantization_config
|
168 |
+
)
|
169 |
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
170 |
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
171 |
load_in_4bit=True,
|
|
|
176 |
bnb_4bit_quant_type="nf4",
|
177 |
)
|
178 |
try:
|
179 |
+
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
from transformers import LlamaForCausalLM
|
181 |
|
182 |
config_kwargs = {}
|
|
|
222 |
# )
|
223 |
# model.train() # sets to train instead of eval mode
|
224 |
elif model_type and not cfg.trust_remote_code:
|
225 |
+
if cfg.gptq:
|
226 |
+
model = AutoModelForCausalLM.from_pretrained(
|
227 |
+
base_model,
|
228 |
+
device_map=cfg.device_map,
|
229 |
+
torch_dtype=cfg.torch_dtype,
|
230 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
231 |
+
**model_kwargs,
|
232 |
+
)
|
233 |
+
else:
|
234 |
+
model = getattr(transformers, model_type).from_pretrained(
|
235 |
+
base_model,
|
236 |
+
device_map=cfg.device_map,
|
237 |
+
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
238 |
+
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
239 |
+
torch_dtype=cfg.torch_dtype,
|
240 |
+
trust_remote_code=cfg.trust_remote_code or False,
|
241 |
+
**model_kwargs,
|
242 |
+
)
|
243 |
else:
|
244 |
config = AutoConfig.from_pretrained(
|
245 |
base_model,
|
|
|
315 |
module.to(torch.float32)
|
316 |
|
317 |
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
318 |
+
if (cfg.adapter == "lora" and load_in_8bit) or (
|
319 |
+
cfg.adapter == "qlora" and cfg.load_in_4bit
|
|
|
320 |
):
|
321 |
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
322 |
+
if cfg.gradient_checkpointing:
|
323 |
+
model.gradient_checkpointing_enable()
|
324 |
model = prepare_model_for_kbit_training(
|
325 |
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
326 |
)
|
|
|
342 |
if cfg.ddp and not load_in_8bit:
|
343 |
model.to(f"cuda:{cfg.local_rank}")
|
344 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
if (
|
346 |
torch.cuda.device_count() > 1
|
347 |
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
348 |
+
and (cfg.load_in_4bit)
|
349 |
):
|
350 |
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
351 |
# so let's only set it for the 4bit, see
|
src/axolotl/utils/trainer.py
CHANGED
@@ -514,23 +514,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|
514 |
training_arguments_kwargs["seed"] = cfg.seed
|
515 |
|
516 |
if cfg.gradient_checkpointing:
|
517 |
-
|
518 |
-
from alpaca_lora_4bit.gradient_checkpointing import (
|
519 |
-
apply_gradient_checkpointing,
|
520 |
-
)
|
521 |
-
|
522 |
-
gradient_checkpointing_ratio = (
|
523 |
-
cfg.gradient_checkpointing_ratio
|
524 |
-
if cfg.gradient_checkpointing_ratio
|
525 |
-
else 1.0
|
526 |
-
)
|
527 |
-
apply_gradient_checkpointing(
|
528 |
-
model, checkpoint_ratio=gradient_checkpointing_ratio
|
529 |
-
)
|
530 |
-
else:
|
531 |
-
training_arguments_kwargs[
|
532 |
-
"gradient_checkpointing"
|
533 |
-
] = cfg.gradient_checkpointing
|
534 |
if cfg.fsdp:
|
535 |
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
536 |
if cfg.fsdp_config:
|
|
|
514 |
training_arguments_kwargs["seed"] = cfg.seed
|
515 |
|
516 |
if cfg.gradient_checkpointing:
|
517 |
+
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
518 |
if cfg.fsdp:
|
519 |
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
520 |
if cfg.fsdp_config:
|