File size: 983 Bytes
8cc0aad ce34d64 4ea9a66 8cc0aad ce34d64 4ea9a66 3a50377 8cc0aad 3a50377 8cc0aad ce34d64 3a50377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Tuple
from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy,
InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle
def load(tokenizer, cfg):
return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
"""
Tokenizing strategy for AlpacaQA
"""
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
return (
prompt["question"],
"",
prompt["answer"],
)
def load_qa(tokenizer, cfg):
return AlpacaQAPromptTokenizingStrategy(
AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
|