winglian commited on
Commit
33e1170
·
unverified ·
1 Parent(s): b4ac96a

precompute dpo logprobs setting and fixes (#1199) [skip ci]

Browse files

* add support for precompute_ref_log_probs for dpo

* add chatml.icr type for argilla orca dpo

* update inline doc

* also set use_reentrant to false for dpo when not set

* don't set use_reentrant to true for rl

* make sure to set gradient checkpointing too

src/axolotl/core/trainer_builder.py CHANGED
@@ -651,7 +651,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
651
  training_arguments_kwargs[
652
  "gradient_checkpointing"
653
  ] = self.cfg.gradient_checkpointing
654
- if self.cfg.gradient_checkpointing_kwargs:
655
  training_arguments_kwargs[
656
  "gradient_checkpointing_kwargs"
657
  ] = self.cfg.gradient_checkpointing_kwargs
@@ -1028,6 +1028,18 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1028
  training_args_kwargs[
1029
  "dataloader_prefetch_factor"
1030
  ] = self.cfg.dataloader_prefetch_factor
 
 
 
 
 
 
 
 
 
 
 
 
1031
 
1032
  training_args = TrainingArguments(
1033
  per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1038,9 +1050,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1038
  save_steps=self.cfg.save_steps,
1039
  output_dir=self.cfg.output_dir,
1040
  warmup_steps=self.cfg.warmup_steps,
1041
- gradient_checkpointing=self.cfg.gradient_checkpointing,
1042
- gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
1043
- or {"use_reentrant": False},
1044
  logging_first_step=True,
1045
  logging_steps=1,
1046
  optim=self.cfg.optimizer,
@@ -1063,6 +1072,10 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1063
  dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
1064
  if self.cfg.adapter and self.peft_config:
1065
  dpo_trainer_kwargs["peft_config"] = self.peft_config
 
 
 
 
1066
  dpo_trainer = DPOTrainer(
1067
  self.model,
1068
  self.model_ref,
 
651
  training_arguments_kwargs[
652
  "gradient_checkpointing"
653
  ] = self.cfg.gradient_checkpointing
654
+ if self.cfg.gradient_checkpointing_kwargs is not None:
655
  training_arguments_kwargs[
656
  "gradient_checkpointing_kwargs"
657
  ] = self.cfg.gradient_checkpointing_kwargs
 
1028
  training_args_kwargs[
1029
  "dataloader_prefetch_factor"
1030
  ] = self.cfg.dataloader_prefetch_factor
1031
+ if self.cfg.gradient_checkpointing:
1032
+ training_args_kwargs[
1033
+ "gradient_checkpointing"
1034
+ ] = self.cfg.gradient_checkpointing
1035
+ if self.cfg.gradient_checkpointing_kwargs is not None:
1036
+ training_args_kwargs[
1037
+ "gradient_checkpointing_kwargs"
1038
+ ] = self.cfg.gradient_checkpointing_kwargs
1039
+ else:
1040
+ training_args_kwargs["gradient_checkpointing_kwargs"] = {
1041
+ "use_reentrant": False
1042
+ }
1043
 
1044
  training_args = TrainingArguments(
1045
  per_device_train_batch_size=self.cfg.micro_batch_size,
 
1050
  save_steps=self.cfg.save_steps,
1051
  output_dir=self.cfg.output_dir,
1052
  warmup_steps=self.cfg.warmup_steps,
 
 
 
1053
  logging_first_step=True,
1054
  logging_steps=1,
1055
  optim=self.cfg.optimizer,
 
1072
  dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
1073
  if self.cfg.adapter and self.peft_config:
1074
  dpo_trainer_kwargs["peft_config"] = self.peft_config
1075
+ if self.cfg.precompute_ref_log_probs is not None:
1076
+ dpo_trainer_kwargs[
1077
+ "precompute_ref_log_probs"
1078
+ ] = self.cfg.precompute_ref_log_probs
1079
  dpo_trainer = DPOTrainer(
1080
  self.model,
1081
  self.model_ref,
src/axolotl/prompt_strategies/dpo/chatml.py CHANGED
@@ -23,6 +23,31 @@ def argilla(
23
  return transform_fn
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
27
  """
28
  For Intel Orca DPO Pairs
 
23
  return transform_fn
24
 
25
 
26
+ def icr(
27
+ cfg,
28
+ ): # pylint: disable=possibly-unused-variable,unused-argument
29
+ """
30
+ chatml transforms for datasets with system, input, chosen, rejected
31
+ ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
32
+ """
33
+
34
+ def transform_fn(sample):
35
+ if "system" in sample and sample["system"]:
36
+ sample["prompt"] = (
37
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
38
+ f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
39
+ )
40
+ else:
41
+ sample[
42
+ "prompt"
43
+ ] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
44
+ sample["chosen"] = f"{sample['chosen']}<|im_end|>"
45
+ sample["rejected"] = f"{sample['rejected']}<|im_end|>"
46
+ return sample
47
+
48
+ return transform_fn
49
+
50
+
51
  def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
52
  """
53
  For Intel Orca DPO Pairs
src/axolotl/utils/config.py CHANGED
@@ -163,6 +163,7 @@ def normalize_config(cfg):
163
  cfg.gradient_checkpointing
164
  and cfg.unfrozen_parameters is None
165
  and cfg.gradient_checkpointing_kwargs is None
 
166
  ):
167
  cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
168
 
 
163
  cfg.gradient_checkpointing
164
  and cfg.unfrozen_parameters is None
165
  and cfg.gradient_checkpointing_kwargs is None
166
+ and cfg.rl is None
167
  ):
168
  cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
169