|
""" |
|
Prompt strategies loader for alpaca instruction datasets with system prompts |
|
""" |
|
from typing import Generator, Tuple, Union |
|
|
|
from axolotl.prompt_tokenizers import PromptTokenizingStrategy |
|
from axolotl.prompters import AlpacaPrompter, PromptStyle |
|
|
|
|
|
class InstructionWSystemPromptTokenizingStrategy(PromptTokenizingStrategy): |
|
""" |
|
Tokenizing strategy for instruction-based prompts. |
|
""" |
|
|
|
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str, str]: |
|
return ( |
|
prompt["instruction"], |
|
prompt["input"] if "input" in prompt else "", |
|
prompt["output"], |
|
prompt["system"], |
|
) |
|
|
|
def tokenize_prompt(self, prompt): |
|
|
|
( |
|
instruction, |
|
input, |
|
response, |
|
system, |
|
) = self.parse_instruction_fields(prompt) |
|
user_prompt = next( |
|
iter( |
|
self.prompter.build_prompt_w_system( |
|
system, |
|
instruction, |
|
input, |
|
) |
|
) |
|
) |
|
tokenized_prompt = self._tokenize(user_prompt, add_eos_token=False) |
|
if not self.train_on_inputs: |
|
user_prompt_len = len(tokenized_prompt["input_ids"]) |
|
|
|
tokenized_prompt["labels"] = [-100] * user_prompt_len |
|
tokenized_res_prompt = self._tokenize( |
|
response, strip_bos_token=True, add_eos_token=True |
|
) |
|
tokenized_prompt["input_ids"] += tokenized_res_prompt["input_ids"] |
|
tokenized_prompt["attention_mask"] += tokenized_res_prompt["attention_mask"] |
|
tokenized_prompt["labels"] += tokenized_res_prompt["input_ids"] |
|
|
|
return tokenized_prompt |
|
|
|
|
|
class SystemDataPrompter(AlpacaPrompter): |
|
""" |
|
Alpaca Style Prompter that uses system prompts from the dataset |
|
""" |
|
|
|
def build_prompt_w_system( |
|
self, |
|
system: str, |
|
instruction: str, |
|
input: Union[None, str] = None, |
|
output: Union[None, str] = None, |
|
) -> Generator[str, None, None]: |
|
|
|
|
|
if input: |
|
res = system + self.turn_format.format(instruction=instruction, input=input) |
|
else: |
|
res = system + self.turn_no_input_format.format(instruction=instruction) |
|
if output: |
|
res = f"{res}{output}" |
|
yield res |
|
|
|
|
|
def load(tokenizer, cfg): |
|
return InstructionWSystemPromptTokenizingStrategy( |
|
SystemDataPrompter(PromptStyle.CHAT.value), |
|
tokenizer, |
|
cfg.train_on_inputs, |
|
cfg.sequence_len, |
|
) |
|
|