winglian commited on
Commit
7925ddc
·
1 Parent(s): 4b43a66

bugfix for potential off by one

Browse files
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -40,6 +40,18 @@ class AlpacaChatPrompter(AlpacaPrompter):
40
  self.match_prompt_style()
41
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
44
  """
45
  Tokenizing strategy for AlpacaQA
 
40
  self.match_prompt_style()
41
 
42
 
43
+ class NoSystemPrompter(AlpacaPrompter):
44
+ """
45
+ Null Prompter with no system prompts
46
+ """
47
+
48
+ prompt_input = "{instruction} {input} "
49
+ prompt_no_input = "{instruction} "
50
+
51
+ def __init__(self): # pylint: disable=super-init-not-called
52
+ pass
53
+
54
+
55
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
56
  """
57
  Tokenizing strategy for AlpacaQA
src/axolotl/prompt_tokenizers.py CHANGED
@@ -96,25 +96,27 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
96
  input, # pylint: disable=redefined-builtin
97
  response,
98
  ) = self.parse_instruction_fields(prompt)
99
- full_prompt = self._build_full_prompt(instruction, input, response)
100
- tokenized_full_prompt = self._tokenize(full_prompt)
101
- if not self.train_on_inputs:
102
- user_prompt = next(
103
- iter(
104
- self.prompter.build_prompt(
105
- instruction,
106
- input,
107
- )
108
  )
109
  )
110
- tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
111
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
 
 
112
  # TODO this could be sped up using numpy array slicing
113
- tokenized_full_prompt["labels"] = [
114
- -100
115
- ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
 
 
 
 
116
 
117
- return tokenized_full_prompt
118
 
119
  def _build_full_prompt(
120
  self, instruction, input, response # pylint: disable=redefined-builtin
 
96
  input, # pylint: disable=redefined-builtin
97
  response,
98
  ) = self.parse_instruction_fields(prompt)
99
+ user_prompt = next(
100
+ iter(
101
+ self.prompter.build_prompt(
102
+ instruction,
103
+ input,
 
 
 
 
104
  )
105
  )
106
+ )
107
+ tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False)
108
+ if not self.train_on_inputs:
109
+ user_prompt_len = len(tokenized_prompt["input_ids"])
110
  # TODO this could be sped up using numpy array slicing
111
+ tokenized_prompt["labels"] = [-100] * user_prompt_len
112
+ tokenized_res_prompt = self._tokenize(
113
+ response, strip_bos_token=True, add_eos_token=True
114
+ )
115
+ tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"]
116
+ tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"]
117
+ tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"]
118
 
119
+ return tokenized_prompt
120
 
121
  def _build_full_prompt(
122
  self, instruction, input, response # pylint: disable=redefined-builtin
tests/test_prompt_tokenizers.py CHANGED
@@ -6,8 +6,12 @@ from pathlib import Path
6
 
7
  from transformers import AutoTokenizer
8
 
9
- from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
10
- from axolotl.prompters import ShareGPTPrompter
 
 
 
 
11
 
12
  logging.basicConfig(level="INFO")
13
 
@@ -29,7 +33,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
29
  )
30
 
31
  def test_sharegpt_integration(self):
32
- print(Path(__file__).parent)
33
  with open(
34
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
35
  ) as fin:
@@ -53,6 +56,43 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
53
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
54
  self.assertEqual(example[fields], tokenized_conversation[fields])
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  if __name__ == "__main__":
58
  unittest.main()
 
6
 
7
  from transformers import AutoTokenizer
8
 
9
+ from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
10
+ from axolotl.prompt_tokenizers import (
11
+ AlpacaPromptTokenizingStrategy,
12
+ ShareGPTPromptTokenizingStrategy,
13
+ )
14
+ from axolotl.prompters import AlpacaPrompter, ShareGPTPrompter
15
 
16
  logging.basicConfig(level="INFO")
17
 
 
33
  )
34
 
35
  def test_sharegpt_integration(self):
 
36
  with open(
37
  Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
38
  ) as fin:
 
56
  self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
57
  self.assertEqual(example[fields], tokenized_conversation[fields])
58
 
59
+ def test_completion(self):
60
+ """
61
+ tests the interface between the user and assistant parts
62
+ """
63
+ prompter = NoSystemPrompter()
64
+ strat = AlpacaPromptTokenizingStrategy(
65
+ prompter,
66
+ self.tokenizer,
67
+ False,
68
+ 2048,
69
+ )
70
+ sample = {
71
+ "instruction": "hello cruel. lorem ipsum dolor sit amet.",
72
+ "output": "world!",
73
+ }
74
+ example = strat.tokenize_prompt(sample)
75
+ world_idx = example["input_ids"].index(3186)
76
+ assert example["labels"][world_idx] == 3186
77
+ assert example["labels"][world_idx - 1] == -100
78
+
79
+ def test_alpaca(self):
80
+ """
81
+ tests the interface between the user and assistant parts
82
+ """
83
+ prompter = AlpacaPrompter()
84
+ strat = AlpacaPromptTokenizingStrategy(
85
+ prompter,
86
+ self.tokenizer,
87
+ False,
88
+ 2048,
89
+ )
90
+ sample = {"instruction": "hello!", "output": "Hi! How can I help?"}
91
+ example = strat.tokenize_prompt(sample)
92
+ world_idx = example["input_ids"].index(6324)
93
+ assert example["labels"][world_idx] == 6324
94
+ assert example["labels"][world_idx - 1] == -100
95
+
96
 
97
  if __name__ == "__main__":
98
  unittest.main()