winglian Nero10578 commited on
Commit
2147cf6
·
unverified ·
1 Parent(s): 50421c8

Llama3 dpo (#1610)

Browse files

* add dpo llama3

* fix dpo bos and eos

* bos token gets added automatically by the tokenizer

* explicit <|end_of_text|> not needed, as eot_id is sufficient

---------

Co-authored-by: Nero10578 <[email protected]>

src/axolotl/prompt_strategies/dpo/llama3.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO strategies for llama-3 chat template
3
+ """
4
+
5
+
6
+ def argilla(
7
+ cfg,
8
+ **kwargs,
9
+ ): # pylint: disable=possibly-unused-variable,unused-argument
10
+ def transform_fn(sample):
11
+ if "system" in sample and sample["system"]:
12
+ sample["prompt"] = (
13
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
14
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
15
+ )
16
+ else:
17
+ sample[
18
+ "prompt"
19
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
20
+ sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
21
+ sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
22
+ return sample
23
+
24
+ return transform_fn
25
+
26
+
27
+ def argilla_chat(
28
+ cfg,
29
+ **kwargs,
30
+ ): # pylint: disable=possibly-unused-variable,unused-argument
31
+ """
32
+ for argilla/dpo-mix-7k conversations
33
+ """
34
+
35
+ def transform_fn(sample):
36
+ sample[
37
+ "prompt"
38
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
39
+ sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
40
+ sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
41
+ return sample
42
+
43
+ return transform_fn
44
+
45
+
46
+ def icr(
47
+ cfg,
48
+ **kwargs,
49
+ ): # pylint: disable=possibly-unused-variable,unused-argument
50
+ """
51
+ chatml transforms for datasets with system, input, chosen, rejected
52
+ ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
53
+ """
54
+
55
+ def transform_fn(sample):
56
+ if "system" in sample and sample["system"]:
57
+ sample["prompt"] = (
58
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
59
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
60
+ )
61
+ else:
62
+ sample[
63
+ "prompt"
64
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
65
+ sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
66
+ sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
67
+ return sample
68
+
69
+ return transform_fn
70
+
71
+
72
+ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
73
+ """
74
+ For Intel Orca DPO Pairs
75
+ """
76
+
77
+ def transform_fn(sample):
78
+ if "system" in sample and sample["system"]:
79
+ sample["prompt"] = (
80
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
81
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
82
+ )
83
+ else:
84
+ sample[
85
+ "prompt"
86
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
87
+ sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
88
+ sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
89
+ return sample
90
+
91
+ return transform_fn
92
+
93
+
94
+ def prompt_pairs(
95
+ cfg, **kwargs
96
+ ): # pylint: disable=possibly-unused-variable,unused-argument
97
+ def transform_fn(sample):
98
+ if "system" in sample and sample["system"]:
99
+ sample["prompt"] = (
100
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
101
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
102
+ )
103
+ else:
104
+ sample[
105
+ "prompt"
106
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
107
+ sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
108
+ sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
109
+ return sample
110
+
111
+ return transform_fn
112
+
113
+
114
+ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
115
+ """
116
+ for ultrafeedback binarized conversations
117
+ """
118
+
119
+ def transform_fn(sample):
120
+ if "system" in sample and sample["system"]:
121
+ sample["prompt"] = (
122
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
123
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
124
+ )
125
+ else:
126
+ sample[
127
+ "prompt"
128
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
129
+ sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
130
+ sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
131
+ return sample
132
+
133
+ return transform_fn