ORPO Trainer replacement (#1551)
Browse files* WIP use trl ORPOTrainer
* fixes to make orpo work with trl
* fix the chat template laoding
* make sure to handle the special tokens and add_generation for assistant turn too
- requirements.txt +1 -1
- src/axolotl/cli/preprocess.py +1 -1
- src/axolotl/cli/train.py +1 -1
- src/axolotl/core/trainer_builder.py +36 -11
- src/axolotl/prompt_strategies/orpo/__init__.py +1 -1
- src/axolotl/prompt_strategies/orpo/chat_template.py +84 -0
- src/axolotl/utils/data/__init__.py +1 -1
- src/axolotl/utils/data/{dpo.py β rl.py} +20 -4
- src/axolotl/utils/trainer.py +3 -3
- tests/core/test_trainer_builder.py +3 -3
requirements.txt
CHANGED
@@ -39,6 +39,6 @@ s3fs
|
|
39 |
gcsfs
|
40 |
# adlfs
|
41 |
|
42 |
-
trl
|
43 |
zstandard==0.22.0
|
44 |
fastcore
|
|
|
39 |
gcsfs
|
40 |
# adlfs
|
41 |
|
42 |
+
trl==0.8.5
|
43 |
zstandard==0.22.0
|
44 |
fastcore
|
src/axolotl/cli/preprocess.py
CHANGED
@@ -54,7 +54,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
|
54 |
LOG.warning(msg)
|
55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
56 |
|
57 |
-
if parsed_cfg.rl and parsed_cfg.rl != "orpo":
|
58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
59 |
else:
|
60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
|
|
54 |
LOG.warning(msg)
|
55 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
56 |
|
57 |
+
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
|
58 |
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
59 |
else:
|
60 |
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
src/axolotl/cli/train.py
CHANGED
@@ -47,7 +47,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
|
|
47 |
else:
|
48 |
register_chatml_template()
|
49 |
|
50 |
-
if cfg.rl and cfg.rl != "orpo":
|
51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
52 |
else:
|
53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
|
47 |
else:
|
48 |
register_chatml_template()
|
49 |
|
50 |
+
if cfg.rl: # and cfg.rl != "orpo":
|
51 |
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
52 |
else:
|
53 |
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -30,7 +30,7 @@ from transformers import (
|
|
30 |
)
|
31 |
from transformers.trainer_utils import seed_worker
|
32 |
from transformers.utils import is_sagemaker_mp_enabled
|
33 |
-
from trl import DPOTrainer
|
34 |
from trl.trainer.utils import pad_to_length
|
35 |
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
@@ -810,6 +810,14 @@ class AxolotlDPOTrainer(DPOTrainer):
|
|
810 |
return res
|
811 |
|
812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
class TrainerBuilderBase(abc.ABC):
|
814 |
"""
|
815 |
Base class for trainer builder
|
@@ -1404,7 +1412,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|
1404 |
)
|
1405 |
|
1406 |
|
1407 |
-
class
|
1408 |
"""
|
1409 |
Trainer factory class for DPO Trainer
|
1410 |
"""
|
@@ -1497,7 +1505,15 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
1497 |
# default to saving each epoch if not defined
|
1498 |
training_args_kwargs["save_strategy"] = "epoch"
|
1499 |
|
1500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1501 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1502 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1503 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
@@ -1530,17 +1546,26 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
1530 |
dpo_trainer_kwargs[
|
1531 |
"precompute_ref_log_probs"
|
1532 |
] = self.cfg.precompute_ref_log_probs
|
1533 |
-
|
1534 |
-
|
1535 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1536 |
args=training_args,
|
1537 |
-
beta=self.cfg.dpo_beta or 0.1,
|
1538 |
train_dataset=self.train_dataset,
|
1539 |
tokenizer=self.tokenizer,
|
1540 |
-
max_length=self.cfg.sequence_len,
|
1541 |
-
max_target_length=None,
|
1542 |
-
max_prompt_length=self.cfg.sequence_len,
|
1543 |
-
generate_during_eval=True,
|
1544 |
callbacks=self.get_callbacks(),
|
1545 |
**dpo_trainer_kwargs,
|
1546 |
)
|
|
|
30 |
)
|
31 |
from transformers.trainer_utils import seed_worker
|
32 |
from transformers.utils import is_sagemaker_mp_enabled
|
33 |
+
from trl import DPOTrainer, ORPOConfig, ORPOTrainer
|
34 |
from trl.trainer.utils import pad_to_length
|
35 |
|
36 |
from axolotl.loraplus import create_loraplus_optimizer
|
|
|
810 |
return res
|
811 |
|
812 |
|
813 |
+
class AxolotlORPOTrainer(ORPOTrainer):
|
814 |
+
"""
|
815 |
+
Extend the base ORPOTrainer for axolotl helpers
|
816 |
+
"""
|
817 |
+
|
818 |
+
tag_names = ["axolotl", "orpo"]
|
819 |
+
|
820 |
+
|
821 |
class TrainerBuilderBase(abc.ABC):
|
822 |
"""
|
823 |
Base class for trainer builder
|
|
|
1412 |
)
|
1413 |
|
1414 |
|
1415 |
+
class HFRLTrainerBuilder(TrainerBuilderBase):
|
1416 |
"""
|
1417 |
Trainer factory class for DPO Trainer
|
1418 |
"""
|
|
|
1505 |
# default to saving each epoch if not defined
|
1506 |
training_args_kwargs["save_strategy"] = "epoch"
|
1507 |
|
1508 |
+
if self.cfg.orpo_alpha:
|
1509 |
+
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
|
1510 |
+
training_args_kwargs["beta"] = self.cfg.orpo_alpha
|
1511 |
+
|
1512 |
+
training_args_cls = TrainingArguments
|
1513 |
+
if self.cfg.rl == "orpo":
|
1514 |
+
training_args_cls = ORPOConfig
|
1515 |
+
|
1516 |
+
training_args = training_args_cls(
|
1517 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1518 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1519 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
|
|
1546 |
dpo_trainer_kwargs[
|
1547 |
"precompute_ref_log_probs"
|
1548 |
] = self.cfg.precompute_ref_log_probs
|
1549 |
+
if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
1550 |
+
trainer_cls = AxolotlDPOTrainer
|
1551 |
+
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
1552 |
+
trainer_cls_args = [self.model, self.model_ref]
|
1553 |
+
|
1554 |
+
# these aren't used for the ORPO trainer
|
1555 |
+
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
|
1556 |
+
dpo_trainer_kwargs["max_target_length"] = None
|
1557 |
+
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
|
1558 |
+
dpo_trainer_kwargs["generate_during_eval"] = True
|
1559 |
+
elif self.cfg.rl == "orpo":
|
1560 |
+
trainer_cls = AxolotlORPOTrainer
|
1561 |
+
trainer_cls_args = [self.model]
|
1562 |
+
else:
|
1563 |
+
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
1564 |
+
dpo_trainer = trainer_cls(
|
1565 |
+
*trainer_cls_args,
|
1566 |
args=training_args,
|
|
|
1567 |
train_dataset=self.train_dataset,
|
1568 |
tokenizer=self.tokenizer,
|
|
|
|
|
|
|
|
|
1569 |
callbacks=self.get_callbacks(),
|
1570 |
**dpo_trainer_kwargs,
|
1571 |
)
|
src/axolotl/prompt_strategies/orpo/__init__.py
CHANGED
@@ -6,4 +6,4 @@ from functools import partial
|
|
6 |
|
7 |
from ..base import load as load_base
|
8 |
|
9 |
-
load = partial(load_base,
|
|
|
6 |
|
7 |
from ..base import load as load_base
|
8 |
|
9 |
+
load = partial(load_base, module_base="axolotl.prompt_strategies.orpo")
|
src/axolotl/prompt_strategies/orpo/chat_template.py
CHANGED
@@ -78,6 +78,57 @@ class ORPODatasetParsingStrategy:
|
|
78 |
)
|
79 |
return MessageList(messages=messages)
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
83 |
"""
|
@@ -186,3 +237,36 @@ class ORPOPrompter(Prompter):
|
|
186 |
chat_template=self.chat_template,
|
187 |
tokenize=False,
|
188 |
), True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
return MessageList(messages=messages)
|
80 |
|
81 |
+
def get_prompt(self, prompt) -> MessageList:
|
82 |
+
"""Map the data to extract everything up to the last turn"""
|
83 |
+
total_msg_len = len(prompt["chosen"])
|
84 |
+
total_msg_turns, remainder = divmod(total_msg_len, 2)
|
85 |
+
assert remainder == 0, "invalid number of turns"
|
86 |
+
|
87 |
+
messages: List[Message] = []
|
88 |
+
if system := prompt.get("system", None):
|
89 |
+
messages.append(Message(role="system", content=system, label=False))
|
90 |
+
for i in range(total_msg_turns):
|
91 |
+
if "prompt" in prompt:
|
92 |
+
messages.append(
|
93 |
+
Message(role="user", content=prompt["prompt"], label=False)
|
94 |
+
)
|
95 |
+
else:
|
96 |
+
messages.append(
|
97 |
+
Message(
|
98 |
+
role="user",
|
99 |
+
content=prompt["chosen"][i * 2]["content"],
|
100 |
+
label=False,
|
101 |
+
)
|
102 |
+
)
|
103 |
+
if i < total_msg_turns - 1:
|
104 |
+
messages.append(
|
105 |
+
Message(
|
106 |
+
role="assistant",
|
107 |
+
content=prompt["chosen"][i * 2 + 1]["content"],
|
108 |
+
label=False,
|
109 |
+
)
|
110 |
+
)
|
111 |
+
|
112 |
+
return MessageList(messages=messages)
|
113 |
+
|
114 |
+
def get_chosen(self, prompt) -> MessageList:
|
115 |
+
res = self.get_prompt(prompt)
|
116 |
+
res.messages.append(
|
117 |
+
Message(
|
118 |
+
role="assistant", content=prompt["chosen"][-1]["content"], label=True
|
119 |
+
)
|
120 |
+
)
|
121 |
+
return res
|
122 |
+
|
123 |
+
def get_rejected(self, prompt) -> MessageList:
|
124 |
+
res = self.get_prompt(prompt)
|
125 |
+
res.messages.append(
|
126 |
+
Message(
|
127 |
+
role="assistant", content=prompt["rejected"][-1]["content"], label=True
|
128 |
+
)
|
129 |
+
)
|
130 |
+
return res
|
131 |
+
|
132 |
|
133 |
class ORPOTokenizingStrategy(PromptTokenizingStrategy):
|
134 |
"""
|
|
|
237 |
chat_template=self.chat_template,
|
238 |
tokenize=False,
|
239 |
), True
|
240 |
+
|
241 |
+
|
242 |
+
def argilla(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
|
243 |
+
dataset_parser = ORPODatasetParsingStrategy()
|
244 |
+
|
245 |
+
chat_template_str = chat_templates(cfg.chat_template)
|
246 |
+
|
247 |
+
def transform_fn(sample, tokenizer=None):
|
248 |
+
res = {}
|
249 |
+
|
250 |
+
res["prompt"] = tokenizer.apply_chat_template(
|
251 |
+
[msg.model_dump() for msg in dataset_parser.get_prompt(sample).messages],
|
252 |
+
add_generation_prompt=True,
|
253 |
+
chat_template=chat_template_str,
|
254 |
+
tokenize=False,
|
255 |
+
)
|
256 |
+
prompt_str_len = len(res["prompt"])
|
257 |
+
res["chosen"] = tokenizer.apply_chat_template(
|
258 |
+
[msg.model_dump() for msg in dataset_parser.get_chosen(sample).messages],
|
259 |
+
add_generation_prompt=False,
|
260 |
+
chat_template=chat_template_str,
|
261 |
+
tokenize=False,
|
262 |
+
)[prompt_str_len:]
|
263 |
+
res["rejected"] = tokenizer.apply_chat_template(
|
264 |
+
[msg.model_dump() for msg in dataset_parser.get_rejected(sample).messages],
|
265 |
+
add_generation_prompt=False,
|
266 |
+
chat_template=chat_template_str,
|
267 |
+
tokenize=False,
|
268 |
+
)[prompt_str_len:]
|
269 |
+
|
270 |
+
return res
|
271 |
+
|
272 |
+
return transform_fn
|
src/axolotl/utils/data/__init__.py
CHANGED
@@ -1,11 +1,11 @@
|
|
1 |
"""
|
2 |
Data processing modules
|
3 |
"""
|
4 |
-
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
5 |
from axolotl.utils.data.pretraining import ( # noqa: F401
|
6 |
encode_pretraining,
|
7 |
wrap_pretraining_dataset,
|
8 |
)
|
|
|
9 |
from axolotl.utils.data.sft import ( # noqa: F401
|
10 |
get_dataset_wrapper,
|
11 |
load_prepare_datasets,
|
|
|
1 |
"""
|
2 |
Data processing modules
|
3 |
"""
|
|
|
4 |
from axolotl.utils.data.pretraining import ( # noqa: F401
|
5 |
encode_pretraining,
|
6 |
wrap_pretraining_dataset,
|
7 |
)
|
8 |
+
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
|
9 |
from axolotl.utils.data.sft import ( # noqa: F401
|
10 |
get_dataset_wrapper,
|
11 |
load_prepare_datasets,
|
src/axolotl/utils/data/{dpo.py β rl.py}
RENAMED
@@ -1,17 +1,20 @@
|
|
1 |
"""data handling specific to DPO"""
|
2 |
-
|
3 |
import logging
|
|
|
4 |
from pathlib import Path
|
5 |
from typing import Any, List
|
6 |
|
7 |
import yaml
|
8 |
-
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
9 |
|
10 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
11 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
|
12 |
from axolotl.utils.data.utils import md5
|
13 |
from axolotl.utils.dict import DictDefault
|
14 |
from axolotl.utils.distributed import is_main_process, zero_first
|
|
|
15 |
|
16 |
LOG = logging.getLogger("axolotl")
|
17 |
|
@@ -72,16 +75,29 @@ def load_prepare_dpo_datasets(cfg):
|
|
72 |
)
|
73 |
split_datasets.insert(i, ds)
|
74 |
|
|
|
75 |
for i, data_set in enumerate(split_datasets):
|
76 |
_type = dataset_cfgs[i]["type"]
|
77 |
if _type:
|
78 |
if isinstance(_type, DictDefault):
|
79 |
_type = "user_defined.default"
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
ds_transform_fn,
|
83 |
desc="Mapping RL Dataset",
|
84 |
)
|
|
|
|
|
|
|
85 |
else:
|
86 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
87 |
# "prompt", "chosen" and "rejected" already preprocessed
|
|
|
1 |
"""data handling specific to DPO"""
|
2 |
+
import inspect
|
3 |
import logging
|
4 |
+
from functools import partial
|
5 |
from pathlib import Path
|
6 |
from typing import Any, List
|
7 |
|
8 |
import yaml
|
9 |
+
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
10 |
|
11 |
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
12 |
from axolotl.prompt_strategies.dpo import load as load_dpo
|
13 |
+
from axolotl.prompt_strategies.orpo import load as load_orpo
|
14 |
from axolotl.utils.data.utils import md5
|
15 |
from axolotl.utils.dict import DictDefault
|
16 |
from axolotl.utils.distributed import is_main_process, zero_first
|
17 |
+
from axolotl.utils.models import load_tokenizer
|
18 |
|
19 |
LOG = logging.getLogger("axolotl")
|
20 |
|
|
|
75 |
)
|
76 |
split_datasets.insert(i, ds)
|
77 |
|
78 |
+
tokenizer = None
|
79 |
for i, data_set in enumerate(split_datasets):
|
80 |
_type = dataset_cfgs[i]["type"]
|
81 |
if _type:
|
82 |
if isinstance(_type, DictDefault):
|
83 |
_type = "user_defined.default"
|
84 |
+
if _cfg.rl == "orpo":
|
85 |
+
ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
|
86 |
+
else:
|
87 |
+
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
88 |
+
sig = inspect.signature(ds_transform_fn)
|
89 |
+
if "tokenizer" in sig.parameters:
|
90 |
+
if not tokenizer:
|
91 |
+
tokenizer = load_tokenizer(_cfg)
|
92 |
+
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
|
93 |
+
|
94 |
+
data_set = data_set.map(
|
95 |
ds_transform_fn,
|
96 |
desc="Mapping RL Dataset",
|
97 |
)
|
98 |
+
if isinstance(data_set, DatasetDict):
|
99 |
+
data_set = data_set["train"]
|
100 |
+
split_datasets[i] = data_set
|
101 |
else:
|
102 |
# If no `type` is provided, assume the dataset is already in the expected format with
|
103 |
# "prompt", "chosen" and "rejected" already preprocessed
|
src/axolotl/utils/trainer.py
CHANGED
@@ -13,7 +13,7 @@ from datasets import set_caching_enabled
|
|
13 |
from torch.utils.data import DataLoader, RandomSampler
|
14 |
from transformers.utils import is_torch_bf16_gpu_available
|
15 |
|
16 |
-
from axolotl.core.trainer_builder import HFCausalTrainerBuilder,
|
17 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
18 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
19 |
|
@@ -340,8 +340,8 @@ def prepare_optim_env(cfg):
|
|
340 |
|
341 |
|
342 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
343 |
-
if cfg.rl in ["dpo", "ipo", "kto_pair"]:
|
344 |
-
trainer_builder =
|
345 |
trainer_builder.model_ref = model[1]
|
346 |
trainer_builder.peft_config = model[2]
|
347 |
else:
|
|
|
13 |
from torch.utils.data import DataLoader, RandomSampler
|
14 |
from transformers.utils import is_torch_bf16_gpu_available
|
15 |
|
16 |
+
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
17 |
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
18 |
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
19 |
|
|
|
340 |
|
341 |
|
342 |
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
343 |
+
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
|
344 |
+
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
345 |
trainer_builder.model_ref = model[1]
|
346 |
trainer_builder.peft_config = model[2]
|
347 |
else:
|
tests/core/test_trainer_builder.py
CHANGED
@@ -4,7 +4,7 @@ unit tests for axolotl.core.trainer_builder
|
|
4 |
|
5 |
import pytest
|
6 |
|
7 |
-
from axolotl.core.trainer_builder import
|
8 |
from axolotl.utils.config import normalize_config
|
9 |
from axolotl.utils.dict import DictDefault
|
10 |
from axolotl.utils.models import load_model, load_tokenizer
|
@@ -51,13 +51,13 @@ def fixture_model(cfg, tokenizer):
|
|
51 |
return load_model(cfg, tokenizer)
|
52 |
|
53 |
|
54 |
-
class
|
55 |
"""
|
56 |
TestCase class for DPO trainer builder
|
57 |
"""
|
58 |
|
59 |
def test_build_training_arguments(self, cfg, model, tokenizer):
|
60 |
-
builder =
|
61 |
training_arguments = builder.build_training_arguments(100)
|
62 |
assert training_arguments.adam_beta1 == 0.998
|
63 |
assert training_arguments.adam_beta2 == 0.9
|
|
|
4 |
|
5 |
import pytest
|
6 |
|
7 |
+
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
8 |
from axolotl.utils.config import normalize_config
|
9 |
from axolotl.utils.dict import DictDefault
|
10 |
from axolotl.utils.models import load_model, load_tokenizer
|
|
|
51 |
return load_model(cfg, tokenizer)
|
52 |
|
53 |
|
54 |
+
class TestHFRLTrainerBuilder:
|
55 |
"""
|
56 |
TestCase class for DPO trainer builder
|
57 |
"""
|
58 |
|
59 |
def test_build_training_arguments(self, cfg, model, tokenizer):
|
60 |
+
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
61 |
training_arguments = builder.build_training_arguments(100)
|
62 |
assert training_arguments.adam_beta1 == 0.998
|
63 |
assert training_arguments.adam_beta2 == 0.9
|