Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
File size: 3,394 Bytes
2ca0c5e 061118b 2ca0c5e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
# chat helper
class ChatState:
def __init__(self, model, system="", chat_template="auto"):
chat_template = (
type(model).__name__ if chat_template == "auto" else chat_template
)
if chat_template == "Llama3CausalLM":
self.__START_TURN_SYSTEM__ = (
"<|start_header_id|>system<|end_header_id|>\n\n"
)
self.__START_TURN_USER__ = (
"<|start_header_id|>user<|end_header_id|>\n\n"
)
self.__START_TURN_MODEL__ = (
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
self.__END_TURN_SYSTEM__ = "<|eot_id|>"
self.__END_TURN_USER__ = "<|eot_id|>"
self.__END_TURN_MODEL__ = "<|eot_id|>"
print("Using chat template for: Llama")
elif chat_template == "GemmaCausalLM":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "<start_of_turn>user\n"
self.__START_TURN_MODEL__ = "<start_of_turn>model\n"
self.__END_TURN_SYSTEM__ = "\n"
self.__END_TURN_USER__ = "<end_of_turn>\n"
self.__END_TURN_MODEL__ = "<end_of_turn>\n"
print("Using chat template for: Gemma")
elif chat_template == "MistralCausalLM":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "[INST]"
self.__START_TURN_MODEL__ = ""
self.__END_TURN_SYSTEM__ = "<s>"
self.__END_TURN_USER__ = "[/INST]"
self.__END_TURN_MODEL__ = "</s>"
print("Using chat template for: Mistral")
elif chat_template == "Vicuna":
self.__START_TURN_SYSTEM__ = ""
self.__START_TURN_USER__ = "USER: "
self.__START_TURN_MODEL__ = "ASSISTANT: "
self.__END_TURN_SYSTEM__ = "\n\n"
self.__END_TURN_USER__ = "\n"
self.__END_TURN_MODEL__ = "</s>\n"
print("Using chat template for : Vicuna")
else:
assert (0, "Unknown turn tags for this model class")
self.model = model
self.system = system
self.history = []
def add_to_history_as_user(self, message):
self.history.append(
self.__START_TURN_USER__ + message + self.__END_TURN_USER__
)
def add_to_history_as_model(self, message):
self.history.append(
self.__START_TURN_MODEL__ + message + self.__END_TURN_MODEL__
)
def get_history(self):
return "".join([*self.history])
def get_full_prompt(self):
prompt = self.get_history() + self.__START_TURN_MODEL__
if len(self.system) > 0:
prompt = (
self.__START_TURN_SYSTEM__
+ self.system
+ self.__END_TURN_SYSTEM__
+ prompt
)
return prompt
def send_message(self, message):
"""
Handles sending a user message and getting a model response.
Args:
message: The user's message.
Returns:
The model's response.
"""
self.add_to_history_as_user(message)
prompt = self.get_full_prompt()
response = self.model.generate(
prompt, max_length=2048, strip_prompt=True
)
self.add_to_history_as_model(response)
return (message, response)
|