""" Prompt Strategy for finetuning Llama2 chat models see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation. This implementation is based on the Vicuna PR and the fastchat repo, see also: https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847 Use dataset type: "llama2_chat" in conig.yml to use this prompt style. E.g. in the config.yml: ``` datasets: - path: llama_finetune_train.jsonl type: llama2_chat ``` The dataset itself should look like this: ``` {'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]} ``` in a jsonl file. The first message should be from the human, the second from gpt. For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns). Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing! """ import logging from dataclasses import dataclass, field from typing import Generator, List, Sequence from axolotl.prompt_tokenizers import PromptTokenizingStrategy from axolotl.prompters import IGNORE_TOKEN_ID, SHAREGPT_ASSERTION_FAILED_ROLE @dataclass class Llama2ChatConversation: """A class that manages prompt templates and keeps all conversation history. copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py""" name: str = "llama2" # The system prompt system: str = ( "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n\n" "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information.\n<>\n\n" ) roles: Sequence[str] = ("[INST]", "[/INST]") messages: List[List[str]] = field(default_factory=list) offset: int = 0 sep = " " sep2 = " " stop_token_ids = [2] def get_prompt(self) -> str: """Get the prompt for generation.""" seps = [self.sep, self.sep2] ret = "" for i, (role, message) in enumerate(self.messages): if (i == len(self.messages) - 1) and (role == self.roles[0]): # last message is from user (due to length), # return prompt without it for training return ret if i == 0: ret += self.system + message.strip() else: ret += role + " " + message.strip() + seps[i % 2] return ret def append_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy): """ Tokenizing strategy for ShareGPT prompts. adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.tokenizer.add_special_tokens( {"pad_token": getattr(self.tokenizer, "pad_token", "")} ) # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json def tokenize_prompt(self, prompt): conv = next(self.prompter.build_prompt(prompt)) conversation_str = conv.get_prompt() # Tokenize conversations input_ids = self.tokenizer( conversation_str, return_tensors="pt", padding="max_length", max_length=self.sequence_len, truncation=True, ).input_ids[0] target = input_ids.clone() # Mask targets. Only compute loss on the assistant outputs. sep = conv.roles[1] total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) turns = conversation_str.split(conv.sep2) cur_len = 1 target[:cur_len] = IGNORE_TOKEN_ID for turn in turns: if turn == "": break turn_len = len(self.tokenizer(turn).input_ids) parts = turn.split(sep) if len(parts) != 2: break parts[0] += sep # "-1" is hardcoded for the LLaMA tokenizer to make the offset correct. instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1 # Ignore the user instructions target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID cur_len += turn_len + 2 # due to length of role token target[cur_len:] = IGNORE_TOKEN_ID if cur_len < self.sequence_len: if cur_len != total_len: target[:] = IGNORE_TOKEN_ID logging.warning( f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)" ) attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist() input_ids = input_ids.tolist() target = target.tolist() # this is a fix for the tokenizer which tokenizes [ differently with eos tokens and # follows the original llama implementation for i in range(2, total_len - 2): if input_ids[i] == 29961: input_ids[i] = 518 if target[i] == 29961: target[i] = 518 return { "input_ids": input_ids, "labels": target, "attention_mask": attention_mask, } class Llama2ChatPrompter: # pylint: disable=too-few-public-methods """ A prompter that generates prompts for Llama2 models. """ system_prompt = ( "[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " "Please ensure that your responses are socially unbiased and positive in nature.\n\n" "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " "If you don't know the answer to a question, please don't share false information.\n<>\n\n" ) def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]: # see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78 source = source["conversations"] # fix data structure for datasets # if system prompt provided, use it if source[0]["from"] == "system": system = f"[INST] <>\n{source[0]['value']}\n<>\n\n" source = source[1:] else: system = self.system_prompt conv = Llama2ChatConversation(system=system) if len(source) < 2: # If there isn't a back and forth conversation, ignore it # also happens on the data splitting leaving empty conversations raise IndexError roles = {"human": conv.roles[0], "gpt": conv.roles[1]} if roles[source[0]["from"]] != conv.roles[0]: # Skip the first one if it is not from human source = source[1:] conv.messages = [] # pylint: disable=R0801 for j, sentence in enumerate(source): role = roles[sentence["from"]] assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE if sentence["value"]: conv.append_message(role, sentence["value"]) yield conv def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy: return LLama2ChatTokenizingStrategy( Llama2ChatPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len, )