winglian commited on
Commit
7d1d22f
β€’
1 Parent(s): 0e8f340

ORPO Trainer replacement (#1551)

Browse files

* WIP use trl ORPOTrainer

* fixes to make orpo work with trl

* fix the chat template laoding

* make sure to handle the special tokens and add_generation for assistant turn too

requirements.txt CHANGED
@@ -39,6 +39,6 @@ s3fs
39
  gcsfs
40
  # adlfs
41
 
42
- trl @ git+https://github.com/huggingface/trl.git@0ee349dcd43b0f4b3169449f16751c38ac4a609f
43
  zstandard==0.22.0
44
  fastcore
 
39
  gcsfs
40
  # adlfs
41
 
42
+ trl==0.8.5
43
  zstandard==0.22.0
44
  fastcore
src/axolotl/cli/preprocess.py CHANGED
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
54
  LOG.warning(msg)
55
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
56
 
57
- if parsed_cfg.rl and parsed_cfg.rl != "orpo":
58
  load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
59
  else:
60
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
54
  LOG.warning(msg)
55
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
56
 
57
+ if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
58
  load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
59
  else:
60
  load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
src/axolotl/cli/train.py CHANGED
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
47
  else:
48
  register_chatml_template()
49
 
50
- if cfg.rl and cfg.rl != "orpo":
51
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
52
  else:
53
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
 
47
  else:
48
  register_chatml_template()
49
 
50
+ if cfg.rl: # and cfg.rl != "orpo":
51
  dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
52
  else:
53
  dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
src/axolotl/core/trainer_builder.py CHANGED
@@ -30,7 +30,7 @@ from transformers import (
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
- from trl import DPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
@@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer):
810
  return res
811
 
812
 
 
 
 
 
 
 
 
 
813
  class TrainerBuilderBase(abc.ABC):
814
  """
815
  Base class for trainer builder
@@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
1404
  )
1405
 
1406
 
1407
- class HFDPOTrainerBuilder(TrainerBuilderBase):
1408
  """
1409
  Trainer factory class for DPO Trainer
1410
  """
@@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1497
  # default to saving each epoch if not defined
1498
  training_args_kwargs["save_strategy"] = "epoch"
1499
 
1500
- training_args = TrainingArguments(
 
 
 
 
 
 
 
 
1501
  per_device_train_batch_size=self.cfg.micro_batch_size,
1502
  max_steps=self.cfg.max_steps or total_num_steps,
1503
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
@@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1530
  dpo_trainer_kwargs[
1531
  "precompute_ref_log_probs"
1532
  ] = self.cfg.precompute_ref_log_probs
1533
- dpo_trainer = AxolotlDPOTrainer(
1534
- self.model,
1535
- self.model_ref,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1536
  args=training_args,
1537
- beta=self.cfg.dpo_beta or 0.1,
1538
  train_dataset=self.train_dataset,
1539
  tokenizer=self.tokenizer,
1540
- max_length=self.cfg.sequence_len,
1541
- max_target_length=None,
1542
- max_prompt_length=self.cfg.sequence_len,
1543
- generate_during_eval=True,
1544
  callbacks=self.get_callbacks(),
1545
  **dpo_trainer_kwargs,
1546
  )
 
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
+ from trl import DPOTrainer, ORPOConfig, ORPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
 
810
  return res
811
 
812
 
813
+ class AxolotlORPOTrainer(ORPOTrainer):
814
+ """
815
+ Extend the base ORPOTrainer for axolotl helpers
816
+ """
817
+
818
+ tag_names = ["axolotl", "orpo"]
819
+
820
+
821
  class TrainerBuilderBase(abc.ABC):
822
  """
823
  Base class for trainer builder
 
1412
  )
1413
 
1414
 
1415
+ class HFRLTrainerBuilder(TrainerBuilderBase):
1416
  """
1417
  Trainer factory class for DPO Trainer
1418
  """
 
1505
  # default to saving each epoch if not defined
1506
  training_args_kwargs["save_strategy"] = "epoch"
