File size: 3,448 Bytes
8cc0aad
 
 
37293dc
ce34d64
 
 
 
612aabd
4ea9a66
 
 
 
8cc0aad
ce34d64
 
 
4ea9a66
3a50377
 
4ac9e25
 
4b43a66
4ac9e25
 
4b43a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4ac9e25
 
7925ddc
 
 
 
 
 
 
 
 
 
 
 
3a50377
8cc0aad
 
 
 
 
3a50377
 
 
 
 
 
 
4ac9e25
 
 
 
 
 
 
 
 
59bb219
4ac9e25
 
 
 
 
 
 
 
 
 
 
 
3a50377
 
4b43a66
ce34d64
 
 
3a50377
4ac9e25
 
 
 
4b43a66
4ac9e25
 
 
 
612aabd
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""

from typing import Tuple

from axolotl.prompt_tokenizers import (
    AlpacaPromptTokenizingStrategy,
    InstructionPromptTokenizingStrategy,
)
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter


def load(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        AlpacaPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


class AlpacaConcisePrompter(AlpacaPrompter):
    """
    Alpaca Prompter extending the system prompt to ask for concise chat-instruct answers
    """

    system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
    system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"


class AlpacaChatPrompter(AlpacaPrompter):
    """
    Alpaca Chat Prompter extending the system prompt to for chat-instruct answers
    """

    system_prompt = "Below is an instruction from a USER that describes a task, paired with an input that provides further context. The ASSISTANT writes a response that concisely and appropriately completes the request.\n\n"
    system_no_input_prompt = "Below is an instruction from a USER that describes a task. The ASSISTANT writes a response that appropriately and concisely completes the request.\n\n"

    def __init__(self):  # pylint: disable=super-init-not-called
        self.prompt_style = PromptStyle.CHAT.value
        self.match_prompt_style()


class NoSystemPrompter(AlpacaPrompter):
    """
    Null Prompter with no system prompts
    """

    prompt_input = "{instruction} {input} "
    prompt_no_input = "{instruction} "

    def __init__(self):  # pylint: disable=super-init-not-called
        pass


class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for AlpacaQA
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["question"],
            "",
            prompt["answer"],
        )


class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
    """
    Tokenizing strategy for CamelAI datasets
    """

    def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
        return (
            prompt["message_1"],
            "",
            prompt["message_2"],
        )


def load_concise(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        AlpacaConcisePrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_qa(tokenizer, cfg):
    return AlpacaQAPromptTokenizingStrategy(
        AlpacaChatPrompter(),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_camel_ai(tokenizer, cfg):
    return CamelAIPromptTokenizingStrategy(
        AlpacaChatPrompter(),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )


def load_no_prompt(tokenizer, cfg):
    return AlpacaPromptTokenizingStrategy(
        UnpromptedPrompter(PromptStyle.CHAT.value),
        tokenizer,
        cfg.train_on_inputs,
        cfg.sequence_len,
    )