winglian commited on
Commit
0f74464
·
1 Parent(s): e0602a9

fix new dataset prompt tokenizers

Browse files
src/axolotl/datasets.py CHANGED
@@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset):
106
  }
107
  else:
108
  logging.warning(
109
- "dropping batch due to tensor size mismatch"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
 
106
  }
107
  else:
108
  logging.warning(
109
+ f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}"
110
  )
111
  buffer = {"input_ids": [], "attention_mask": [], "labels": []}
112
  buffer_len = 0
src/axolotl/prompt_strategies/__init__.py CHANGED
@@ -1,11 +1,13 @@
1
  import importlib
2
- from functools import cache
3
 
4
- @cache
5
  def load(strategy, tokenizer, cfg):
6
  try:
7
- m = importlib.import_module(f".{strategy}", axolotl.prompt_strategies)
8
- fn = getattr(m, "load")
 
 
 
 
9
  return fn(tokenizer, cfg)
10
  except:
11
  pass
 
1
  import importlib
 
2
 
 
3
  def load(strategy, tokenizer, cfg):
4
  try:
5
+ load_fn = "load"
6
+ if strategy.split(".")[-1].startswith("load_"):
7
+ load_fn = strategy.split(".")[-1]
8
+ strategy = ".".join(strategy.split(".")[:-1])
9
+ m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
10
+ fn = getattr(m, load_fn)
11
  return fn(tokenizer, cfg)
12
  except:
13
  pass
