winglian commited on
Commit
f243c21
·
1 Parent(s): 59b2d30

RL/DPO (#935)

Browse files

* ipo-dpo trainer

* fix missing abstract method

* chatml template, grad checkpointing kwargs support

* fix steps calc for RL and add dataloader kwargs

* wip to fix dpo and start ppo

* more fixes

* refactor to generalize map fn

* fix dataset loop and handle argilla pref dataset

* set training args

* load reference model on seperate gpu if more than one device

* no auto upload to hub for dpo, don't add lora adapters to ref model for dpo

* fixes for rl training

* support for ipo from yaml

* set dpo training args from the config, add tests

* chore: lint

* set sequence_len for model in test

* add RLHF docs

docs/rlhf.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RLHF (Beta)
2
+
3
+ ### Overview
4
+
5
+ Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human
6
+ feedback. Various methods include, but not limited to:
7
+
8
+ - Proximal Policy Optimization (PPO) (not yet supported in axolotl)
9
+ - Direct Preference Optimization (DPO)
10
+ - Identity Preference Optimization (IPO)
11
+
12
+
13
+ ### RLHF using Axolotl
14
+
15
+ [!IMPORTANT]
16
+ This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
17
+
18
+ The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
19
+
20
+ #### DPO
21
+ ```yaml
22
+ rl: true
23
+ datasets:
24
+ - path: Intel/orca_dpo_pairs
25
+ split: train
26
+ type: intel_apply_chatml
27
+ - path: argilla/ultrafeedback-binarized-preferences
28
+ split: train
29
+ type: argilla_apply_chatml
30
+ ```
31
+
32
+ #### IPO
33
+ ```yaml
34
+ rl: ipo
35
+ ```
requirements.txt CHANGED
@@ -37,3 +37,5 @@ tensorboard
37
  s3fs
38
  gcsfs
39
  # adlfs
 
 
 
37
  s3fs
38
  gcsfs
39
  # adlfs
40
+
41
+ trl @ git+https://github.com/huggingface/trl.git@main
src/axolotl/cli/__init__.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import importlib
4
  import logging
 
5
  import os
6
  import random
7
  import sys
@@ -16,6 +17,7 @@ import yaml
16
  # add src to the pythonpath so we don't need to pip install this
17
  from accelerate.commands.config import config_args
18
  from art import text2art
 
19
  from huggingface_hub import HfApi
20
  from huggingface_hub.utils import LocalTokenNotFoundError
21
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
@@ -325,6 +327,94 @@ def load_datasets(
325
  )
326
 
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  def check_accelerate_default_config():
329
  if Path(config_args.default_yaml_config_file).exists():
330
  LOG.warning(
 
2
 
3
  import importlib
4
  import logging
5
+ import math
6
  import os
7
  import random
8
  import sys
 
17
  # add src to the pythonpath so we don't need to pip install this
18
  from accelerate.commands.config import config_args
19
  from art import text2art
20
+ from datasets import concatenate_datasets, load_dataset
21
  from huggingface_hub import HfApi
22
  from huggingface_hub.utils import LocalTokenNotFoundError
23
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
 
327
  )
328
 
329
 
330
+ def load_rl_datasets(
331
+ *,
332
+ cfg: DictDefault,
333
+ cli_args: TrainerCliArgs, # pylint: disable=unused-argument
334
+ ) -> TrainDatasetMeta:
335
+ train_datasets: List[Any] = []
336
+ for i, ds_cfg in enumerate(cfg.datasets):
337
+ train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
338
+ # eval_dataset = load_dataset(
339
+ # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
340
+ # )
341
+ eval_dataset = None
342
+
343
+ def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
344
+ if "system" in sample and sample["system"]:
345
+ sample["prompt"] = (
346
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
347
+ f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
348
+ )
349
+ else:
350
+ sample[
351
+ "prompt"
352
+ ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
353
+ sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
354
+ sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
355
+ return sample
356
+
357
+ def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
358
+ if "system" in sample and sample["system"]:
359
+ sample["prompt"] = (
360
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
361
+ f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
362
+ )
363
+ else:
364
+ sample[
365
+ "prompt"
366
+ ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
367
+ sample["chosen"] = f"{sample['chosen']}<|im_end|>"
368
+ sample["rejected"] = f"{sample['rejected']}<|im_end|>"
369
+ return sample
370
+
371
+ def apply_chatml(sample): # pylint: disable=possibly-unused-variable
372
+ if "system" in sample and sample["system"]:
373
+ sample["prompt"] = (
374
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
375
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
376
+ )
377
+ else:
378
+ sample[
379
+ "prompt"
380
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
381
+ sample["chosen"] = f"{sample['chosen']}<|im_end|>"
382
+ sample["rejected"] = f"{sample['rejected']}<|im_end|>"
383
+ return sample
384
+
385
+ def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
386
+ if "system" in sample and sample["system"]:
387
+ sample["prompt"] = (
388
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
389
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
390
+ )
391
+ else:
392
+ sample[
393
+ "prompt"
394
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
395
+ sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
396
+ sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
397
+ return sample
398
+
399
+ for i, data_set in enumerate(train_datasets):
400
+ _type = cfg.datasets[i]["type"]
401
+ ds_type_fn = locals()[_type]
402
+ train_datasets[i] = data_set.map(ds_type_fn)
403
+ train_dataset = concatenate_datasets(train_datasets)
404
+
405
+ # eval_dataset = eval_dataset.map(intel_apply_chatml)
406
+
407
+ total_num_steps = int(
408
+ math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
409
+ )
410
+
411
+ return TrainDatasetMeta(
412
+ train_dataset=train_dataset,
413
+ eval_dataset=eval_dataset,
414
+ total_num_steps=total_num_steps,
415
+ )
416
+
417
+
418
  def check_accelerate_default_config():
419
  if Path(config_args.default_yaml_config_file).exists():
420
  LOG.warning(
src/axolotl/cli/train.py CHANGED
@@ -12,6 +12,7 @@ from axolotl.cli import (
12
  check_user_token,
13
  load_cfg,
14
  load_datasets,
 
15
  print_axolotl_text_art,
16
  )
17
  from axolotl.common.cli import TrainerCliArgs
@@ -30,7 +31,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
30
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
31
  return_remaining_strings=True
32
  )
33
- dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
 
34
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
35
 
36
 
 
12
  check_user_token,
13
  load_cfg,
14
  load_datasets,
15
+ load_rl_datasets,
16
  print_axolotl_text_art,
17
  )
18
  from axolotl.common.cli import TrainerCliArgs
 
31
  parsed_cli_args, _ = parser.parse_args_into_dataclasses(
32
  return_remaining_strings=True
33
  )
34
+ if parsed_cfg.rl:
35
+ dataset_meta = load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
36
+ else:
37
+ dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
38
  train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
39
 
40
 
src/axolotl/core/trainer_builder.py CHANGED
@@ -20,6 +20,7 @@ from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
  from transformers.trainer_utils import seed_worker
 
23
 
24
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
25
  from axolotl.utils.callbacks import (
@@ -420,12 +421,21 @@ class TrainerBuilderBase(abc.ABC):
420
 
421
  _train_dataset = None
422
  _eval_dataset = None
 
423
 
424
  def __init__(self, cfg, model, tokenizer):
425
  self.cfg = cfg
426
  self.model = model
427
  self.tokenizer = tokenizer
428
 
 
 
 
 
 
 
 
 
429
  @property
430
  def train_dataset(self):
431
  return self._train_dataset
@@ -827,3 +837,96 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
827
  return_tensors="pt",
828
  **kwargs,
829
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
  from transformers.trainer_utils import seed_worker
23
+ from trl import DPOTrainer
24
 
25
  from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
26
  from axolotl.utils.callbacks import (
 
421
 
422
  _train_dataset = None
423
  _eval_dataset = None
424
+ _model_ref = None
425
 
426
  def __init__(self, cfg, model, tokenizer):
427
  self.cfg = cfg
428
  self.model = model
429
  self.tokenizer = tokenizer
430
 
431
+ @property
432
+ def model_ref(self):
433
+ return self._model_ref
434
+
435
+ @model_ref.setter
436
+ def model_ref(self, model):
437
+ self._model_ref = model
438
+
439
  @property
440
  def train_dataset(self):
441
  return self._train_dataset
 
837
  return_tensors="pt",
838
  **kwargs,
839
  )
840
+
841
+
842
+ class HFDPOTrainerBuilder(TrainerBuilderBase):
843
+ """
844
+ Trainer factory class for DPO Trainer
845
+ """
846
+
847
+ def get_callbacks(self):
848
+ callbacks = []
849
+ return callbacks
850
+
851
+ def get_post_trainer_create_callbacks(self, trainer):
852
+ callbacks = []
853
+ return callbacks
854
+
855
+ def build_training_arguments(self, total_num_steps):
856
+ training_args_kwargs = {}
857
+ for arg in [
858
+ "adam_beta1",
859
+ "adam_beta2",
860
+ "adam_epsilon",
861
+ "dataloader_num_workers",
862
+ "dataloader_pin_memory",
863
+ ]:
864
+ if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
865
+ training_args_kwargs[arg] = getattr(self.cfg, arg)
866
+ training_args = TrainingArguments(
867
+ per_device_train_batch_size=self.cfg.micro_batch_size,
868
+ max_steps=total_num_steps,
869
+ remove_unused_columns=False,
870
+ gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
871
+ learning_rate=self.cfg.learning_rate,
872
+ evaluation_strategy="no",
873
+ # eval_steps=self.cfg.eval_steps,
874
+ save_strategy="steps",
875
+ save_steps=self.cfg.save_steps,
876
+ output_dir=self.cfg.output_dir,
877
+ warmup_steps=self.cfg.warmup_steps,
878
+ bf16=True,
879
+ gradient_checkpointing=self.cfg.gradient_checkpointing,
880
+ gradient_checkpointing_kwargs={"use_reentrant": False},
881
+ logging_first_step=True,
882
+ logging_steps=1,
883
+ optim=self.cfg.optimizer,
884
+ save_total_limit=self.cfg.save_total_limit or 5,
885
+ **training_args_kwargs,
886
+ )
887
+
888
+ return training_args
889
+
890
+ def build(self, total_num_steps):
891
+ training_args = self.build_training_arguments(total_num_steps)
892
+ dpo_trainer_kwargs = {}
893
+ if self.cfg.rl == "ipo":
894
+ dpo_trainer_kwargs["loss_type"] = "ipo"
895
+ if self.cfg.dpo_label_smoothing:
896
+ dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
897
+
898
+ dpo_trainer = DPOTrainer(
899
+ self.model,
900
+ self.model_ref,
901
+ args=training_args,
902
+ beta=self.cfg.dpo_beta or 0.1,
903
+ train_dataset=self.train_dataset,
904
+ # eval_dataset=self.eval_dataset,
905
+ eval_dataset=None,
906
+ tokenizer=self.tokenizer,
907
+ max_length=self.cfg.sequence_len,
908
+ max_target_length=None,
909
+ max_prompt_length=self.cfg.sequence_len,
910
+ generate_during_eval=True,
911
+ **dpo_trainer_kwargs,
912
+ )
913
+
914
+ return dpo_trainer
915
+
916
+
917
+ class HFPPOTrainerBuilder(TrainerBuilderBase):
918
+ """
919
+ HF Factory class for PPO Trainer
920
+ """
921
+
922
+ def get_callbacks(self):
923
+ callbacks = []
924
+ return callbacks
925
+
926
+ def get_post_trainer_create_callbacks(self, trainer):
927
+ callbacks = []
928
+ return callbacks
929
+
930
+ def build(self, total_num_steps):
931
+ # build PPOConfig
932
+ pass
src/axolotl/core/trainers/__init__.py ADDED
File without changes
src/axolotl/core/trainers/trl.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for TRL PPO training
3
+ """
4
+ import torch
5
+ from tqdm import tqdm
6
+ from trl import PPOTrainer
7
+
8
+
9
+ class TRLPPOTrainer(PPOTrainer):
10
+ """
11
+ wrapper for ppo trainer to handle customizations
12
+ """
13
+
14
+ def train(
15
+ self,
16
+ reward_pipe,
17
+ resume_from_checkpoint=None, # pylint: disable=unused-argument
18
+ ):
19
+ generation_kwargs = {
20
+ "min_length": -1,
21
+ "top_k": 0.0,
22
+ "top_p": 1.0,
23
+ "do_sample": True,
24
+ "pad_token_id": self.tokenizer.eos_token_id,
25
+ "max_new_tokens": 32,
26
+ }
27
+ sent_kwargs = {
28
+ "return_all_scores": True,
29
+ "function_to_apply": "none",
30
+ "batch_size": 16,
31
+ }
32
+
33
+ for epoch, batch in tqdm( # pylint: disable=unused-variable
34
+ enumerate(self.dataloader)
35
+ ):
36
+ query_tensors = batch["input_ids"]
37
+
38
+ # generate model response
39
+ response_tensors, ref_response_tensors = self.generate(
40
+ query_tensors,
41
+ return_prompt=False,
42
+ generate_ref_response=True,
43
+ **generation_kwargs
44
+ )
45
+ batch["response"] = self.tokenizer.batch_decode(response_tensors)
46
+ batch["ref_response"] = self.tokenizer.batch_decode(ref_response_tensors)
47
+
48
+ # Compute sentiment score
49
+ texts = [q + r for q, r in zip(batch["query"], batch["response"])]
50
+ pipe_outputs = reward_pipe(texts, **sent_kwargs)
51
+ rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
52
+ ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
53
+ ref_pipe_outputs = reward_pipe(ref_texts, **sent_kwargs)
54
+ ref_rewards = [
55
+ torch.tensor(output[1]["score"]) for output in ref_pipe_outputs
56
+ ]
57
+ batch["ref_rewards"] = ref_rewards
58
+
59
+ # Run PPO step
60
+ stats = self.step(query_tensors, response_tensors, rewards)
61
+ self.log_stats(
62
+ stats,
63
+ batch,
64
+ rewards,
65
+ columns_to_log=["query", "response", "ref_response", "ref_rewards"],
66
+ )
src/axolotl/train.py CHANGED
@@ -61,6 +61,12 @@ def train(
61
  msg += " and peft_config..."
62
  LOG.debug(msg)
63
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
 
 
 
 
 
 
64
 
65
  safe_serialization = cfg.save_safetensors is True
66
 
@@ -83,7 +89,7 @@ def train(
83
  freeze_parameters_except(model, cfg.unfrozen_parameters)
84
 
85
  trainer = setup_trainer(
86
- cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
87
  )
88
 
89
  if hasattr(model, "config"):
 
61
  msg += " and peft_config..."
62
  LOG.debug(msg)
63
  model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
64
+ model_ref = None
65
+ if cfg.rl:
66
+ # load the model again for model_ref/baseline
67
+ model_ref, _ = load_model(
68
+ cfg, tokenizer, inference=cli_args.inference, reference_model=True
69
+ )
70
 
71
  safe_serialization = cfg.save_safetensors is True
72
 
 
89
  freeze_parameters_except(model, cfg.unfrozen_parameters)
90
 
91
  trainer = setup_trainer(
92
+ cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
93
  )
94
 
95
  if hasattr(model, "config"):
src/axolotl/utils/models.py CHANGED
@@ -200,6 +200,7 @@ def load_model(
200
  cfg: DictDefault,
201
  tokenizer: PreTrainedTokenizerBase,
202
  inference: bool = False,
 
203
  ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
204
  """
205
  Load a model for a given configuration and tokenizer.
@@ -290,6 +291,15 @@ def load_model(
290
  model_kwargs["device_map"] = cfg.device_map
291
  model_kwargs["max_memory"] = cfg.max_memory
292
  model_kwargs["torch_dtype"] = cfg.torch_dtype
 
 
 
 
 
 
 
 
 
293
 
294
  if is_deepspeed_zero3_enabled():
295
  del model_kwargs["device_map"]
@@ -560,9 +570,11 @@ def load_model(
560
  if hasattr(module, "weight"):
561
  module.to(cfg.torch_dtype)
562
 
563
- model, lora_config = load_adapter(model, cfg, cfg.adapter)
 
 
564
 
565
- if cfg.ddp and not load_in_8bit:
566
  model.to(f"cuda:{cfg.local_rank}")
567
 
568
  if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
 
200
  cfg: DictDefault,
201
  tokenizer: PreTrainedTokenizerBase,
202
  inference: bool = False,
203
+ reference_model: bool = False,
204
  ) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
205
  """
206
  Load a model for a given configuration and tokenizer.
 
291
  model_kwargs["device_map"] = cfg.device_map
292
  model_kwargs["max_memory"] = cfg.max_memory
293
  model_kwargs["torch_dtype"] = cfg.torch_dtype
294
+ # TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
295
+ # if cfg.rl:
296
+ # if torch.cuda.device_count() > 1:
297
+ # if reference_model:
298
+ # model_kwargs["device_map"] = "cuda:" + str(
299
+ # torch.cuda.current_device() + 1
300
+ # )
301
+ # else:
302
+ # model_kwargs["device_map"] = "cuda:" + str(torch.cuda.current_device())
303
 
304
  if is_deepspeed_zero3_enabled():
305
  del model_kwargs["device_map"]
 
570
  if hasattr(module, "weight"):
571
  module.to(cfg.torch_dtype)
572
 
573
+ lora_config = None
574
+ if not reference_model or cfg.lora_model_dir:
575
+ model, lora_config = load_adapter(model, cfg, cfg.adapter)
576
 
577
+ if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
578
  model.to(f"cuda:{cfg.local_rank}")
579
 
580
  if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
src/axolotl/utils/trainer.py CHANGED
@@ -12,7 +12,7 @@ from accelerate.logging import get_logger
12
  from datasets import set_caching_enabled
13
  from torch.utils.data import DataLoader, RandomSampler
14
 
15
- from axolotl.core.trainer_builder import HFCausalTrainerBuilder
16
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
17
  from axolotl.utils.samplers import MultipackBatchSampler
18
 
@@ -280,7 +280,12 @@ def prepare_optim_env(cfg):
280
 
281
 
282
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
283
- trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
 
 
 
 
 
284
  trainer_builder.train_dataset = train_dataset
285
  trainer_builder.eval_dataset = eval_dataset
286
 
 
12
  from datasets import set_caching_enabled
13
  from torch.utils.data import DataLoader, RandomSampler
14
 
15
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
16
  from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
17
  from axolotl.utils.samplers import MultipackBatchSampler
18
 
 
280
 
281
 
282
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
283
+ if cfg.rl:
284
+ trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
285
+ trainer_builder.model_ref = model[1]
286
+ else:
287
+ trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
288
+
289
  trainer_builder.train_dataset = train_dataset
290
  trainer_builder.eval_dataset = eval_dataset
291
 
tests/core/test_trainer_builder.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ unit tests for axolotl.core.trainer_builder
3
+ """
4
+ import pytest
5
+
6
+ from axolotl.core.trainer_builder import HFDPOTrainerBuilder
7
+ from axolotl.utils.dict import DictDefault
8
+ from axolotl.utils.models import load_model, load_tokenizer
9
+
10
+
11
+ @pytest.fixture(name="cfg")
12
+ def fixture_cfg():
13
+ return DictDefault(
14
+ {
15
+ "base_model": "TinyLlama/TinyLlama-1.1B-Chat-v0.6",
16
+ "model_type": "AutoModelForCausalLM",
17
+ "tokenizer_type": "LlamaTokenizer",
18
+ "micro_batch_size": 1,
19
+ "gradient_accumulation_steps": 1,
20
+ "learning_rate": 0.00005,
21
+ "save_steps": 100,
22
+ "output_dir": "./model-out",
23
+ "warmup_steps": 10,
24
+ "gradient_checkpointing": False,
25
+ "optimizer": "adamw_torch",
26
+ "sequence_len": 2048,
27
+ "rl": True,
28
+ "adam_beta1": 0.998,
29
+ "adam_beta2": 0.9,
30
+ "adam_epsilon": 0.00001,
31
+ "dataloader_num_workers": 1,
32
+ "dataloader_pin_memory": True,
33
+ }
34
+ )
35
+
36
+
37
+ @pytest.fixture(name="tokenizer")
38
+ def fixture_tokenizer(cfg):
39
+ return load_tokenizer(cfg)
40
+
41
+
42
+ @pytest.fixture(name="model")
43
+ def fixture_model(cfg, tokenizer):
44
+ return load_model(cfg, tokenizer)
45
+
46
+
47
+ class TestHFDPOTrainerBuilder:
48
+ """
49
+ TestCase class for DPO trainer builder
50
+ """
51
+
52
+ def test_build_training_arguments(self, cfg, model, tokenizer):
53
+ builder = HFDPOTrainerBuilder(cfg, model, tokenizer)
54
+ training_arguments = builder.build_training_arguments(100)
55
+ assert training_arguments.adam_beta1 == 0.998
56
+ assert training_arguments.adam_beta2 == 0.9
57
+ assert training_arguments.adam_epsilon == 0.00001
58
+ assert training_arguments.dataloader_num_workers == 1
59
+ assert training_arguments.dataloader_pin_memory is True