add alpaca multiple choice instruct dataset support
Browse files- scripts/finetune.py +1 -1
- src/axolotl/prompt_tokenizers.py +9 -0
- src/axolotl/prompters.py +4 -0
- src/axolotl/utils/data.py +8 -2
scripts/finetune.py
CHANGED
@@ -67,7 +67,7 @@ def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
|
|
67 |
instruction = get_multi_line_input()
|
68 |
if not instruction:
|
69 |
return
|
70 |
-
prompt = prompter_module().build_prompt(instruction=instruction)
|
71 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
72 |
|
73 |
model.eval()
|
|
|
67 |
instruction = get_multi_line_input()
|
68 |
if not instruction:
|
69 |
return
|
70 |
+
prompt: str = next(prompter_module().build_prompt(instruction=instruction))
|
71 |
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
72 |
|
73 |
model.eval()
|
src/axolotl/prompt_tokenizers.py
CHANGED
@@ -92,6 +92,15 @@ class AlpacaPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
|
92 |
)
|
93 |
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
96 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
97 |
return (
|
|
|
92 |
)
|
93 |
|
94 |
|
95 |
+
class AlpacaMultipleChoicePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
96 |
+
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
97 |
+
return (
|
98 |
+
prompt["question"],
|
99 |
+
"\n".join(f'- "{choice}"' for choice in prompt["choices"]),
|
100 |
+
prompt["explanation"],
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
class JeopardyPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
105 |
def parse_instruction_fields(self, prompt) -> (str, str, str):
|
106 |
return (
|
src/axolotl/prompters.py
CHANGED
@@ -35,6 +35,10 @@ class JeopardyPrompter(AlpacaPrompter):
|
|
35 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
38 |
class CompletionPrompter(AlpacaPrompter):
|
39 |
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
|
40 |
yield instruction
|
|
|
35 |
prompt_input = "Below is a Jeopardy clue paired with input providing the category of the clue. Write a concise response that best answers tbe clue given the category.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
36 |
|
37 |
|
38 |
+
class MultipleChoiceExplainPrompter(AlpacaPrompter):
|
39 |
+
prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n"
|
40 |
+
|
41 |
+
|
42 |
class CompletionPrompter(AlpacaPrompter):
|
43 |
def build_prompt(self, instruction: str, input=None, output=None) -> Generator[str, None, None]:
|
44 |
yield instruction
|
src/axolotl/utils/data.py
CHANGED
@@ -19,7 +19,7 @@ from axolotl.prompt_tokenizers import (
|
|
19 |
AlpacaReflectionPTStrategy,
|
20 |
ShareGPTPromptTokenizingStrategy,
|
21 |
JeopardyPromptTokenizingStrategy,
|
22 |
-
CompletionPromptTokenizingStrategy,
|
23 |
)
|
24 |
from axolotl.prompters import (
|
25 |
AlpacaPrompter,
|
@@ -27,7 +27,7 @@ from axolotl.prompters import (
|
|
27 |
ReflectAlpacaPrompter,
|
28 |
ShareGPTPrompter,
|
29 |
JeopardyPrompter,
|
30 |
-
CompletionPrompter,
|
31 |
)
|
32 |
|
33 |
|
@@ -88,6 +88,12 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
|
88 |
)
|
89 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
90 |
datasets.append(ds_wrapper)
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
elif d.type == "jeopardy":
|
92 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
93 |
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
|
|
19 |
AlpacaReflectionPTStrategy,
|
20 |
ShareGPTPromptTokenizingStrategy,
|
21 |
JeopardyPromptTokenizingStrategy,
|
22 |
+
CompletionPromptTokenizingStrategy, AlpacaMultipleChoicePromptTokenizingStrategy,
|
23 |
)
|
24 |
from axolotl.prompters import (
|
25 |
AlpacaPrompter,
|
|
|
27 |
ReflectAlpacaPrompter,
|
28 |
ShareGPTPrompter,
|
29 |
JeopardyPrompter,
|
30 |
+
CompletionPrompter, MultipleChoiceExplainPrompter,
|
31 |
)
|
32 |
|
33 |
|
|
|
88 |
)
|
89 |
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
90 |
datasets.append(ds_wrapper)
|
91 |
+
elif d.type == "explainchoice":
|
92 |
+
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
93 |
+
MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|
94 |
+
)
|
95 |
+
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
96 |
+
datasets.append(ds_wrapper)
|
97 |
elif d.type == "jeopardy":
|
98 |
ds_strategy = JeopardyPromptTokenizingStrategy(
|
99 |
JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
|