File size: 3,394 Bytes
2ca0c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
061118b
2ca0c5e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# chat helper
class ChatState:

    def __init__(self, model, system="", chat_template="auto"):
        chat_template = (
            type(model).__name__ if chat_template == "auto" else chat_template
        )

        if chat_template == "Llama3CausalLM":
            self.__START_TURN_SYSTEM__ = (
                "<|start_header_id|>system<|end_header_id|>\n\n"
            )
            self.__START_TURN_USER__ = (
                "<|start_header_id|>user<|end_header_id|>\n\n"
            )
            self.__START_TURN_MODEL__ = (
                "<|start_header_id|>assistant<|end_header_id|>\n\n"
            )
            self.__END_TURN_SYSTEM__ = "<|eot_id|>"
            self.__END_TURN_USER__ = "<|eot_id|>"
            self.__END_TURN_MODEL__ = "<|eot_id|>"
            print("Using chat template for: Llama")
        elif chat_template == "GemmaCausalLM":
            self.__START_TURN_SYSTEM__ = ""
            self.__START_TURN_USER__ = "<start_of_turn>user\n"
            self.__START_TURN_MODEL__ = "<start_of_turn>model\n"
            self.__END_TURN_SYSTEM__ = "\n"
            self.__END_TURN_USER__ = "<end_of_turn>\n"
            self.__END_TURN_MODEL__ = "<end_of_turn>\n"
            print("Using chat template for: Gemma")
        elif chat_template == "MistralCausalLM":
            self.__START_TURN_SYSTEM__ = ""
            self.__START_TURN_USER__ = "[INST]"
            self.__START_TURN_MODEL__ = ""
            self.__END_TURN_SYSTEM__ = "<s>"
            self.__END_TURN_USER__ = "[/INST]"
            self.__END_TURN_MODEL__ = "</s>"
            print("Using chat template for: Mistral")
        elif chat_template == "Vicuna":
            self.__START_TURN_SYSTEM__ = ""
            self.__START_TURN_USER__ = "USER: "
            self.__START_TURN_MODEL__ = "ASSISTANT: "
            self.__END_TURN_SYSTEM__ = "\n\n"
            self.__END_TURN_USER__ = "\n"
            self.__END_TURN_MODEL__ = "</s>\n"
            print("Using chat template for : Vicuna")
        else:
            assert (0, "Unknown turn tags for this model class")

        self.model = model
        self.system = system
        self.history = []

    def add_to_history_as_user(self, message):
        self.history.append(
            self.__START_TURN_USER__ + message + self.__END_TURN_USER__
        )

    def add_to_history_as_model(self, message):
        self.history.append(
            self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__
        )

    def get_history(self):
        return "".join([*self.history])

    def get_full_prompt(self):
        prompt = self.get_history() + self.__START_TURN_MODEL__
        if len(self.system) > 0:
            prompt = (
                self.__START_TURN_SYSTEM__
                + self.system
                + self.__END_TURN_SYSTEM__
                + prompt
            )
        return prompt

    def send_message(self, message):
        """
        Handles sending a user message and getting a model response.

        Args:
            message: The user's message.

        Returns:
            The model's response.
        """
        self.add_to_history_as_user(message)
        prompt = self.get_full_prompt()
        response = self.model.generate(
            prompt, max_length=2048, strip_prompt=True
        )
        self.add_to_history_as_model(response)
        return (message, response)