src/axolotl/prompt_strategies/creative_acr.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Generator
2
+
3
+ import yaml
4
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5
+
6
+
7
+ class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
8
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
9
+ question = prompt["instruction"]
10
+ answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset
11
+ return (
12
+ question,
13
+ "",
14
+ answer,
15
+ )
16
+
17
+
18
+ class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
19
+ user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria:
20
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
21
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias.
22
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
23
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
24
+ Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria.
25
+
26
+ Output your answer in YAML format like so:
27
+ scores:
28
+ refusal: <0-10>
29
+ prescriptive_bias: <0-10>
30
+ creativity: <0-10>
31
+ comprehensiveness: <0-10>
32
+ critiques:
33
+ refusal:
34
+ explanation: ...
35
+ improvements: ...
36
+ prescriptive_bias:
37
+ explanation: ...
38
+ improvements: ...
39
+ creativity:
40
+ explanation: ...
41
+ improvements: ...
42
+ comprehensiveness:
43
+ explanation: ...
44
+ improvements: ...
45
+
46
+ Question: {question}
47
+ Answer: {answer}
48
+ """
49
+
50
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
51
+ scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
52
+ critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
53
+ evaluation = scores + critiques
54
+ question = prompt["instruction"]
55
+ answer = prompt["answer"]
56
+ return (
57
+ self.user_prompt.format(question=question, answer=answer),
58
+ "",
59
+ evaluation,
60
+ )
61
+
62
+
63
+ class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
64
+ user_prompt = """Definitions:
65
+ refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question.
66
+ prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias.
67
+ creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity
68
+ comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth
69
+
70
+ Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response.
71
+
72
+ Question: {question}
73
+ Answer: {answer}
74
+ Evaluation:
75
+ {evaluation}
76
+ """
77
+
78
+ def parse_instruction_fields(self, prompt) -> (str, str, str):
79
+ scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper)
80
+ critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper)
81
+ evaluation = scores + critiques
82
+ question = prompt["instruction"]
83
+ answer = prompt["answer"]
84
+ return (
85
+ self.user_prompt.format(question=question, answer=answer, evaluation=evaluation),
86
+ "",
87
+ prompt["revision"],
88
+ )
89
+
90
+
91
+ class CreativePrompterBase:
92
+ system_prompt = ""
93
+ prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:"
94
+
95
+ def build_prompt(
96
+ self,
97
+ instruction: str,
98
+ input: Union[None, str] = None,
99
+ output: Union[None, str] = None,
100
+ ) -> Generator[str, None, None]:
101
+ if self.system_prompt:
102
+ res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:"
103
+ else:
104
+ res = f"USER: {instruction}\nASSISTANT:"
105
+ if output:
106
+ res = f"{res}{output}"
107
+ yield res
108
+
109
+
110
+ class CreativeAnswerPrompter(CreativePrompterBase):
111
+ system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity."
112
+
113
+
114
+ class CreativeCritiquePrompter(CreativePrompterBase):
115
+ system_prompt = ""
116
+
117
+
118
+ class CreativeRevisePrompter(CreativePrompterBase):
119
+ system_prompt = ""
120
+
121
+
122
+ def load_answer(tokenizer, cfg):
123
+ return CreativeAnsweringPromptTokenizingStrategy(
124
+ CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
125
+ )
126
+
127
+
128
+ def load_critique(tokenizer, cfg):
129
+ return CreativeCritiquePromptTokenizingStrategy(
130
+ CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
131
+ )
132
+
133
+
134
+ def load_revise(tokenizer, cfg):
135
+ return CreativeRevisePromptTokenizingStrategy(
136
+ CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
137
+ )
src/axolotl/prompt_strategies/pygmalion.py CHANGED
@@ -41,9 +41,9 @@ class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy):
41
  elif role == "bot":
42
  prefix = "<|model|>"
43
  res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
44
- res["input_ids"] = [*self.bot_prefix_token_ids, *res["input_ids"]]
45
  # mask out the prefix token, rest is not masked out from labels
46
- labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])]
 
47
  else:
48
  logging.warning(f"unknown role in conversation: {role}")
49
  res = defaultdict(lambda: [])
 
41
  elif role == "bot":
42
  prefix = "<|model|>"
43
  res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True)
 
44
  # mask out the prefix token, rest is not masked out from labels
45
+ # make sure we create the labels first, otherwise we get incorrect lengths
46
+ labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):]
47
  else:
48
  logging.warning(f"unknown role in conversation: {role}")
49
  res = defaultdict(lambda: [])
src/axolotl/utils/data.py CHANGED
@@ -75,7 +75,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
75
  ds = None
76
  ds_from_hub = False
77
  try:
78
- load_dataset(d.path, streaming=True)
79
  ds_from_hub = True
80
  except FileNotFoundError:
81
  pass
@@ -83,18 +83,18 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
83
  # prefer local dataset, even if hub exists
84
  if Path(d.path).exists():
85
  ds: IterableDataset = load_dataset(
86
- "json", data_files=d.path, streaming=True, split=None
87
  )
88
  elif ds_from_hub:
89
  if d.data_files:
90
- ds = load_dataset(d.path, streaming=True, data_files=d.data_files)
91
  else:
92
- ds = load_dataset(d.path, streaming=True)
93
  else:
94
  fp = hf_hub_download(
95
  repo_id=d.path, repo_type="dataset", filename=d.data_files
96
  )
97
- ds = load_dataset("json", data_files=fp, streaming=True, split=None)
98
  if not ds:
99
  raise Exception("unhandled dataset load")
100
  d_type = d.type
 
75
  ds = None
76
  ds_from_hub = False
77
  try:
78
+ load_dataset(d.path, streaming=True, use_auth_token=True)
79
  ds_from_hub = True
80
  except FileNotFoundError:
81
  pass
 
83
  # prefer local dataset, even if hub exists
84
  if Path(d.path).exists():
85
  ds: IterableDataset = load_dataset(
86
+ "json", data_files=d.path, streaming=False, split=None
87
  )
88
  elif ds_from_hub:
89
  if d.data_files:
90
+ ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True)
91
  else:
92
+ ds = load_dataset(d.path, streaming=False, use_auth_token=True)
93
  else:
94
  fp = hf_hub_download(
95
  repo_id=d.path, repo_type="dataset", filename=d.data_files
96
  )
97
+ ds = load_dataset("json", data_files=fp, streaming=False, split=None)
98
  if not ds:
99
  raise Exception("unhandled dataset load")
100
  d_type = d.type