winglian commited on
Commit
b46bc02
·
1 Parent(s): f98e173

add alpaca multiple choice instruct dataset support

Browse files
scripts/finetune.py CHANGED
@@ -67,7 +67,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
67
  instruction = get_multi_line_input()
68
  if not instruction:
69
  return
70
- prompt = prompter_module().build_prompt(instruction=instruction)
71
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
72
 
73
  model.eval()
 
67
  instruction = get_multi_line_input()
68
  if not instruction:
69
  return
70
+ prompt: str = next(prompter_module().build_prompt(instruction=instruction))
71
  batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
72
 
73
  model.eval()
src/axolotl/prompt_tokenizers.py CHANGED
@@ -92,6 +92,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
92
  )
93
 
94
 
 
 
 
 
 
 
 
 
 
95
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
96
  def parse_instruction_fields(self, prompt) -> (str, str, str):
97
  return (
 
92
  )
93
 
94
 
95
+ class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
96
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
97
+ return (
98
+ prompt["question"],
99
+ "\n".join(f'- "{choice}"' for choice in prompt["choices"]),
100
+ prompt["explanation"],
101
+ )
102
+
103
+
104
  class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
105
  def parse_instruction_fields(self, prompt) -> (str, str, str):
106
  return (
src/axolotl/prompters.py CHANGED
@@ -35,6 +35,10 @@ class JeopardyPrompter(AlpacaPrompter):
35
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
36
 
37
 
 
 
 
 
38
  class CompletionPrompter(AlpacaPrompter):
39
  def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
40
  yield instruction
 
35
  prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
36
 
37
 
38
+ class MultipleChoiceExplainPrompter(AlpacaPrompter):
39
+ prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
40
+
41
+
42
  class CompletionPrompter(AlpacaPrompter):
43
  def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
44
  yield instruction
src/axolotl/utils/data.py CHANGED
@@ -19,7 +19,7 @@ from axolotl.prompt_tokenizers import (
19
  AlpacaReflectionPTStrategy,
20
  ShareGPTPromptTokenizingStrategy,
21
  JeopardyPromptTokenizingStrategy,
22
- CompletionPromptTokenizingStrategy,
23
  )
24
  from axolotl.prompters import (
25
  AlpacaPrompter,
@@ -27,7 +27,7 @@ from axolotl.prompters import (
27
  ReflectAlpacaPrompter,
28
  ShareGPTPrompter,
29
  JeopardyPrompter,
30
- CompletionPrompter,
31
  )
32
 
33
 
@@ -88,6 +88,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
88
  )
89
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
90
  datasets.append(ds_wrapper)
 
 
 
 
 
 
91
  elif d.type == "jeopardy":
92
  ds_strategy = JeopardyPromptTokenizingStrategy(
93
  JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
19
  AlpacaReflectionPTStrategy,
20
  ShareGPTPromptTokenizingStrategy,
21
  JeopardyPromptTokenizingStrategy,
22
+ CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
23
  )
24
  from axolotl.prompters import (
25
  AlpacaPrompter,
 
27
  ReflectAlpacaPrompter,
28
  ShareGPTPrompter,
29
  JeopardyPrompter,
30
+ CompletionPrompter, MultipleChoiceExplainPrompter,
31
  )
32
 
33
 
 
88
  )
89
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
90
  datasets.append(ds_wrapper)
91
+ elif d.type == "explainchoice":
92
+ ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
93
+ MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
94
+ )
95
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
96
+ datasets.append(ds_wrapper)
97
  elif d.type == "jeopardy":
98
  ds_strategy = JeopardyPromptTokenizingStrategy(
99
  JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len