Nanobit commited on
Commit
5d86137
·
1 Parent(s): 01c8a33

Lint prompt_tokenizers

Browse files
Files changed (1) hide show
  1. src/axolotl/prompt_tokenizers.py +89 -22
src/axolotl/prompt_tokenizers.py CHANGED
@@ -1,7 +1,10 @@
 
 
1
  import abc
2
  import copy
3
  import functools
4
  import logging
 
5
 
6
  from transformers import PreTrainedTokenizer
7
 
@@ -15,10 +18,16 @@ LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
15
 
16
 
17
  class InvalidDataException(Exception):
18
- pass
 
 
19
 
20
 
21
  class PromptTokenizingStrategy(abc.ABC):
 
 
 
 
22
  def __init__(
23
  self,
24
  prompter,
@@ -35,14 +44,14 @@ class PromptTokenizingStrategy(abc.ABC):
35
  def tokenize_prompt(self, prompt):
36
  pass
37
 
38
- @functools.cache
39
  def _get_user_token(self):
40
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
41
  if isinstance(id_or_ids, (int,)):
42
  return id_or_ids
43
  return False
44
 
45
- @functools.cache
46
  def _get_assistant_token(self):
47
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
48
  if isinstance(id_or_ids, (int,)):
@@ -51,11 +60,19 @@ class PromptTokenizingStrategy(abc.ABC):
51
 
52
 
53
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
54
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
55
  raise NotImplementedError
56
 
57
  def tokenize_prompt(self, prompt):
58
- instruction, input, response = self.parse_instruction_fields(prompt)
 
 
 
 
59
  full_prompt = self._build_full_prompt(instruction, input, response)
60
  tokenized_full_prompt = self._tokenize(full_prompt)
61
  if not self.train_on_inputs:
@@ -76,7 +93,9 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
76
 
77
  return tokenized_full_prompt
78
 
79
- def _build_full_prompt(self, instruction, input, response):
 
 
80
  return next(
81
  iter(
82
  self.prompter.build_prompt(
@@ -112,7 +131,11 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
112
 
113
 
114
  class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
115
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
116
  return (
117
  prompt["instruction"],
118
  prompt["input"] if "input" in prompt else "",
@@ -121,7 +144,11 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
121
 
122
 
123
  class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
124
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
125
  return (
126
  prompt["question"],
127
  "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
@@ -130,7 +157,11 @@ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingSt
130
 
131
 
132
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
133
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
134
  return (
135
  prompt["question"],
136
  prompt["category"],
@@ -139,7 +170,11 @@ class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
139
 
140
 
141
  class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
142
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
143
  return (
144
  prompt["INSTRUCTION"],
145
  "",
@@ -148,7 +183,11 @@ class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
148
 
149
 
150
  class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
151
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
152
  return (
153
  prompt["article"],
154
  "",
@@ -157,7 +196,11 @@ class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
157
 
158
 
159
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
160
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
161
  return (
162
  prompt["instruction"],
163
  prompt["input"] if "input" in prompt else "",
@@ -166,7 +209,11 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
166
 
167
 
168
  class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
169
- def parse_instruction_fields(self, prompt) -> (str, str, str):
 
 
 
 
170
  return (
171
  prompt["prompt"],
172
  "",
@@ -175,6 +222,10 @@ class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
175
 
176
 
177
  class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
 
 
 
 
178
  def parse_instruction_fields(self, prompt) -> str:
179
  return prompt["text"]
180
 
@@ -185,18 +236,24 @@ class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
185
 
186
  return tokenized_full_prompt
187
 
188
- def _build_full_prompt(self, instruction, input, response):
 
 
189
  return next(iter(self.prompter.build_prompt(instruction)))
190
 
191
 
192
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
193
- def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
 
 
 
 
194
  raise NotImplementedError
195
 
196
  def tokenize_prompt(self, prompt):
197
  (
198
  instruction,
199
- input,
200
  output,
201
  reflection,
202
  corrected,
@@ -223,7 +280,9 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
223
 
224
  return tokenized_full_prompt
225
 
226
- def _build_full_prompt(self, instruction, input, output, reflection, corrected):
 
 
227
  return next(
228
  iter(
229
  self.prompter.build_prompt(
@@ -257,7 +316,11 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
257
 
258
 
259
  class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
260
- def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
 
 
 
 
261
  return (
262
  prompt["instruction"],
263
  prompt["input"] if "input" in prompt else "",
@@ -268,6 +331,10 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
268
 
269
 
270
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
 
 
 
 
271
  def get_conversation_thread(self, prompt):
272
  return prompt["conversations"]
273
 
@@ -281,7 +348,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
281
  user_token = self._get_user_token()
282
  assistant_token = self._get_assistant_token()
283
  try:
284
- for i, part in enumerate(
285
  self.prompter.build_prompt(self.get_conversation_thread(prompt))
286
  ):
287
  if isinstance(part, tuple):
@@ -307,7 +374,7 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
307
  # not masked out from labels
308
  labels = copy.deepcopy(res["input_ids"])
309
  else:
310
- logging.warning("unhandled role: " + part[0])
311
  else:
312
  # this is only ever the first part, should include the bos token and the user query
313
  res = self._tokenize(
@@ -324,8 +391,8 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
324
  result["labels"][current_len : current_len + input_len] = labels
325
  current_len += input_len
326
  return result
327
- except (KeyError, AssertionError, IndexError) as e:
328
- raise InvalidDataException(str(e))
329
 
330
  def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
331
  result = self.tokenizer(
 
1
+ """Module containing PromptTokenizingStrategy and Prompter classes"""
2
+
3
  import abc
4
  import copy
5
  import functools
6
  import logging
7
+ from typing import Tuple
8
 
9
  from transformers import PreTrainedTokenizer
10
 
 
18
 
19
 
20
  class InvalidDataException(Exception):
21
+ """
22
+ Exception raised when the data is invalid
23
+ """
24
 
25
 
26
  class PromptTokenizingStrategy(abc.ABC):
27
+ """
28
+ Abstract class for tokenizing strategies
29
+ """
30
+
31
  def __init__(
32
  self,
33
  prompter,
 
44
  def tokenize_prompt(self, prompt):
45
  pass
46
 
47
+ @functools.lru_cache(maxsize=128)
48
  def _get_user_token(self):
49
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
50
  if isinstance(id_or_ids, (int,)):
51
  return id_or_ids
52
  return False
53
 
54
+ @functools.lru_cache(maxsize=128)
55
  def _get_assistant_token(self):
56
  id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
57
  if isinstance(id_or_ids, (int,)):
 
60
 
61
 
62
  class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
63
+ """
64
+ Tokenizing strategy for instruction-based prompts.
65
+ """
66
+
67
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
68
  raise NotImplementedError
69
 
70
  def tokenize_prompt(self, prompt):
71
+ (
72
+ instruction,
73
+ input, # pylint: disable=redefined-builtin
74
+ response,
75
+ ) = self.parse_instruction_fields(prompt)
76
  full_prompt = self._build_full_prompt(instruction, input, response)
77
  tokenized_full_prompt = self._tokenize(full_prompt)
78
  if not self.train_on_inputs:
 
93
 
94
  return tokenized_full_prompt
95
 
96
+ def _build_full_prompt(
97
+ self, instruction, input, response # pylint: disable=redefined-builtin
98
+ ):
99
  return next(
100
  iter(
101
  self.prompter.build_prompt(
 
131
 
132
 
133
  class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
134
+ """
135
+ Tokenizing strategy for Alpaca prompts.
136
+ """
137
+
138
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
139
  return (
140
  prompt["instruction"],
141
  prompt["input"] if "input" in prompt else "",
 
144
 
145
 
146
  class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
147
+ """
148
+ Tokenizing strategy for Alpaca Multiple Choice prompts.
149
+ """
150
+
151
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
152
  return (
153
  prompt["question"],
154
  "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
 
157
 
158
 
159
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
160
+ """
161
+ Tokenizing strategy for Jeopardy prompts.
162
+ """
163
+
164
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
165
  return (
166
  prompt["question"],
167
  prompt["category"],
 
170
 
171
 
172
  class OpenAssistantPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
173
+ """
174
+ Tokenizing strategy for OpenAssistant prompts.
175
+ """
176
+
177
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
178
  return (
179
  prompt["INSTRUCTION"],
180
  "",
 
183
 
184
 
185
  class SummarizeTLDRPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
186
+ """
187
+ Tokenizing strategy for SummarizeTLDR prompts.
188
+ """
189
+
190
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
191
  return (
192
  prompt["article"],
193
  "",
 
196
 
197
 
198
  class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
199
+ """
200
+ Tokenizing strategy for GPTeacher prompts.
201
+ """
202
+
203
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
204
  return (
205
  prompt["instruction"],
206
  prompt["input"] if "input" in prompt else "",
 
209
 
210
 
211
  class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
212
+ """
213
+ Tokenizing strategy for NomicGPT4All prompts.
214
+ """
215
+
216
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
217
  return (
218
  prompt["prompt"],
219
  "",
 
222
 
223
 
224
  class CompletionPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
225
+ """
226
+ Tokenizing strategy for Completion prompts.
227
+ """
228
+
229
  def parse_instruction_fields(self, prompt) -> str:
230
  return prompt["text"]
231
 
 
236
 
237
  return tokenized_full_prompt
238
 
239
+ def _build_full_prompt(
240
+ self, instruction, input, response
241
+ ): # pylint: disable=unused-argument, redefined-builtin
242
  return next(iter(self.prompter.build_prompt(instruction)))
243
 
244
 
245
  class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
246
+ """
247
+ Tokenizing strategy for Reflection prompts.
248
+ """
249
+
250
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
251
  raise NotImplementedError
252
 
253
  def tokenize_prompt(self, prompt):
254
  (
255
  instruction,
256
+ input, # pylint: disable=redefined-builtin
257
  output,
258
  reflection,
259
  corrected,
 
280
 
281
  return tokenized_full_prompt
282
 
283
+ def _build_full_prompt(
284
+ self, instruction, input, output, reflection, corrected
285
+ ): # pylint: disable=redefined-builtin
286
  return next(
287
  iter(
288
  self.prompter.build_prompt(
 
316
 
317
 
318
  class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
319
+ """
320
+ Tokenizing strategy for Alpaca Reflection prompts.
321
+ """
322
+
323
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str, str]:
324
  return (
325
  prompt["instruction"],
326
  prompt["input"] if "input" in prompt else "",
 
331
 
332
 
333
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
334
+ """
335
+ Tokenizing strategy for ShareGPT prompts.
336
+ """
337
+
338
  def get_conversation_thread(self, prompt):
339
  return prompt["conversations"]
340
 
 
348
  user_token = self._get_user_token()
349
  assistant_token = self._get_assistant_token()
350
  try:
351
+ for _, part in enumerate(
352
  self.prompter.build_prompt(self.get_conversation_thread(prompt))
353
  ):
354
  if isinstance(part, tuple):
 
374
  # not masked out from labels
375
  labels = copy.deepcopy(res["input_ids"])
376
  else:
377
+ logging.warning(f"unhandled role: {part[0]}")
378
  else:
379
  # this is only ever the first part, should include the bos token and the user query
380
  res = self._tokenize(
 
391
  result["labels"][current_len : current_len + input_len] = labels
392
  current_len += input_len
393
  return result
394
+ except (KeyError, AssertionError, IndexError) as err:
395
+ raise InvalidDataException(str(err)) from err
396
 
397
  def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
398
  result = self.tokenizer(