1507
 
1508
+ if self.cfg.orpo_alpha:
1509
+ # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
1510
+ training_args_kwargs["beta"] = self.cfg.orpo_alpha
1511
+
1512
+ training_args_cls = TrainingArguments
1513
+ if self.cfg.rl == "orpo":
1514
+ training_args_cls = ORPOConfig
1515
+
1516
+ training_args = training_args_cls(
1517
  per_device_train_batch_size=self.cfg.micro_batch_size,
1518
  max_steps=self.cfg.max_steps or total_num_steps,
1519
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
 
1546
  dpo_trainer_kwargs[
1547
  "precompute_ref_log_probs"
1548
  ] = self.cfg.precompute_ref_log_probs
1549
+ if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
1550
+ trainer_cls = AxolotlDPOTrainer
1551
+ dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
1552
+ trainer_cls_args = [self.model, self.model_ref]
1553
+
1554
+ # these aren't used for the ORPO trainer
1555
+ dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
1556
+ dpo_trainer_kwargs["max_target_length"] = None
1557
+ dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
1558
+ dpo_trainer_kwargs["generate_during_eval"] = True
1559
+ elif self.cfg.rl == "orpo":
1560
+ trainer_cls = AxolotlORPOTrainer
1561
+ trainer_cls_args = [self.model]
1562
+ else:
1563
+ raise ValueError(f"Unsupported RL: {self.cfg.rl}")
1564
+ dpo_trainer = trainer_cls(
1565
+ *trainer_cls_args,
1566
  args=training_args,
 
1567
  train_dataset=self.train_dataset,
1568
  tokenizer=self.tokenizer,
 
 
 
 
1569
  callbacks=self.get_callbacks(),
1570
  **dpo_trainer_kwargs,
1571
  )
src/axolotl/prompt_strategies/orpo/__init__.py CHANGED
@@ -6,4 +6,4 @@ from functools import partial
6
 
7
  from ..base import load as load_base
8
 
9
- load = partial(load_base, module="axolotl.prompt_strategies.orpo")
 
6
 
7
  from ..base import load as load_base
8
 
9
+ load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
src/axolotl/prompt_strategies/orpo/chat_template.py CHANGED
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
78
  )
79
  return MessageList(messages=messages)
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  class ORPOTokenizingStrategy(PromptTokenizingStrategy):
83
  """
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
186
  chat_template=self.chat_template,
187
  tokenize=False,
188
  ), True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
  return MessageList(messages=messages)
80
 
81
+ def get_prompt(self, prompt) -> MessageList:
82
+ """Map the data to extract everything up to the last turn"""
83
+ total_msg_len = len(prompt["chosen"])
84
+ total_msg_turns, remainder = divmod(total_msg_len, 2)
85
+ assert remainder == 0, "invalid number of turns"
86
+
87
+ messages: List[Message] = []
88
+ if system := prompt.get("system", None):
89
+ messages.append(Message(role="system", content=system, label=False))
90
+ for i in range(total_msg_turns):
91
+ if "prompt" in prompt:
92
+ messages.append(
93
+ Message(role="user", content=prompt["prompt"], label=False)
94
+ )
95
+ else:
96
+ messages.append(
97
+ Message(
98
+ role="user",
99
+ content=prompt["chosen"][i * 2]["content"],
100
+ label=False,
101
+ )
102
+ )
103
+ if i < total_msg_turns - 1:
104
+ messages.append(
105
+ Message(
106
+ role="assistant",
107
+ content=prompt["chosen"][i * 2 + 1]["content"],
108
+ label=False,
109
+ )
110
+ )
111
+
112
+ return MessageList(messages=messages)
113
+
114
+ def get_chosen(self, prompt) -> MessageList:
115
+ res = self.get_prompt(prompt)
116
+ res.messages.append(
117
+ Message(
118
+ role="assistant", content=prompt["chosen"][-1]["content"], label=True
119
+ )
120
+ )
121
+ return res
122
+
123
+ def get_rejected(self, prompt) -> MessageList:
124
+ res = self.get_prompt(prompt)
125
+ res.messages.append(
126
+ Message(
127
+ role="assistant", content=prompt["rejected"][-1]["content"], label=True
128
+ )
129
+ )
130
+ return res
131
+
132
 
133
  class ORPOTokenizingStrategy(PromptTokenizingStrategy):
134
  """
 
