Commit
•
22ae21a
1
Parent(s):
ba45531
Add KTO support (#1640)
Browse files* add kto support
* test cleanup
* fix outdated comment
* fix llama3 ultra
* chore: lint
* update to use rl_beta instead of dpo_beta
---------
Co-authored-by: Wing Lian <[email protected]>
- src/axolotl/core/trainer_builder.py +29 -2
- src/axolotl/prompt_strategies/kto/__init__.py +9 -0
- src/axolotl/prompt_strategies/kto/chatml.py +105 -0
- src/axolotl/prompt_strategies/kto/llama3.py +105 -0
- src/axolotl/prompt_strategies/kto/user_defined.py +39 -0
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +42 -2
- src/axolotl/utils/data/rl.py +28 -12
- src/axolotl/utils/models.py +5 -1
- src/axolotl/utils/trainer.py +1 -1
- tests/e2e/test_dpo.py +63 -0
- tests/test_validation.py +9 -0
src/axolotl/core/trainer_builder.py
CHANGED
@@ -30,7 +30,7 @@ from transformers import (
|
|
30 |
)
|
31 |
from transformers.trainer_utils import seed_worker
|
32 |
from transformers.utils import is_sagemaker_mp_enabled
|
33 |
-
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
34 |
from trl.trainer.utils import pad_to_length
|
35 |
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
@@ -826,6 +826,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
|
|
826 |
tag_names = ["axolotl", "orpo"]
|
827 |
|
828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
829 |
class TrainerBuilderBase(abc.ABC):
|
830 |
"""
|
831 |
Base class for trainer builder
|
@@ -1532,6 +1540,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1532 |
if self.cfg.max_prompt_len:
|
1533 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1535 |
training_args = training_args_cls(
|
1536 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1537 |
max_steps=self.cfg.max_steps or total_num_steps,
|
@@ -1567,7 +1591,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1567 |
] = self.cfg.precompute_ref_log_probs
|
1568 |
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
1569 |
trainer_cls = AxolotlDPOTrainer
|
1570 |
-
dpo_trainer_kwargs["beta"] = self.cfg.
|
1571 |
trainer_cls_args = [self.model, self.model_ref]
|
1572 |
|
1573 |
# these aren't used for the ORPO trainer
|
@@ -1580,6 +1604,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|
1580 |
elif self.cfg.rl == "orpo":
|
1581 |
trainer_cls = AxolotlORPOTrainer
|
1582 |
trainer_cls_args = [self.model]
|
|
|
|
|
|
|
1583 |
else:
|
1584 |
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
1585 |
dpo_trainer = trainer_cls(
|
|
|
30 |
)
|
31 |
from transformers.trainer_utils import seed_worker
|
32 |
from transformers.utils import is_sagemaker_mp_enabled
|
33 |
+
from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
|
34 |
from trl.trainer.utils import pad_to_length
|
35 |
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
|
826 |
tag_names = ["axolotl", "orpo"]
|
827 |
|
828 |
|
829 |
+
class AxolotlKTOTrainer(KTOTrainer):
|
830 |
+
"""
|
831 |
+
Extend the base KTOTrainer for axolotl helpers
|
832 |
+
"""
|
833 |
+
|
834 |
+
tag_names = ["axolotl", "kto"]
|
835 |
+
|
836 |
+
|
837 |
class TrainerBuilderBase(abc.ABC):
|
838 |
"""
|
839 |
Base class for trainer builder
|
|
|
1540 |
if self.cfg.max_prompt_len:
|
1541 |
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1542 |
|
1543 |
+
if self.cfg.rl == "kto":
|
1544 |
+
training_args_cls = KTOConfig
|
1545 |
+
|
1546 |
+
training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
1547 |
+
training_args_kwargs["desirable_weight"] = (
|
1548 |
+
self.cfg.kto_desirable_weight or 1.0
|
1549 |
+
)
|
1550 |
+
training_args_kwargs["undesirable_weight"] = (
|
1551 |
+
self.cfg.kto_undesirable_weight or 1.0
|
1552 |
+
)
|
1553 |
+
|
1554 |
+
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
1555 |
+
training_args_kwargs["max_length"] = self.cfg.sequence_len
|
1556 |
+
if self.cfg.max_prompt_len:
|
1557 |
+
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
|
1558 |
+
|
1559 |
training_args = training_args_cls(
|
1560 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1561 |
max_steps=self.cfg.max_steps or total_num_steps,
|
|
|
1591 |
] = self.cfg.precompute_ref_log_probs
|
1592 |
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
1593 |
trainer_cls = AxolotlDPOTrainer
|
1594 |
+
dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
|
1595 |
trainer_cls_args = [self.model, self.model_ref]
|
1596 |
|
1597 |
# these aren't used for the ORPO trainer
|
|
|
1604 |
elif self.cfg.rl == "orpo":
|
1605 |
trainer_cls = AxolotlORPOTrainer
|
1606 |
trainer_cls_args = [self.model]
|
1607 |
+
elif self.cfg.rl == "kto":
|
1608 |
+
trainer_cls = AxolotlKTOTrainer
|
1609 |
+
trainer_cls_args = [self.model]
|
1610 |
else:
|
1611 |
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
1612 |
dpo_trainer = trainer_cls(
|
src/axolotl/prompt_strategies/kto/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
module for KTO style dataset transform strategies
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
|
7 |
+
from ..base import load as load_base
|
8 |
+
|
9 |
+
load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
|
src/axolotl/prompt_strategies/kto/chatml.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
KTO strategies for chatml
|
3 |
+
"""
|
4 |
+
# pylint: disable=duplicate-code
|
5 |
+
|
6 |
+
|
7 |
+
def argilla(
|
8 |
+
cfg,
|
9 |
+
**kwargs,
|
10 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
11 |
+
def transform_fn(sample):
|
12 |
+
if "system" in sample and sample["system"]:
|
13 |
+
sample["prompt"] = (
|
14 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
15 |
+
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
16 |
+
)
|
17 |
+
else:
|
18 |
+
sample[
|
19 |
+
"prompt"
|
20 |
+
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
|
21 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
22 |
+
return sample
|
23 |
+
|
24 |
+
return transform_fn
|
25 |
+
|
26 |
+
|
27 |
+
def argilla_chat(
|
28 |
+
cfg,
|
29 |
+
**kwargs,
|
30 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
31 |
+
"""
|
32 |
+
for argilla/kto-mix-15k conversations
|
33 |
+
"""
|
34 |
+
|
35 |
+
def transform_fn(sample):
|
36 |
+
sample[
|
37 |
+
"prompt"
|
38 |
+
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
39 |
+
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
|
40 |
+
return sample
|
41 |
+
|
42 |
+
return transform_fn
|
43 |
+
|
44 |
+
|
45 |
+
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
46 |
+
"""
|
47 |
+
For Intel Orca KTO
|
48 |
+
ex: argilla/distilabel-intel-orca-kto
|
49 |
+
"""
|
50 |
+
|
51 |
+
def transform_fn(sample):
|
52 |
+
if "system" in sample and sample["system"]:
|
53 |
+
sample["prompt"] = (
|
54 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
55 |
+
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
sample[
|
59 |
+
"prompt"
|
60 |
+
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
61 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
62 |
+
return sample
|
63 |
+
|
64 |
+
return transform_fn
|
65 |
+
|
66 |
+
|
67 |
+
def prompt_pairs(
|
68 |
+
cfg, **kwargs
|
69 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
70 |
+
def transform_fn(sample):
|
71 |
+
if "system" in sample and sample["system"]:
|
72 |
+
sample["prompt"] = (
|
73 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
74 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
sample[
|
78 |
+
"prompt"
|
79 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
80 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
81 |
+
return sample
|
82 |
+
|
83 |
+
return transform_fn
|
84 |
+
|
85 |
+
|
86 |
+
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
87 |
+
"""
|
88 |
+
for ultrafeedback binarized conversations
|
89 |
+
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
90 |
+
"""
|
91 |
+
|
92 |
+
def transform_fn(sample):
|
93 |
+
if "system" in sample and sample["system"]:
|
94 |
+
sample["prompt"] = (
|
95 |
+
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
|
96 |
+
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
sample[
|
100 |
+
"prompt"
|
101 |
+
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
102 |
+
sample["completion"] = f"{sample['completion']}<|im_end|>"
|
103 |
+
return sample
|
104 |
+
|
105 |
+
return transform_fn
|
src/axolotl/prompt_strategies/kto/llama3.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
KTO strategies for llama-3 chat template
|
3 |
+
"""
|
4 |
+
# pylint: disable=duplicate-code
|
5 |
+
|
6 |
+
|
7 |
+
def argilla(
|
8 |
+
cfg,
|
9 |
+
**kwargs,
|
10 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
11 |
+
def transform_fn(sample):
|
12 |
+
if "system" in sample and sample["system"]:
|
13 |
+
sample["prompt"] = (
|
14 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
15 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
16 |
+
)
|
17 |
+
else:
|
18 |
+
sample[
|
19 |
+
"prompt"
|
20 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
21 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
22 |
+
return sample
|
23 |
+
|
24 |
+
return transform_fn
|
25 |
+
|
26 |
+
|
27 |
+
def argilla_chat(
|
28 |
+
cfg,
|
29 |
+
**kwargs,
|
30 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
31 |
+
"""
|
32 |
+
for argilla/kto-mix-15k conversations
|
33 |
+
"""
|
34 |
+
|
35 |
+
def transform_fn(sample):
|
36 |
+
sample[
|
37 |
+
"prompt"
|
38 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
39 |
+
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
|
40 |
+
return sample
|
41 |
+
|
42 |
+
return transform_fn
|
43 |
+
|
44 |
+
|
45 |
+
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
46 |
+
"""
|
47 |
+
For Intel Orca KTO
|
48 |
+
ex: argilla/distilabel-intel-orca-kto
|
49 |
+
"""
|
50 |
+
|
51 |
+
def transform_fn(sample):
|
52 |
+
if "system" in sample and sample["system"]:
|
53 |
+
sample["prompt"] = (
|
54 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
55 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
56 |
+
)
|
57 |
+
else:
|
58 |
+
sample[
|
59 |
+
"prompt"
|
60 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
61 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
62 |
+
return sample
|
63 |
+
|
64 |
+
return transform_fn
|
65 |
+
|
66 |
+
|
67 |
+
def prompt_pairs(
|
68 |
+
cfg, **kwargs
|
69 |
+
): # pylint: disable=possibly-unused-variable,unused-argument
|
70 |
+
def transform_fn(sample):
|
71 |
+
if "system" in sample and sample["system"]:
|
72 |
+
sample["prompt"] = (
|
73 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
74 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
sample[
|
78 |
+
"prompt"
|
79 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
80 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
81 |
+
return sample
|
82 |
+
|
83 |
+
return transform_fn
|
84 |
+
|
85 |
+
|
86 |
+
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
87 |
+
"""
|
88 |
+
for ultrafeedback binarized conversations
|
89 |
+
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
|
90 |
+
"""
|
91 |
+
|
92 |
+
def transform_fn(sample):
|
93 |
+
if "system" in sample and sample["system"]:
|
94 |
+
sample["prompt"] = (
|
95 |
+
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
|
96 |
+
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
97 |
+
)
|
98 |
+
else:
|
99 |
+
sample[
|
100 |
+
"prompt"
|
101 |
+
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
102 |
+
sample["completion"] = f"{sample['completion']}<|eot_id|>"
|
103 |
+
return sample
|
104 |
+
|
105 |
+
return transform_fn
|
src/axolotl/prompt_strategies/kto/user_defined.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
User-defined KTO strategies
|
3 |
+
"""
|
4 |
+
# pylint: disable=duplicate-code
|
5 |
+
|
6 |
+
|
7 |
+
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
|
8 |
+
ds_cfg = cfg["datasets"][dataset_idx]["type"]
|
9 |
+
if not isinstance(ds_cfg, dict):
|
10 |
+
raise ValueError(
|
11 |
+
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
|
12 |
+
)
|
13 |
+
field_prompt = ds_cfg.get("field_prompt", "prompt")
|
14 |
+
field_system = ds_cfg.get("field_system", "system")
|
15 |
+
field_completion = ds_cfg.get("field_completion", "completion")
|
16 |
+
field_label = ds_cfg.get("field_label", "label")
|
17 |
+
prompt_format = ds_cfg.get("prompt_format")
|
18 |
+
if not prompt_format:
|
19 |
+
prompt_format = "{" + field_prompt + "}"
|
20 |
+
completion_format = ds_cfg.get("completion_format")
|
21 |
+
if not completion_format:
|
22 |
+
chosen_format = "{" + field_completion + "}"
|
23 |
+
|
24 |
+
def transform_fn(sample):
|
25 |
+
if (
|
26 |
+
"{" + field_system + "}" in prompt_format
|
27 |
+
and field_system in sample
|
28 |
+
and sample[field_system]
|
29 |
+
):
|
30 |
+
sample["prompt"] = prompt_format.format(
|
31 |
+
system=sample[field_system], prompt=sample[field_prompt]
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
|
35 |
+
sample["completion"] = chosen_format.format(chosen=sample[field_completion])
|
36 |
+
sample["label"] = sample[field_label]
|
37 |
+
return sample
|
38 |
+
|
39 |
+
return transform_fn
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel):
|
|
24 |
max_packed_sequence_len: Optional[int] = None
|
25 |
rope_scaling: Optional[Any] = None
|
26 |
noisy_embedding_alpha: Optional[float] = None
|
|
|
27 |
|
28 |
@field_validator("max_packed_sequence_len")
|
29 |
@classmethod
|
@@ -48,6 +49,13 @@ class DeprecatedParameters(BaseModel):
|
|
48 |
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
49 |
return noisy_embedding_alpha
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
class RemappedParameters(BaseModel):
|
53 |
"""parameters that have been remapped to other names"""
|
@@ -126,6 +134,26 @@ class DPODataset(BaseModel):
|
|
126 |
data_files: Optional[List[str]] = None
|
127 |
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
class RLType(str, Enum):
|
130 |
"""RL trainer type configuration subset"""
|
131 |
|
@@ -133,6 +161,7 @@ class RLType(str, Enum):
|
|
133 |
ipo = "ipo" # pylint: disable=invalid-name
|
134 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
135 |
orpo = "orpo" # pylint: disable=invalid-name
|
|
|
136 |
|
137 |
|
138 |
class ChatTemplate(str, Enum):
|
@@ -450,8 +479,8 @@ class AxolotlInputConfig(
|
|
450 |
|
451 |
rl: Optional[RLType] = None
|
452 |
|
453 |
-
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
454 |
-
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
455 |
shuffle_merged_datasets: Optional[bool] = True
|
456 |
dataset_prepared_path: Optional[str] = None
|
457 |
dataset_shard_num: Optional[int] = None
|
@@ -585,6 +614,10 @@ class AxolotlInputConfig(
|
|
585 |
|
586 |
orpo_alpha: Optional[float] = None
|
587 |
|
|
|
|
|
|
|
|
|
588 |
max_memory: Optional[
|
589 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
590 |
] = None
|
@@ -884,6 +917,13 @@ class AxolotlInputConfig(
|
|
884 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
885 |
return neftune_noise_alpha
|
886 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
887 |
@model_validator(mode="before")
|
888 |
@classmethod
|
889 |
def check_frozen(cls, data):
|
|
|
24 |
max_packed_sequence_len: Optional[int] = None
|
25 |
rope_scaling: Optional[Any] = None
|
26 |
noisy_embedding_alpha: Optional[float] = None
|
27 |
+
dpo_beta: Optional[float] = None
|
28 |
|
29 |
@field_validator("max_packed_sequence_len")
|
30 |
@classmethod
|
|
|
49 |
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
50 |
return noisy_embedding_alpha
|
51 |
|
52 |
+
@field_validator("dpo_beta")
|
53 |
+
@classmethod
|
54 |
+
def validate_dpo_beta(cls, dpo_beta):
|
55 |
+
if dpo_beta is not None:
|
56 |
+
LOG.warning("dpo_beta is deprecated, use rl_beta instead")
|
57 |
+
return dpo_beta
|
58 |
+
|
59 |
|
60 |
class RemappedParameters(BaseModel):
|
61 |
"""parameters that have been remapped to other names"""
|
|
|
134 |
data_files: Optional[List[str]] = None
|
135 |
|
136 |
|
137 |
+
class UserDefinedKTOType(BaseModel):
|
138 |
+
"""User defined typing for KTO"""
|
139 |
+
|
140 |
+
field_system: Optional[str] = None
|
141 |
+
field_prompt: Optional[str] = None
|
142 |
+
field_completion: Optional[str] = None
|
143 |
+
field_label: Optional[bool] = None
|
144 |
+
prompt_format: Optional[str] = None
|
145 |
+
completion_format: Optional[str] = None
|
146 |
+
|
147 |
+
|
148 |
+
class KTODataset(BaseModel):
|
149 |
+
"""KTO configuration subset"""
|
150 |
+
|
151 |
+
path: Optional[str] = None
|
152 |
+
split: Optional[str] = None
|
153 |
+
type: Optional[Union[UserDefinedKTOType, str]] = None
|
154 |
+
data_files: Optional[List[str]] = None
|
155 |
+
|
156 |
+
|
157 |
class RLType(str, Enum):
|
158 |
"""RL trainer type configuration subset"""
|
159 |
|
|
|
161 |
ipo = "ipo" # pylint: disable=invalid-name
|
162 |
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
163 |
orpo = "orpo" # pylint: disable=invalid-name
|
164 |
+
kto = "kto" # pylint: disable=invalid-name
|
165 |
|
166 |
|
167 |
class ChatTemplate(str, Enum):
|
|
|
479 |
|
480 |
rl: Optional[RLType] = None
|
481 |
|
482 |
+
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
483 |
+
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
484 |
shuffle_merged_datasets: Optional[bool] = True
|
485 |
dataset_prepared_path: Optional[str] = None
|
486 |
dataset_shard_num: Optional[int] = None
|
|
|
614 |
|
615 |
orpo_alpha: Optional[float] = None
|
616 |
|
617 |
+
kto_desirable_weight: Optional[float] = None
|
618 |
+
kto_undesirable_weight: Optional[float] = None
|
619 |
+
rl_beta: Optional[float] = None
|
620 |
+
|
621 |
max_memory: Optional[
|
622 |
Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
|
623 |
] = None
|
|
|
917 |
raise ValueError("neftune_noise_alpha must be > 0.0")
|
918 |
return neftune_noise_alpha
|
919 |
|
920 |
+
@model_validator(mode="after")
|
921 |
+
def check(self):
|
922 |
+
if self.dpo_beta and not self.rl_beta:
|
923 |
+
self.rl_beta = self.dpo_beta
|
924 |
+
del self.dpo_beta
|
925 |
+
return self
|
926 |
+
|
927 |
@model_validator(mode="before")
|
928 |
@classmethod
|
929 |
def check_frozen(cls, data):
|
src/axolotl/utils/data/rl.py
CHANGED
@@ -10,6 +10,7 @@ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_
|
|
10 |
|
11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
|
13 |
from axolotl.prompt_strategies.orpo import load as load_orpo
|
14 |
from axolotl.utils.data.utils import md5
|
15 |
from axolotl.utils.dict import DictDefault
|
@@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|
55 |
dataset.save_to_disk(str(prepared_ds_path))
|
56 |
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def load_prepare_dpo_datasets(cfg):
|
59 |
def load_split(dataset_cfgs, _cfg):
|
60 |
split_datasets: List[Any] = []
|
@@ -76,6 +93,7 @@ def load_prepare_dpo_datasets(cfg):
|
|
76 |
split_datasets.insert(i, ds)
|
77 |
|
78 |
tokenizer = None
|
|
|
79 |
for i, data_set in enumerate(split_datasets):
|
80 |
_type = dataset_cfgs[i]["type"]
|
81 |
if _type:
|
@@ -83,21 +101,19 @@ def load_prepare_dpo_datasets(cfg):
|
|
83 |
_type = "user_defined.default"
|
84 |
if _cfg.rl == "orpo":
|
85 |
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
|
|
|
|
86 |
else:
|
87 |
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
ds_transform_fn,
|
96 |
-
desc="Mapping RL Dataset",
|
97 |
)
|
98 |
-
if isinstance(data_set, DatasetDict):
|
99 |
-
data_set = data_set["train"]
|
100 |
-
split_datasets[i] = data_set
|
101 |
else:
|
102 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
103 |
# "prompt", "chosen" and "rejected" already preprocessed
|
|
|
10 |
|
11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
13 |
+
from axolotl.prompt_strategies.kto import load as load_kto
|
14 |
from axolotl.prompt_strategies.orpo import load as load_orpo
|
15 |
from axolotl.utils.data.utils import md5
|
16 |
from axolotl.utils.dict import DictDefault
|
|
|
56 |
dataset.save_to_disk(str(prepared_ds_path))
|
57 |
|
58 |
|
59 |
+
def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
|
60 |
+
sig = inspect.signature(ds_transform_fn)
|
61 |
+
if "tokenizer" in sig.parameters:
|
62 |
+
if not tokenizer:
|
63 |
+
tokenizer = load_tokenizer(cfg)
|
64 |
+
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
65 |
+
|
66 |
+
data_set = data_set.map(
|
67 |
+
ds_transform_fn,
|
68 |
+
desc="Mapping RL Dataset",
|
69 |
+
)
|
70 |
+
if isinstance(data_set, DatasetDict):
|
71 |
+
data_set = data_set["train"]
|
72 |
+
return data_set
|
73 |
+
|
74 |
+
|
75 |
def load_prepare_dpo_datasets(cfg):
|
76 |
def load_split(dataset_cfgs, _cfg):
|
77 |
split_datasets: List[Any] = []
|
|
|
93 |
split_datasets.insert(i, ds)
|
94 |
|
95 |
tokenizer = None
|
96 |
+
|
97 |
for i, data_set in enumerate(split_datasets):
|
98 |
_type = dataset_cfgs[i]["type"]
|
99 |
if _type:
|
|
|
101 |
_type = "user_defined.default"
|
102 |
if _cfg.rl == "orpo":
|
103 |
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
104 |
+
elif _cfg.rl == "kto":
|
105 |
+
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
106 |
else:
|
107 |
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
108 |
+
|
109 |
+
split_datasets[i] = map_dataset(
|
110 |
+
cfg, data_set, ds_transform_fn, tokenizer
|
111 |
+
)
|
112 |
+
elif _cfg.rl == "kto":
|
113 |
+
ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
|
114 |
+
split_datasets[i] = map_dataset(
|
115 |
+
cfg, data_set, ds_transform_fn, tokenizer
|
|
|
116 |
)
|
|
|
|
|
|
|
117 |
else:
|
118 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
119 |
# "prompt", "chosen" and "rejected" already preprocessed
|
src/axolotl/utils/models.py
CHANGED
@@ -803,7 +803,11 @@ def load_model(
|
|
803 |
if not reference_model or cfg.lora_model_dir:
|
804 |
# if we're not loading the reference model, then we're loading the model for training
|
805 |
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
806 |
-
if
|
|
|
|
|
|
|
|
|
807 |
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
808 |
else:
|
809 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
|
|
803 |
if not reference_model or cfg.lora_model_dir:
|
804 |
# if we're not loading the reference model, then we're loading the model for training
|
805 |
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
806 |
+
if (
|
807 |
+
cfg.adapter
|
808 |
+
and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
|
809 |
+
and not cfg.merge_lora
|
810 |
+
):
|
811 |
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
812 |
else:
|
813 |
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
src/axolotl/utils/trainer.py
CHANGED
@@ -428,7 +428,7 @@ def prepare_optim_env(cfg):
|
|
428 |
|
429 |
|
430 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
431 |
-
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
432 |
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
433 |
trainer_builder.model_ref = model[1]
|
434 |
trainer_builder.peft_config = model[2]
|
|
|
428 |
|
429 |
|
430 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
431 |
+
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
|
432 |
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
433 |
trainer_builder.model_ref = model[1]
|
434 |
trainer_builder.peft_config = model[2]
|
tests/e2e/test_dpo.py
CHANGED
@@ -205,3 +205,66 @@ class TestDPOLlamaLora(unittest.TestCase):
|
|
205 |
|
206 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
207 |
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
207 |
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
208 |
+
|
209 |
+
@with_temp_dir
|
210 |
+
def test_kto_lora(self, temp_dir):
|
211 |
+
# pylint: disable=duplicate-code
|
212 |
+
cfg = DictDefault(
|
213 |
+
{
|
214 |
+
"base_model": "JackFram/llama-68m",
|
215 |
+
"tokenizer_type": "LlamaTokenizer",
|
216 |
+
"sequence_len": 1024,
|
217 |
+
"load_in_8bit": True,
|
218 |
+
"adapter": "lora",
|
219 |
+
"lora_r": 64,
|
220 |
+
"lora_alpha": 32,
|
221 |
+
"lora_dropout": 0.1,
|
222 |
+
"lora_target_linear": True,
|
223 |
+
"special_tokens": {},
|
224 |
+
"rl": "kto",
|
225 |
+
"rl_beta": 0.5,
|
226 |
+
"kto_desirable_weight": 1.0,
|
227 |
+
"kto_undesirable_weight": 1.0,
|
228 |
+
"remove_unused_columns": False,
|
229 |
+
"datasets": [
|
230 |
+
# {
|
231 |
+
# "path": "argilla/kto-mix-15k",
|
232 |
+
# "type": "chatml.argilla_chat",
|
233 |
+
# "split": "train",
|
234 |
+
# },
|
235 |
+
{
|
236 |
+
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
237 |
+
"type": "chatml.ultra",
|
238 |
+
"split": "train",
|
239 |
+
},
|
240 |
+
# {
|
241 |
+
# "path": "argilla/kto-mix-15k",
|
242 |
+
# "type": "llama3.argilla_chat",
|
243 |
+
# "split": "train",
|
244 |
+
# },
|
245 |
+
{
|
246 |
+
"path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
|
247 |
+
"type": "llama3.ultra",
|
248 |
+
"split": "train",
|
249 |
+
},
|
250 |
+
],
|
251 |
+
"num_epochs": 1,
|
252 |
+
"micro_batch_size": 4,
|
253 |
+
"gradient_accumulation_steps": 1,
|
254 |
+
"output_dir": temp_dir,
|
255 |
+
"learning_rate": 0.00001,
|
256 |
+
"optimizer": "paged_adamw_8bit",
|
257 |
+
"lr_scheduler": "cosine",
|
258 |
+
"max_steps": 20,
|
259 |
+
"save_steps": 10,
|
260 |
+
"warmup_steps": 5,
|
261 |
+
"gradient_checkpointing": True,
|
262 |
+
"gradient_checkpointing_kwargs": {"use_reentrant": True},
|
263 |
+
}
|
264 |
+
)
|
265 |
+
normalize_config(cfg)
|
266 |
+
cli_args = TrainerCliArgs()
|
267 |
+
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
268 |
+
|
269 |
+
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
270 |
+
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
tests/test_validation.py
CHANGED
@@ -1117,6 +1117,15 @@ class TestValidation(BaseValidation):
|
|
1117 |
validate_config(cfg)
|
1118 |
assert len(self._caplog.records) == 0
|
1119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1120 |
|
1121 |
class TestValidationCheckModelConfig(BaseValidation):
|
1122 |
"""
|
|
|
1117 |
validate_config(cfg)
|
1118 |
assert len(self._caplog.records) == 0
|
1119 |
|
1120 |
+
def test_dpo_beta_deprecation(self, minimal_cfg):
|
1121 |
+
cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
|
1122 |
+
|
1123 |
+
with self._caplog.at_level(logging.WARNING):
|
1124 |
+
new_cfg = validate_config(cfg)
|
1125 |
+
assert new_cfg["rl_beta"] == 0.2
|
1126 |
+
assert new_cfg["dpo_beta"] is None
|
1127 |
+
assert len(self._caplog.records) == 1
|
1128 |
+
|
1129 |
|
1130 |
class TestValidationCheckModelConfig(BaseValidation):
|
1131 |
"""
|