pseudotensor
commited on
Commit
·
6a1fd9e
1
Parent(s):
7134600
Upload h2oai_pipeline.py
Browse files- h2oai_pipeline.py +648 -2
h2oai_pipeline.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
-
|
5 |
-
|
6 |
|
7 |
|
8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
@@ -126,3 +126,649 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
126 |
else:
|
127 |
raise ValueError("TF not avaialble.")
|
128 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import TextGenerationPipeline
|
2 |
from transformers.pipelines.text_generation import ReturnType
|
3 |
|
4 |
+
|
5 |
+
|
6 |
|
7 |
|
8 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
|
126 |
else:
|
127 |
raise ValueError("TF not avaialble.")
|
128 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text}
|
129 |
+
import torch
|
130 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
135 |
+
|
136 |
+
def __init__(self, stops=[], encounters=[], device="cuda"):
|
137 |
+
super().__init__()
|
138 |
+
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
139 |
+
self.encounters = encounters
|
140 |
+
self.stops = [stop.to(device) for stop in stops]
|
141 |
+
self.num_stops = [0] * len(stops)
|
142 |
+
|
143 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
144 |
+
for stopi, stop in enumerate(self.stops):
|
145 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
146 |
+
self.num_stops[stopi] += 1
|
147 |
+
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
148 |
+
# print("Stopped", flush=True)
|
149 |
+
return True
|
150 |
+
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
151 |
+
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
152 |
+
return False
|
153 |
+
|
154 |
+
|
155 |
+
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
156 |
+
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
157 |
+
if prompt_type == PromptType.human_bot.name:
|
158 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
159 |
+
# stopping only starts once output is beyond prompt
|
160 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
161 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
162 |
+
encounters = [1, 2]
|
163 |
+
elif prompt_type == PromptType.instruct_vicuna.name:
|
164 |
+
# even below is not enough, generic strings and many ways to encode
|
165 |
+
stop_words = [
|
166 |
+
'### Human:',
|
167 |
+
"""
|
168 |
+
### Human:""",
|
169 |
+
"""
|
170 |
+
### Human:
|
171 |
+
""",
|
172 |
+
'### Assistant:',
|
173 |
+
"""
|
174 |
+
### Assistant:""",
|
175 |
+
"""
|
176 |
+
### Assistant:
|
177 |
+
""",
|
178 |
+
]
|
179 |
+
encounters = [1, 2]
|
180 |
+
else:
|
181 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
182 |
+
stop_words = ['### End']
|
183 |
+
encounters = [1]
|
184 |
+
stop_words_ids = [
|
185 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
186 |
+
# handle single token case
|
187 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
188 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
189 |
+
# avoid padding in front of tokens
|
190 |
+
if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug
|
191 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
192 |
+
# handle fake \n added
|
193 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
194 |
+
# build stopper
|
195 |
+
stopping_criteria = StoppingCriteriaList(
|
196 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
|
197 |
+
else:
|
198 |
+
stopping_criteria = StoppingCriteriaList()
|
199 |
+
return stopping_criteria
|
200 |
+
import time
|
201 |
+
from enum import Enum
|
202 |
+
|
203 |
+
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
204 |
+
|
205 |
+
|
206 |
+
class PromptType(Enum):
|
207 |
+
plain = 0
|
208 |
+
instruct = 1
|
209 |
+
quality = 2
|
210 |
+
human_bot = 3
|
211 |
+
dai_faq = 4
|
212 |
+
summarize = 5
|
213 |
+
simple_instruct = 6
|
214 |
+
instruct_vicuna = 7
|
215 |
+
instruct_with_end = 8
|
216 |
+
human_bot_orig = 9
|
217 |
+
prompt_answer = 10
|
218 |
+
open_assistant = 11
|
219 |
+
wizard_lm = 12
|
220 |
+
wizard_mega = 13
|
221 |
+
instruct_vicuna2 = 14
|
222 |
+
instruct_vicuna3 = 15
|
223 |
+
wizard2 = 16
|
224 |
+
wizard3 = 17
|
225 |
+
|
226 |
+
|
227 |
+
prompt_type_to_model_name = {
|
228 |
+
'plain': [
|
229 |
+
'EleutherAI/gpt-j-6B',
|
230 |
+
'EleutherAI/pythia-6.9b',
|
231 |
+
'EleutherAI/pythia-12b',
|
232 |
+
'EleutherAI/pythia-12b-deduped',
|
233 |
+
'EleutherAI/gpt-neox-20b',
|
234 |
+
'openlm-research/open_llama_7b_700bt_preview',
|
235 |
+
'decapoda-research/llama-7b-hf',
|
236 |
+
'decapoda-research/llama-13b-hf',
|
237 |
+
'decapoda-research/llama-30b-hf',
|
238 |
+
'decapoda-research/llama-65b-hf',
|
239 |
+
'facebook/mbart-large-50-many-to-many-mmt',
|
240 |
+
'philschmid/bart-large-cnn-samsum',
|
241 |
+
'philschmid/flan-t5-base-samsum',
|
242 |
+
'gpt2',
|
243 |
+
'distilgpt2',
|
244 |
+
'mosaicml/mpt-7b-storywriter',
|
245 |
+
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
246 |
+
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
247 |
+
'gptj', # internally handles prompting
|
248 |
+
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
249 |
+
'gpt4all_llama', # internally handles prompting
|
250 |
+
],
|
251 |
+
'prompt_answer': [
|
252 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
253 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
254 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
255 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
256 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
257 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
258 |
+
],
|
259 |
+
'instruct': [],
|
260 |
+
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
261 |
+
'quality': [],
|
262 |
+
'human_bot': [
|
263 |
+
'h2oai/h2ogpt-oasst1-512-12b',
|
264 |
+
'h2oai/h2ogpt-oasst1-512-20b',
|
265 |
+
'h2oai/h2ogpt-oig-oasst1-256-6_9b',
|
266 |
+
'h2oai/h2ogpt-oig-oasst1-512-6_9b',
|
267 |
+
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
268 |
+
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
269 |
+
'h2oai/h2ogpt-research-oasst1-512-30b',
|
270 |
+
'h2oai/h2ogpt-oasst1-falcon-40b',
|
271 |
+
],
|
272 |
+
'dai_faq': [],
|
273 |
+
'summarize': [],
|
274 |
+
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
275 |
+
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
|
276 |
+
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
277 |
+
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
278 |
+
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
279 |
+
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
280 |
+
}
|
281 |
+
|
282 |
+
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
283 |
+
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
284 |
+
|
285 |
+
prompt_types_strings = []
|
286 |
+
for p in PromptType:
|
287 |
+
prompt_types_strings.extend([p.name])
|
288 |
+
|
289 |
+
prompt_types = []
|
290 |
+
for p in PromptType:
|
291 |
+
prompt_types.extend([p.name, p.value, str(p.value)])
|
292 |
+
|
293 |
+
|
294 |
+
def get_prompt(prompt_type, chat, context, reduced):
|
295 |
+
if prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
296 |
+
PromptType.plain.name]:
|
297 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
298 |
+
terminate_response = []
|
299 |
+
chat_sep = ''
|
300 |
+
humanstr = ''
|
301 |
+
botstr = ''
|
302 |
+
elif prompt_type == 'simple_instruct':
|
303 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
304 |
+
terminate_response = []
|
305 |
+
chat_sep = '\n'
|
306 |
+
humanstr = ''
|
307 |
+
botstr = ''
|
308 |
+
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
309 |
+
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
310 |
+
str(PromptType.instruct_with_end.value),
|
311 |
+
PromptType.instruct_with_end.name]:
|
312 |
+
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
|
313 |
+
chat and reduced) else ''
|
314 |
+
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
|
315 |
+
chat and reduced) else ''
|
316 |
+
|
317 |
+
PreInstruct = """
|
318 |
+
### Instruction:
|
319 |
+
"""
|
320 |
+
|
321 |
+
PreInput = """
|
322 |
+
### Input:
|
323 |
+
"""
|
324 |
+
|
325 |
+
PreResponse = """
|
326 |
+
### Response:
|
327 |
+
"""
|
328 |
+
if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value),
|
329 |
+
PromptType.instruct_with_end.name]:
|
330 |
+
terminate_response = ['### End']
|
331 |
+
else:
|
332 |
+
terminate_response = None
|
333 |
+
chat_sep = '\n'
|
334 |
+
humanstr = PreInstruct
|
335 |
+
botstr = PreResponse
|
336 |
+
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
337 |
+
PromptType.quality.name]:
|
338 |
+
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
|
339 |
+
chat and reduced) else ''
|
340 |
+
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
|
341 |
+
chat and reduced) else ''
|
342 |
+
|
343 |
+
PreInstruct = """
|
344 |
+
### Instruction:
|
345 |
+
"""
|
346 |
+
|
347 |
+
PreInput = """
|
348 |
+
### Input:
|
349 |
+
"""
|
350 |
+
|
351 |
+
PreResponse = """
|
352 |
+
### Response:
|
353 |
+
"""
|
354 |
+
terminate_response = None
|
355 |
+
chat_sep = '\n'
|
356 |
+
humanstr = PreInstruct # first thing human says
|
357 |
+
botstr = PreResponse # first thing bot says
|
358 |
+
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
359 |
+
PromptType.human_bot.name] + [PromptType.human_bot_orig.value,
|
360 |
+
str(PromptType.human_bot_orig.value),
|
361 |
+
PromptType.human_bot_orig.name]:
|
362 |
+
human = '<human>:'
|
363 |
+
bot = "<bot>:"
|
364 |
+
if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
365 |
+
PromptType.human_bot.name]:
|
366 |
+
preprompt = ''
|
367 |
+
else:
|
368 |
+
cur_date = time.strftime('%Y-%m-%d')
|
369 |
+
cur_time = time.strftime('%H:%M:%S %p %Z')
|
370 |
+
|
371 |
+
PRE_PROMPT = """\
|
372 |
+
Current Date: {}
|
373 |
+
Current Time: {}
|
374 |
+
|
375 |
+
"""
|
376 |
+
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
377 |
+
start = human
|
378 |
+
promptB = promptA = '%s%s ' % (preprompt, start)
|
379 |
+
|
380 |
+
PreInstruct = ""
|
381 |
+
|
382 |
+
PreInput = None
|
383 |
+
|
384 |
+
if reduced:
|
385 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
386 |
+
PreResponse = bot + ' '
|
387 |
+
else:
|
388 |
+
# normally LLM adds space after this, because was how trained.
|
389 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
390 |
+
PreResponse = bot
|
391 |
+
|
392 |
+
terminate_response = [start, PreResponse]
|
393 |
+
chat_sep = '\n'
|
394 |
+
humanstr = human # tag before human talks
|
395 |
+
botstr = bot # tag before bot talks
|
396 |
+
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
397 |
+
PromptType.dai_faq.name]:
|
398 |
+
promptA = ''
|
399 |
+
promptB = 'Answer the following Driverless AI question.\n'
|
400 |
+
|
401 |
+
PreInstruct = """
|
402 |
+
### Driverless AI frequently asked question:
|
403 |
+
"""
|
404 |
+
|
405 |
+
PreInput = None
|
406 |
+
|
407 |
+
PreResponse = """
|
408 |
+
### Driverless AI documentation answer:
|
409 |
+
"""
|
410 |
+
terminate_response = ['\n\n']
|
411 |
+
chat_sep = terminate_response
|
412 |
+
humanstr = PreInstruct
|
413 |
+
botstr = PreResponse
|
414 |
+
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
415 |
+
PromptType.summarize.name]:
|
416 |
+
promptA = promptB = PreInput = ''
|
417 |
+
PreInstruct = '## Main Text\n\n'
|
418 |
+
PreResponse = '\n\n## Summary\n\n'
|
419 |
+
terminate_response = None
|
420 |
+
chat_sep = '\n'
|
421 |
+
humanstr = PreInstruct
|
422 |
+
botstr = PreResponse
|
423 |
+
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
424 |
+
PromptType.instruct_vicuna.name]:
|
425 |
+
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
426 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
|
427 |
+
chat and reduced) else ''
|
428 |
+
|
429 |
+
PreInstruct = """
|
430 |
+
### Human:
|
431 |
+
"""
|
432 |
+
|
433 |
+
PreInput = None
|
434 |
+
|
435 |
+
PreResponse = """
|
436 |
+
### Assistant:
|
437 |
+
"""
|
438 |
+
terminate_response = [
|
439 |
+
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
440 |
+
chat_sep = '\n'
|
441 |
+
humanstr = PreInstruct
|
442 |
+
botstr = PreResponse
|
443 |
+
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
444 |
+
PromptType.prompt_answer.name]:
|
445 |
+
preprompt = ''
|
446 |
+
prompt_tokens = "<|prompt|>"
|
447 |
+
answer_tokens = "<|answer|>"
|
448 |
+
start = prompt_tokens
|
449 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
450 |
+
PreInstruct = ""
|
451 |
+
PreInput = None
|
452 |
+
PreResponse = answer_tokens
|
453 |
+
eos = '<|endoftext|>' # neox eos
|
454 |
+
terminate_response = [start, PreResponse, eos]
|
455 |
+
chat_sep = eos
|
456 |
+
humanstr = prompt_tokens
|
457 |
+
botstr = answer_tokens
|
458 |
+
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
459 |
+
PromptType.open_assistant.name]:
|
460 |
+
# From added_tokens.json
|
461 |
+
preprompt = ''
|
462 |
+
prompt_tokens = "<|prompter|>"
|
463 |
+
answer_tokens = "<|assistant|>"
|
464 |
+
start = prompt_tokens
|
465 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
466 |
+
PreInstruct = ""
|
467 |
+
PreInput = None
|
468 |
+
PreResponse = answer_tokens
|
469 |
+
pend = "<|prefix_end|>"
|
470 |
+
eos = "</s>"
|
471 |
+
terminate_response = [start, PreResponse, pend, eos]
|
472 |
+
chat_sep = eos
|
473 |
+
humanstr = prompt_tokens
|
474 |
+
botstr = answer_tokens
|
475 |
+
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
476 |
+
PromptType.wizard_lm.name]:
|
477 |
+
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
478 |
+
preprompt = ''
|
479 |
+
start = ''
|
480 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
481 |
+
PreInstruct = ""
|
482 |
+
PreInput = None
|
483 |
+
PreResponse = "\n\n### Response\n"
|
484 |
+
eos = "</s>"
|
485 |
+
terminate_response = [PreResponse, eos]
|
486 |
+
chat_sep = eos
|
487 |
+
humanstr = promptA
|
488 |
+
botstr = PreResponse
|
489 |
+
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
490 |
+
PromptType.wizard_mega.name]:
|
491 |
+
preprompt = ''
|
492 |
+
start = ''
|
493 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
494 |
+
PreInstruct = """
|
495 |
+
### Instruction:
|
496 |
+
"""
|
497 |
+
PreInput = None
|
498 |
+
PreResponse = """
|
499 |
+
### Assistant:
|
500 |
+
"""
|
501 |
+
terminate_response = [PreResponse]
|
502 |
+
chat_sep = '\n'
|
503 |
+
humanstr = PreInstruct
|
504 |
+
botstr = PreResponse
|
505 |
+
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
506 |
+
PromptType.instruct_vicuna2.name]:
|
507 |
+
promptA = promptB = "" if not (
|
508 |
+
chat and reduced) else ''
|
509 |
+
|
510 |
+
PreInstruct = """
|
511 |
+
HUMAN:
|
512 |
+
"""
|
513 |
+
|
514 |
+
PreInput = None
|
515 |
+
|
516 |
+
PreResponse = """
|
517 |
+
ASSISTANT:
|
518 |
+
"""
|
519 |
+
terminate_response = [
|
520 |
+
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
521 |
+
chat_sep = '\n'
|
522 |
+
humanstr = PreInstruct
|
523 |
+
botstr = PreResponse
|
524 |
+
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
525 |
+
PromptType.instruct_vicuna3.name]:
|
526 |
+
promptA = promptB = "" if not (
|
527 |
+
chat and reduced) else ''
|
528 |
+
|
529 |
+
PreInstruct = """
|
530 |
+
### User:
|
531 |
+
"""
|
532 |
+
|
533 |
+
PreInput = None
|
534 |
+
|
535 |
+
PreResponse = """
|
536 |
+
### Assistant:
|
537 |
+
"""
|
538 |
+
terminate_response = [
|
539 |
+
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
540 |
+
chat_sep = '\n'
|
541 |
+
humanstr = PreInstruct
|
542 |
+
botstr = PreResponse
|
543 |
+
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
544 |
+
PromptType.wizard2.name]:
|
545 |
+
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
546 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
547 |
+
start = ''
|
548 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
549 |
+
PreInstruct = """
|
550 |
+
### Instruction:
|
551 |
+
"""
|
552 |
+
PreInput = None
|
553 |
+
PreResponse = """
|
554 |
+
### Response:
|
555 |
+
"""
|
556 |
+
terminate_response = [PreResponse]
|
557 |
+
chat_sep = '\n'
|
558 |
+
humanstr = PreInstruct
|
559 |
+
botstr = PreResponse
|
560 |
+
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
561 |
+
PromptType.wizard3.name]:
|
562 |
+
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
563 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
564 |
+
start = ''
|
565 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
566 |
+
PreInstruct = """USER: """
|
567 |
+
PreInput = None
|
568 |
+
PreResponse = """ASSISTANT: """
|
569 |
+
terminate_response = [PreResponse]
|
570 |
+
chat_sep = '\n'
|
571 |
+
humanstr = PreInstruct
|
572 |
+
botstr = PreResponse
|
573 |
+
|
574 |
+
else:
|
575 |
+
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
576 |
+
|
577 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
|
578 |
+
|
579 |
+
|
580 |
+
def generate_prompt(data_point, prompt_type, chat, reduced):
|
581 |
+
context = data_point.get('context')
|
582 |
+
if context is None:
|
583 |
+
context = ''
|
584 |
+
instruction = data_point.get('instruction')
|
585 |
+
input = data_point.get('input')
|
586 |
+
output = data_point.get('output')
|
587 |
+
prompt_type = data_point.get('prompt_type', prompt_type)
|
588 |
+
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
589 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
590 |
+
terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
|
591 |
+
|
592 |
+
prompt = context if not reduced else ''
|
593 |
+
|
594 |
+
if input and promptA:
|
595 |
+
prompt += f"""{promptA}"""
|
596 |
+
elif promptB:
|
597 |
+
prompt += f"""{promptB}"""
|
598 |
+
|
599 |
+
if instruction and PreInstruct is not None and input and PreInput is not None:
|
600 |
+
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
601 |
+
prompt = inject_newline(prompt_type, prompt)
|
602 |
+
elif instruction and input and PreInstruct is None and PreInput is not None:
|
603 |
+
prompt += f"""{PreInput}{instruction}
|
604 |
+
{input}"""
|
605 |
+
prompt = inject_newline(prompt_type, prompt)
|
606 |
+
elif input and instruction and PreInput is None and PreInstruct is not None:
|
607 |
+
prompt += f"""{PreInstruct}{instruction}
|
608 |
+
{input}"""
|
609 |
+
prompt = inject_newline(prompt_type, prompt)
|
610 |
+
elif instruction and PreInstruct is not None:
|
611 |
+
prompt += f"""{PreInstruct}{instruction}"""
|
612 |
+
prompt = inject_newline(prompt_type, prompt)
|
613 |
+
elif input and PreInput is not None:
|
614 |
+
prompt += f"""{PreInput}{input}"""
|
615 |
+
prompt = inject_newline(prompt_type, prompt)
|
616 |
+
elif input and instruction and PreInput is not None:
|
617 |
+
prompt += f"""{PreInput}{instruction}{input}"""
|
618 |
+
prompt = inject_newline(prompt_type, prompt)
|
619 |
+
elif input and instruction and PreInstruct is not None:
|
620 |
+
prompt += f"""{PreInstruct}{instruction}{input}"""
|
621 |
+
prompt = inject_newline(prompt_type, prompt)
|
622 |
+
elif input and instruction:
|
623 |
+
# i.e. for simple_instruct
|
624 |
+
prompt += f"""{instruction}: {input}"""
|
625 |
+
prompt = inject_newline(prompt_type, prompt)
|
626 |
+
elif input:
|
627 |
+
prompt += f"""{input}"""
|
628 |
+
prompt = inject_newline(prompt_type, prompt)
|
629 |
+
elif instruction:
|
630 |
+
prompt += f"""{instruction}"""
|
631 |
+
prompt = inject_newline(prompt_type, prompt)
|
632 |
+
|
633 |
+
if PreResponse is not None:
|
634 |
+
prompt += f"""{PreResponse}"""
|
635 |
+
pre_response = PreResponse # Don't use strip
|
636 |
+
else:
|
637 |
+
pre_response = ''
|
638 |
+
|
639 |
+
if output:
|
640 |
+
prompt += f"""{output}"""
|
641 |
+
|
642 |
+
return prompt, pre_response, terminate_response, chat_sep
|
643 |
+
|
644 |
+
|
645 |
+
def inject_newline(prompt_type, prompt):
|
646 |
+
if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
|
647 |
+
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
648 |
+
prompt += '\n'
|
649 |
+
return prompt
|
650 |
+
|
651 |
+
|
652 |
+
class Prompter(object):
|
653 |
+
def __init__(self, prompt_type, debug=False, chat=False, stream_output=False, repeat_penalty=True,
|
654 |
+
allowed_repeat_line_length=10):
|
655 |
+
self.prompt_type = prompt_type
|
656 |
+
data_point = dict(instruction='', input='', output='')
|
657 |
+
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
658 |
+
generate_prompt(data_point, prompt_type, chat, False)
|
659 |
+
self.debug = debug
|
660 |
+
self.chat = chat
|
661 |
+
self.stream_output = stream_output
|
662 |
+
self.repeat_penalty = repeat_penalty
|
663 |
+
self.allowed_repeat_line_length = allowed_repeat_line_length
|
664 |
+
self.prompt = None
|
665 |
+
context = "" # not for chat context
|
666 |
+
reduced = False # not for chat context
|
667 |
+
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
668 |
+
self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
|
669 |
+
get_prompt(prompt_type, chat, context, reduced)
|
670 |
+
|
671 |
+
def generate_prompt(self, data_point):
|
672 |
+
reduced = False
|
673 |
+
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
674 |
+
if self.debug:
|
675 |
+
print("prompt: ", prompt, flush=True)
|
676 |
+
self.prompt = prompt
|
677 |
+
return prompt
|
678 |
+
|
679 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
|
680 |
+
if isinstance(outputs, str):
|
681 |
+
outputs = [outputs]
|
682 |
+
if self.debug:
|
683 |
+
print("output:\n", '\n\n'.join(outputs), flush=True)
|
684 |
+
if prompt is not None:
|
685 |
+
self.prompt = prompt
|
686 |
+
|
687 |
+
def clean_response(response):
|
688 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
|
689 |
+
for word in meaningless_words:
|
690 |
+
response = response.replace(word, "")
|
691 |
+
if sanitize_bot_response:
|
692 |
+
from better_profanity import profanity
|
693 |
+
response = profanity.censor(response)
|
694 |
+
response = response.strip("\n")
|
695 |
+
return response
|
696 |
+
|
697 |
+
def clean_repeats(response):
|
698 |
+
lines = response.split('\n')
|
699 |
+
new_lines = []
|
700 |
+
[new_lines.append(line) for line in lines if
|
701 |
+
line not in new_lines or len(line) < self.allowed_repeat_line_length]
|
702 |
+
if self.debug and len(lines) != len(new_lines):
|
703 |
+
print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True)
|
704 |
+
response = '\n'.join(new_lines)
|
705 |
+
return response
|
706 |
+
|
707 |
+
multi_output = len(outputs) > 1
|
708 |
+
|
709 |
+
for oi, output in enumerate(outputs):
|
710 |
+
if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]:
|
711 |
+
output = clean_response(output)
|
712 |
+
elif prompt is None:
|
713 |
+
# then use most basic parsing like pipeline
|
714 |
+
if self.botstr in output:
|
715 |
+
if self.humanstr:
|
716 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
|
717 |
+
else:
|
718 |
+
# i.e. use after bot but only up to next bot
|
719 |
+
output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
|
720 |
+
else:
|
721 |
+
# output = clean_response(output.strip())
|
722 |
+
# assume just not printed yet
|
723 |
+
output = ""
|
724 |
+
else:
|
725 |
+
# find first instance of prereponse
|
726 |
+
# prompt sometimes has odd characters, that mutate length,
|
727 |
+
# so can't go by length alone
|
728 |
+
if self.pre_response:
|
729 |
+
outputi = output.find(prompt)
|
730 |
+
if outputi >= 0:
|
731 |
+
output = output[outputi + len(prompt):]
|
732 |
+
allow_terminate = True
|
733 |
+
else:
|
734 |
+
# subtraction is risky due to space offsets sometimes, so only do if necessary
|
735 |
+
output = output[len(prompt) - len(self.pre_response):]
|
736 |
+
# [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat)
|
737 |
+
if self.pre_response in output:
|
738 |
+
output = output.split(self.pre_response)[1]
|
739 |
+
allow_terminate = True
|
740 |
+
else:
|
741 |
+
if output:
|
742 |
+
print("Failure of parsing or not enough output yet: %s" % output, flush=True)
|
743 |
+
allow_terminate = False
|
744 |
+
else:
|
745 |
+
allow_terminate = True
|
746 |
+
output = output[len(prompt):]
|
747 |
+
# clean after subtract prompt out, so correct removal of pre_response
|
748 |
+
output = clean_response(output).strip()
|
749 |
+
if self.repeat_penalty:
|
750 |
+
output = clean_repeats(output).strip()
|
751 |
+
if self.terminate_response and allow_terminate:
|
752 |
+
finds = []
|
753 |
+
for term in self.terminate_response:
|
754 |
+
finds.append(output.find(term))
|
755 |
+
finds = [x for x in finds if x >= 0]
|
756 |
+
if len(finds) > 0:
|
757 |
+
termi = finds[0]
|
758 |
+
output = output[:termi].strip()
|
759 |
+
else:
|
760 |
+
output = output.strip()
|
761 |
+
else:
|
762 |
+
output = output.strip()
|
763 |
+
if multi_output:
|
764 |
+
# prefix with output counter
|
765 |
+
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
766 |
+
if oi > 0:
|
767 |
+
# post fix outputs with seperator
|
768 |
+
output += '\n'
|
769 |
+
outputs[oi] = output
|
770 |
+
# join all outputs, only one extra new line between outputs
|
771 |
+
output = '\n'.join(outputs)
|
772 |
+
if self.debug:
|
773 |
+
print("outputclean:\n", '\n\n'.join(outputs), flush=True)
|
774 |
+
return output
|