237
  chat_template=self.chat_template,
238
  tokenize=False,
239
  ), True
240
+
241
+
242
+ def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
243
+ dataset_parser = ORPODatasetParsingStrategy()
244
+
245
+ chat_template_str = chat_templates(cfg.chat_template)
246
+
247
+ def transform_fn(sample, tokenizer=None):
248
+ res = {}
249
+
250
+ res["prompt"] = tokenizer.apply_chat_template(
251
+ [msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
252
+ add_generation_prompt=True,
253
+ chat_template=chat_template_str,
254
+ tokenize=False,
255
+ )
256
+ prompt_str_len = len(res["prompt"])
257
+ res["chosen"] = tokenizer.apply_chat_template(
258
+ [msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
259
+ add_generation_prompt=False,
260
+ chat_template=chat_template_str,
261
+ tokenize=False,
262
+ )[prompt_str_len:]
263
+ res["rejected"] = tokenizer.apply_chat_template(
264
+ [msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
265
+ add_generation_prompt=False,
266
+ chat_template=chat_template_str,
267
+ tokenize=False,
268
+ )[prompt_str_len:]
269
+
270
+ return res
271
+
272
+ return transform_fn
src/axolotl/utils/data/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
  Data processing modules
3
  """
4
- from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
5
  from axolotl.utils.data.pretraining import ( # noqa: F401
6
  encode_pretraining,
7
  wrap_pretraining_dataset,
8
  )
 
9
  from axolotl.utils.data.sft import ( # noqa: F401
10
  get_dataset_wrapper,
11
  load_prepare_datasets,
 
1
  """
2
  Data processing modules
3
  """
 
4
  from axolotl.utils.data.pretraining import ( # noqa: F401
5
  encode_pretraining,
6
  wrap_pretraining_dataset,
7
  )
8
+ from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
9
  from axolotl.utils.data.sft import ( # noqa: F401
10
  get_dataset_wrapper,
11
  load_prepare_datasets,
src/axolotl/utils/data/{dpo.py β†’ rl.py} RENAMED
@@ -1,17 +1,20 @@
1
  """data handling specific to DPO"""
2
-
3
  import logging
 
4
  from pathlib import Path
5
  from typing import Any, List
6
 
7
  import yaml
8
- from datasets import concatenate_datasets, load_dataset, load_from_disk
9
 
10
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
11
  from axolotl.prompt_strategies.dpo import load as load_dpo
 
12
  from axolotl.utils.data.utils import md5
13
  from axolotl.utils.dict import DictDefault
14
  from axolotl.utils.distributed import is_main_process, zero_first
 
15
 
16
  LOG = logging.getLogger("axolotl")
17
 
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
72
  )
73
  split_datasets.insert(i, ds)
74
 
 
75
  for i, data_set in enumerate(split_datasets):
76
  _type = dataset_cfgs[i]["type"]
77
  if _type:
78
  if isinstance(_type, DictDefault):
79
  _type = "user_defined.default"
80
- ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
81
- split_datasets[i] = data_set.map(
 
 
 
 
 
 
 
 
 
82
  ds_transform_fn,
83
  desc="Mapping RL Dataset",
84
  )
 
 
 
85
  else:
86
  # If no `type` is provided, assume the dataset is already in the expected format with
87
  # "prompt", "chosen" and "rejected" already preprocessed
 
1
  """data handling specific to DPO"""
2
+ import inspect
3
  import logging
4
+ from functools import partial
5
  from pathlib import Path
6
  from typing import Any, List
7
 
8
  import yaml
9
+ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
10
 
11
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
12
  from axolotl.prompt_strategies.dpo import load as load_dpo
13
+ from axolotl.prompt_strategies.orpo import load as load_orpo
14
  from axolotl.utils.data.utils import md5
15
  from axolotl.utils.dict import DictDefault
16
  from axolotl.utils.distributed import is_main_process, zero_first
17
+ from axolotl.utils.models import load_tokenizer
18
 
19
  LOG = logging.getLogger("axolotl")
20
 
 
75
  )
76
  split_datasets.insert(i, ds)
77
 
78
+ tokenizer = None
79
  for i, data_set in enumerate(split_datasets):
80
  _type = dataset_cfgs[i]["type"]
81
  if _type:
82
  if isinstance(_type, DictDefault):
83
  _type = "user_defined.default"
84
+ if _cfg.rl == "orpo":
85
+ ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
86
+ else:
87
+ ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
88
+ sig = inspect.signature(ds_transform_fn)
89
+ if "tokenizer" in sig.parameters:
90
+ if not tokenizer:
91
+ tokenizer = load_tokenizer(_cfg)
92
+ ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
93
+
94
+ data_set = data_set.map(
95
  ds_transform_fn,
96
  desc="Mapping RL Dataset",
97
  )
98
+ if isinstance(data_set, DatasetDict):
99
+ data_set = data_set["train"]
100
+ split_datasets[i] = data_set
101
  else:
102
  # If no `type` is provided, assume the dataset is already in the expected format with
103
  # "prompt", "chosen" and "rejected" already preprocessed
src/axolotl/utils/trainer.py CHANGED
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
13
  from torch.utils.data import DataLoader, RandomSampler
14
  from transformers.utils import is_torch_bf16_gpu_available
15
 
16
- from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
17
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
18
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
19
 
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
340
 
341
 
342
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
343
- if cfg.rl in ["dpo", "ipo", "kto_pair"]:
344
- trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
345
  trainer_builder.model_ref = model[1]
346
  trainer_builder.peft_config = model[2]
347
  else:
 
13
  from torch.utils.data import DataLoader, RandomSampler
14
  from transformers.utils import is_torch_bf16_gpu_available
15
 
16
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
17
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
18
  from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
19
 
 
340
 
341
 
342
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
343
+ if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
344
+ trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
345
  trainer_builder.model_ref = model[1]
346
  trainer_builder.peft_config = model[2]
347
  else:
tests/core/test_trainer_builder.py CHANGED
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
4
 
5
  import pytest
6
 
7
- from axolotl.core.trainer_builder import HFDPOTrainerBuilder
8
  from axolotl.utils.config import normalize_config
9
  from axolotl.utils.dict import DictDefault
10
  from axolotl.utils.models import load_model, load_tokenizer
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
51
  return load_model(cfg, tokenizer)
52
 
53
 
54
- class TestHFDPOTrainerBuilder:
55
  """
56
  TestCase class for DPO trainer builder
57
  """
58
 
59
  def test_build_training_arguments(self, cfg, model, tokenizer):
60
- builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
61
  training_arguments = builder.build_training_arguments(100)
62
  assert training_arguments.adam_beta1 == 0.998
63
  assert training_arguments.adam_beta2 == 0.9
 
4
 
5
  import pytest
6
 
7
+ from axolotl.core.trainer_builder import HFRLTrainerBuilder
8
  from axolotl.utils.config import normalize_config
9
  from axolotl.utils.dict import DictDefault
10
  from axolotl.utils.models import load_model, load_tokenizer
 
51
  return load_model(cfg, tokenizer)
52
 
53
 
54
+ class TestHFRLTrainerBuilder:
55
  """
56
  TestCase class for DPO trainer builder
57
  """
58
 
59
  def test_build_training_arguments(self, cfg, model, tokenizer):
60
+ builder = HFRLTrainerBuilder(cfg, model, tokenizer)
61
  training_arguments = builder.build_training_arguments(100)
62
  assert training_arguments.adam_beta1 == 0.998
63
  assert training_arguments.adam_beta2 == 0.9