Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
·
e0ba5f2
1
Parent(s):
09d4719
Update with h2oGPT hash 61628d335bdb685fdcc63ca9821cf5607f41a9e3
Browse files- app.py +0 -0
- app.py +1 -0
- finetune.py +169 -124
- generate.py +1185 -0
- gradio_runner.py +910 -0
- gradio_themes.py +142 -0
- prompter.py +2 -1
- requirements.txt +2 -1
- stopping.py +2 -114
- utils.py +77 -2
app.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
generate.py
|
finetune.py
CHANGED
@@ -1,55 +1,22 @@
|
|
1 |
import os
|
2 |
-
import pathlib
|
3 |
-
import random
|
4 |
-
import shutil
|
5 |
-
import subprocess
|
6 |
import sys
|
7 |
import time
|
8 |
-
from
|
9 |
from typing import List, Union
|
|
|
10 |
import fire
|
11 |
import numpy as np
|
|
|
12 |
import torch
|
13 |
-
from datasets import load_dataset, concatenate_datasets
|
14 |
-
import transformers
|
15 |
-
import torch.distributed as dist
|
16 |
-
|
17 |
-
from peft import (
|
18 |
-
prepare_model_for_int8_training,
|
19 |
-
LoraConfig,
|
20 |
-
get_peft_model,
|
21 |
-
get_peft_model_state_dict,
|
22 |
-
set_peft_model_state_dict,
|
23 |
-
)
|
24 |
-
|
25 |
-
from peft import mapping
|
26 |
-
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
|
27 |
|
28 |
|
29 |
def log(*args, **kwargs):
|
30 |
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
|
|
|
|
31 |
print(*args, **kwargs)
|
32 |
|
33 |
|
34 |
-
try:
|
35 |
-
import neptune
|
36 |
-
from transformers.integrations import NeptuneCallback
|
37 |
-
|
38 |
-
neptune_run = neptune.init_run(
|
39 |
-
source_files=[],
|
40 |
-
)
|
41 |
-
log("Connected to Neptune.")
|
42 |
-
except ImportError:
|
43 |
-
neptune_run = None
|
44 |
-
log("Please pip install neptune for tracking.")
|
45 |
-
except neptune.exceptions.NeptuneMissingApiTokenException:
|
46 |
-
neptune_run = None
|
47 |
-
os.environ["NEPTUNE_MODE"] = 'debug'
|
48 |
-
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
49 |
-
|
50 |
-
from enum import Enum
|
51 |
-
|
52 |
-
|
53 |
class PromptType(Enum):
|
54 |
plain = 0
|
55 |
instruct = 1
|
@@ -87,6 +54,7 @@ prompt_type_to_model_name = {
|
|
87 |
'h2oai/h2ogpt-oasst1-512-12b',
|
88 |
'h2oai/h2ogpt-oasst1-512-20b',
|
89 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
|
|
90 |
],
|
91 |
'dai_faq': [],
|
92 |
'summarize': [],
|
@@ -134,7 +102,7 @@ def train(
|
|
134 |
tokenizer_base_model: str = None,
|
135 |
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
136 |
|
137 |
-
data_path: str =
|
138 |
data_col_dict: dict = None,
|
139 |
# data_path: str = "./dai_docs.train.json",
|
140 |
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
@@ -158,6 +126,7 @@ def train(
|
|
158 |
micro_batch_size: int = 4,
|
159 |
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
160 |
fp16=True,
|
|
|
161 |
|
162 |
# general training hyperparams
|
163 |
num_epochs: float = 1,
|
@@ -175,12 +144,14 @@ def train(
|
|
175 |
lora_dropout: float = 0.05,
|
176 |
lora_target_modules: List[str] = None,
|
177 |
llama_type: bool = None,
|
|
|
178 |
|
179 |
# llm hyperparams
|
180 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
181 |
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
182 |
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
183 |
-
cutoff_len: int =
|
|
|
184 |
|
185 |
# torch training params
|
186 |
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
@@ -190,8 +161,15 @@ def train(
|
|
190 |
warmup_steps: int = 100,
|
191 |
logging_steps: int = 1,
|
192 |
save_steps: int = None, # must be round multiple of eval_steps
|
|
|
193 |
add_eos_token: bool = False,
|
194 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
# allow set token directly
|
196 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
197 |
|
@@ -211,7 +189,7 @@ def train(
|
|
211 |
if not output_dir:
|
212 |
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
213 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
214 |
-
raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
|
215 |
else:
|
216 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
217 |
raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
@@ -223,6 +201,21 @@ def train(
|
|
223 |
tokenizer_base_model = base_model
|
224 |
if llama_type is None:
|
225 |
llama_type = "llama" in base_model.lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
assert (
|
227 |
base_model
|
228 |
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
@@ -254,7 +247,7 @@ def train(
|
|
254 |
|
255 |
model = model_loader.from_pretrained(
|
256 |
base_model,
|
257 |
-
load_in_8bit=
|
258 |
device_map=device_map,
|
259 |
torch_dtype=torch.float16,
|
260 |
max_memory=max_memory,
|
@@ -268,66 +261,28 @@ def train(
|
|
268 |
model.is_parallelizable = True
|
269 |
model.model_parallel = True
|
270 |
|
271 |
-
tokenizer = tokenizer_loader
|
272 |
-
local_files_only=local_files_only,
|
273 |
-
resume_download=resume_download,
|
274 |
-
use_auth_token=use_auth_token)
|
275 |
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
280 |
-
tokenizer.padding_side = "left" # Allow batched inference
|
281 |
-
|
282 |
-
def tokenize(prompt, add_eos_token=True):
|
283 |
-
# there's probably a way to do this with the tokenizer settings
|
284 |
-
# but again, gotta move fast
|
285 |
-
result = tokenizer(
|
286 |
-
prompt,
|
287 |
-
truncation=True,
|
288 |
-
max_length=cutoff_len,
|
289 |
-
padding=False,
|
290 |
-
return_tensors=None,
|
291 |
)
|
292 |
-
if (
|
293 |
-
result["input_ids"][-1] != tokenizer.eos_token_id
|
294 |
-
and len(result["input_ids"]) < cutoff_len
|
295 |
-
and add_eos_token
|
296 |
-
):
|
297 |
-
result["input_ids"].append(tokenizer.eos_token_id)
|
298 |
-
result["attention_mask"].append(1)
|
299 |
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
-
|
|
|
|
|
303 |
|
304 |
-
def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
|
305 |
-
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
306 |
-
tokenized_full_prompt = tokenize(full_prompt)
|
307 |
-
if not train_on_inputs:
|
308 |
-
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
309 |
-
tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
|
310 |
-
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
311 |
-
if add_eos:
|
312 |
-
user_prompt_len -= 1
|
313 |
-
|
314 |
-
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
315 |
-
tokenized_full_prompt["labels"] = [
|
316 |
-
-100
|
317 |
-
] * user_prompt_len + tokenized_full_prompt["labels"][
|
318 |
-
user_prompt_len:
|
319 |
-
] # could be sped up, probably
|
320 |
-
return tokenized_full_prompt
|
321 |
-
|
322 |
-
if "gpt-neox" not in base_model or True:
|
323 |
-
model = prepare_model_for_int8_training(model)
|
324 |
-
else:
|
325 |
-
model = prepare_model_for_int8_training(
|
326 |
-
model,
|
327 |
-
output_embedding_layer_name="embed_out", # keep output logits in float32
|
328 |
-
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
329 |
-
)
|
330 |
if lora_weights:
|
|
|
331 |
from peft import PeftModel
|
332 |
model = PeftModel.from_pretrained(
|
333 |
model,
|
@@ -338,7 +293,7 @@ def train(
|
|
338 |
resume_download=resume_download,
|
339 |
use_auth_token=use_auth_token,
|
340 |
)
|
341 |
-
|
342 |
if lora_target_modules is None:
|
343 |
base_model_lower = base_model.lower()
|
344 |
if base_model_lower in lora_mappings:
|
@@ -386,7 +341,11 @@ def train(
|
|
386 |
log(f"Checkpoint {checkpoint_name} not found")
|
387 |
|
388 |
print(model)
|
389 |
-
|
|
|
|
|
|
|
|
|
390 |
|
391 |
metrics = {}
|
392 |
for name in supported_metrics:
|
@@ -405,6 +364,7 @@ def train(
|
|
405 |
elif val_set_size < 1.0 and val_set_size != 0:
|
406 |
raise RuntimeError("Fractional validation size not supported.")
|
407 |
|
|
|
408 |
if valid_path:
|
409 |
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
410 |
else:
|
@@ -427,10 +387,16 @@ def train(
|
|
427 |
else:
|
428 |
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
429 |
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
|
431 |
# only get as much as we need to balance
|
432 |
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
433 |
-
train_size = max(1, min(data_mix_in.num_rows - valid_size,
|
434 |
mixin_small = data_mix_in.train_test_split(
|
435 |
test_size=train_size + valid_size,
|
436 |
shuffle=True, seed=np.random.randint(10000),
|
@@ -486,10 +452,20 @@ def train(
|
|
486 |
|
487 |
assert train_data is not None
|
488 |
|
|
|
|
|
|
|
|
|
489 |
# shuffle and tokenize data
|
490 |
if train_data_mix_in:
|
491 |
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
train_set_size = len(train_data)
|
494 |
|
495 |
if valid_data and valid_data_mix_in:
|
@@ -498,7 +474,8 @@ def train(
|
|
498 |
valid_data = valid_data_mix_in
|
499 |
|
500 |
if valid_data:
|
501 |
-
|
|
|
502 |
val_set_size = len(valid_data)
|
503 |
else:
|
504 |
val_set_size = 0
|
@@ -509,6 +486,22 @@ def train(
|
|
509 |
del sample_row_dict['labels']
|
510 |
log("Sample input: %s" % sample_row_dict)
|
511 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
if neptune_run:
|
513 |
neptune_callback = NeptuneCallback(run=neptune_run)
|
514 |
callbacks = [neptune_callback]
|
@@ -578,6 +571,7 @@ def train(
|
|
578 |
else:
|
579 |
trainer_kwargs = dict()
|
580 |
|
|
|
581 |
trainer = transformers.Trainer(
|
582 |
model=model,
|
583 |
tokenizer=tokenizer,
|
@@ -605,7 +599,7 @@ def train(
|
|
605 |
eval_steps=eval_steps if val_set_size > 0 else None,
|
606 |
save_steps=save_steps,
|
607 |
output_dir=output_dir,
|
608 |
-
save_total_limit=
|
609 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
610 |
ddp_find_unused_parameters=False if ddp else None,
|
611 |
group_by_length=group_by_length,
|
@@ -622,6 +616,8 @@ def train(
|
|
622 |
model.config.use_cache = False
|
623 |
|
624 |
old_state_dict = model.state_dict
|
|
|
|
|
625 |
model.state_dict = (
|
626 |
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
627 |
).__get__(model, type(model))
|
@@ -629,7 +625,8 @@ def train(
|
|
629 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
630 |
model = torch.compile(model)
|
631 |
# WIP (not generally replacing layers until pytorch 2.1)
|
632 |
-
|
|
|
633 |
|
634 |
if gpus > 1 and not ddp:
|
635 |
assert trainer.is_model_parallel
|
@@ -649,6 +646,9 @@ def get_loaders(llama_type, model_name, reward_type):
|
|
649 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
650 |
model_loader = LlamaForCausalLM
|
651 |
tokenizer_loader = LlamaTokenizer
|
|
|
|
|
|
|
652 |
elif 'gpt2' in model_name.lower():
|
653 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
654 |
return GPT2LMHeadModel, GPT2Tokenizer
|
@@ -676,31 +676,76 @@ def get_loaders(llama_type, model_name, reward_type):
|
|
676 |
return model_loader, tokenizer_loader
|
677 |
|
678 |
|
679 |
-
def
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
684 |
-
|
|
|
|
|
|
|
|
|
|
|
685 |
|
|
|
686 |
|
687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
"""
|
689 |
-
|
690 |
-
:param
|
|
|
691 |
:return:
|
692 |
"""
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
704 |
|
705 |
|
706 |
def get_prompt(prompt_type, chat, context, reduced):
|
@@ -824,7 +869,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
824 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
825 |
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
|
826 |
|
827 |
-
prompt = context
|
828 |
|
829 |
if input and promptA:
|
830 |
prompt += f"""{promptA}"""
|
|
|
1 |
import os
|
|
|
|
|
|
|
|
|
2 |
import sys
|
3 |
import time
|
4 |
+
from functools import partial
|
5 |
from typing import List, Union
|
6 |
+
from enum import Enum
|
7 |
import fire
|
8 |
import numpy as np
|
9 |
+
from utils import get_githash, copy_code
|
10 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def log(*args, **kwargs):
|
14 |
if int(os.environ.get("LOCAL_RANK", 0)) == 0:
|
15 |
+
if 'flush' not in kwargs:
|
16 |
+
kwargs['flush'] = True
|
17 |
print(*args, **kwargs)
|
18 |
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
class PromptType(Enum):
|
21 |
plain = 0
|
22 |
instruct = 1
|
|
|
54 |
'h2oai/h2ogpt-oasst1-512-12b',
|
55 |
'h2oai/h2ogpt-oasst1-512-20b',
|
56 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b',
|
57 |
+
'h2oai/h2ogpt-research-oasst1-512-30b', # private
|
58 |
],
|
59 |
'dai_faq': [],
|
60 |
'summarize': [],
|
|
|
102 |
tokenizer_base_model: str = None,
|
103 |
# tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
|
104 |
|
105 |
+
data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
|
106 |
data_col_dict: dict = None,
|
107 |
# data_path: str = "./dai_docs.train.json",
|
108 |
prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
|
|
|
126 |
micro_batch_size: int = 4,
|
127 |
gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
|
128 |
fp16=True,
|
129 |
+
train_8bit=True,
|
130 |
|
131 |
# general training hyperparams
|
132 |
num_epochs: float = 1,
|
|
|
144 |
lora_dropout: float = 0.05,
|
145 |
lora_target_modules: List[str] = None,
|
146 |
llama_type: bool = None,
|
147 |
+
llama_flash_attn: bool = False,
|
148 |
|
149 |
# llm hyperparams
|
150 |
train_on_inputs: bool = True, # if False, masks out inputs in loss
|
151 |
group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
|
152 |
resume_from_checkpoint: str = None, # either training checkpoint or final adapter
|
153 |
+
cutoff_len: int = 512, # larger values use more memory
|
154 |
+
drop_truncations: bool = False, # if True, drop any truncated long sequences
|
155 |
|
156 |
# torch training params
|
157 |
ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
|
|
|
161 |
warmup_steps: int = 100,
|
162 |
logging_steps: int = 1,
|
163 |
save_steps: int = None, # must be round multiple of eval_steps
|
164 |
+
save_total_limit: int = 3,
|
165 |
add_eos_token: bool = False,
|
166 |
):
|
167 |
+
|
168 |
+
if llama_flash_attn:
|
169 |
+
# Need to call this before importing transformers.
|
170 |
+
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
171 |
+
replace_llama_attn_with_flash_attn()
|
172 |
+
|
173 |
# allow set token directly
|
174 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
175 |
|
|
|
189 |
if not output_dir:
|
190 |
output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
|
191 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
192 |
+
raise FileExistsError(f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
|
193 |
else:
|
194 |
if os.path.exists(output_dir) and not resume_from_checkpoint:
|
195 |
raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
|
|
|
201 |
tokenizer_base_model = base_model
|
202 |
if llama_type is None:
|
203 |
llama_type = "llama" in base_model.lower()
|
204 |
+
if llama_type and llama_flash_attn:
|
205 |
+
import pkg_resources
|
206 |
+
try:
|
207 |
+
pkg_resources.get_distribution('flash_attn')
|
208 |
+
can_do_flash_attn = True
|
209 |
+
except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
|
210 |
+
can_do_flash_attn = False
|
211 |
+
|
212 |
+
if not can_do_flash_attn:
|
213 |
+
raise RuntimeError("""Flash attention not installed.
|
214 |
+
NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
|
215 |
+
|
216 |
+
CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
|
217 |
+
from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
|
218 |
+
replace_llama_attn_with_flash_attn()
|
219 |
assert (
|
220 |
base_model
|
221 |
), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
|
|
|
247 |
|
248 |
model = model_loader.from_pretrained(
|
249 |
base_model,
|
250 |
+
load_in_8bit=train_8bit,
|
251 |
device_map=device_map,
|
252 |
torch_dtype=torch.float16,
|
253 |
max_memory=max_memory,
|
|
|
261 |
model.is_parallelizable = True
|
262 |
model.model_parallel = True
|
263 |
|
264 |
+
tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
|
|
|
|
|
|
|
265 |
|
266 |
+
if train_8bit:
|
267 |
+
from peft import (
|
268 |
+
prepare_model_for_int8_training,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
|
271 |
+
if "gpt-neox" not in base_model or True:
|
272 |
+
model = prepare_model_for_int8_training(model)
|
273 |
+
else:
|
274 |
+
model = prepare_model_for_int8_training(
|
275 |
+
model,
|
276 |
+
output_embedding_layer_name="embed_out", # keep output logits in float32
|
277 |
+
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
278 |
+
)
|
279 |
|
280 |
+
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, utils
|
281 |
+
lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
282 |
+
lora_mappings['distilgpt2'] = ["c_attn"]
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
if lora_weights:
|
285 |
+
|
286 |
from peft import PeftModel
|
287 |
model = PeftModel.from_pretrained(
|
288 |
model,
|
|
|
293 |
resume_download=resume_download,
|
294 |
use_auth_token=use_auth_token,
|
295 |
)
|
296 |
+
elif lora_r > 0:
|
297 |
if lora_target_modules is None:
|
298 |
base_model_lower = base_model.lower()
|
299 |
if base_model_lower in lora_mappings:
|
|
|
341 |
log(f"Checkpoint {checkpoint_name} not found")
|
342 |
|
343 |
print(model)
|
344 |
+
try:
|
345 |
+
# only for PeftModel
|
346 |
+
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
347 |
+
except:
|
348 |
+
pass
|
349 |
|
350 |
metrics = {}
|
351 |
for name in supported_metrics:
|
|
|
364 |
elif val_set_size < 1.0 and val_set_size != 0:
|
365 |
raise RuntimeError("Fractional validation size not supported.")
|
366 |
|
367 |
+
from datasets import load_dataset, concatenate_datasets
|
368 |
if valid_path:
|
369 |
data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
|
370 |
else:
|
|
|
387 |
else:
|
388 |
data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
|
389 |
data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
|
390 |
+
mix_in_rows = int(num_rows * data_mix_in_factor)
|
391 |
+
|
392 |
+
if mix_in_rows > data_mix_in.num_rows:
|
393 |
+
# duplicate rows if mix-in is smaller than required
|
394 |
+
log("Duplicating mixin to compensate for its size for training size and mixin fraction")
|
395 |
+
data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
|
396 |
|
397 |
# only get as much as we need to balance
|
398 |
valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
|
399 |
+
train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
|
400 |
mixin_small = data_mix_in.train_test_split(
|
401 |
test_size=train_size + valid_size,
|
402 |
shuffle=True, seed=np.random.randint(10000),
|
|
|
452 |
|
453 |
assert train_data is not None
|
454 |
|
455 |
+
generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
|
456 |
+
train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
|
457 |
+
cutoff_len=cutoff_len, tokenizer=tokenizer)
|
458 |
+
|
459 |
# shuffle and tokenize data
|
460 |
if train_data_mix_in:
|
461 |
train_data = concatenate_datasets([train_data, train_data_mix_in])
|
462 |
+
log("Tokenizing %s training rows" % train_data.num_rows)
|
463 |
+
train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
|
464 |
+
if drop_truncations:
|
465 |
+
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
|
466 |
+
prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
|
467 |
+
train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
|
468 |
+
log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
|
469 |
train_set_size = len(train_data)
|
470 |
|
471 |
if valid_data and valid_data_mix_in:
|
|
|
474 |
valid_data = valid_data_mix_in
|
475 |
|
476 |
if valid_data:
|
477 |
+
log("Tokenizing %s validation rows" % valid_data.num_rows)
|
478 |
+
valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
|
479 |
val_set_size = len(valid_data)
|
480 |
else:
|
481 |
val_set_size = 0
|
|
|
486 |
del sample_row_dict['labels']
|
487 |
log("Sample input: %s" % sample_row_dict)
|
488 |
|
489 |
+
try:
|
490 |
+
import neptune
|
491 |
+
from transformers.integrations import NeptuneCallback
|
492 |
+
|
493 |
+
neptune_run = neptune.init_run(
|
494 |
+
source_files=[],
|
495 |
+
)
|
496 |
+
log("Connected to Neptune.")
|
497 |
+
except ImportError:
|
498 |
+
neptune_run = None
|
499 |
+
log("Please pip install neptune for tracking.")
|
500 |
+
except neptune.exceptions.NeptuneMissingApiTokenException:
|
501 |
+
neptune_run = None
|
502 |
+
os.environ["NEPTUNE_MODE"] = 'debug'
|
503 |
+
log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
|
504 |
+
|
505 |
if neptune_run:
|
506 |
neptune_callback = NeptuneCallback(run=neptune_run)
|
507 |
callbacks = [neptune_callback]
|
|
|
571 |
else:
|
572 |
trainer_kwargs = dict()
|
573 |
|
574 |
+
import transformers
|
575 |
trainer = transformers.Trainer(
|
576 |
model=model,
|
577 |
tokenizer=tokenizer,
|
|
|
599 |
eval_steps=eval_steps if val_set_size > 0 else None,
|
600 |
save_steps=save_steps,
|
601 |
output_dir=output_dir,
|
602 |
+
save_total_limit=save_total_limit,
|
603 |
load_best_model_at_end=True if val_set_size > 0 else False,
|
604 |
ddp_find_unused_parameters=False if ddp else None,
|
605 |
group_by_length=group_by_length,
|
|
|
616 |
model.config.use_cache = False
|
617 |
|
618 |
old_state_dict = model.state_dict
|
619 |
+
from peft import get_peft_model_state_dict
|
620 |
+
|
621 |
model.state_dict = (
|
622 |
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
623 |
).__get__(model, type(model))
|
|
|
625 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
626 |
model = torch.compile(model)
|
627 |
# WIP (not generally replacing layers until pytorch 2.1)
|
628 |
+
if not llama_flash_attn:
|
629 |
+
torch.backends.cuda.enable_flash_sdp(True)
|
630 |
|
631 |
if gpus > 1 and not ddp:
|
632 |
assert trainer.is_model_parallel
|
|
|
646 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
647 |
model_loader = LlamaForCausalLM
|
648 |
tokenizer_loader = LlamaTokenizer
|
649 |
+
elif 'distilgpt2' in model_name.lower():
|
650 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
651 |
+
return AutoModelForCausalLM, AutoTokenizer
|
652 |
elif 'gpt2' in model_name.lower():
|
653 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
654 |
return GPT2LMHeadModel, GPT2Tokenizer
|
|
|
676 |
return model_loader, tokenizer_loader
|
677 |
|
678 |
|
679 |
+
def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
|
680 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
681 |
+
local_files_only=local_files_only,
|
682 |
+
resume_download=resume_download,
|
683 |
+
use_auth_token=use_auth_token)
|
684 |
+
|
685 |
+
tokenizer.pad_token_id = 0 # different from the eos token
|
686 |
+
# when generating, we will use the logits of right-most token to predict the next token
|
687 |
+
# so the padding should be on the left,
|
688 |
+
# e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
|
689 |
+
tokenizer.padding_side = "left" # Allow batched inference
|
690 |
|
691 |
+
return tokenizer
|
692 |
|
693 |
+
|
694 |
+
def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
|
695 |
+
# there's probably a way to do this with the tokenizer settings
|
696 |
+
# but again, gotta move fast
|
697 |
+
result = tokenizer(
|
698 |
+
prompt,
|
699 |
+
truncation=True,
|
700 |
+
max_length=cutoff_len,
|
701 |
+
padding=False,
|
702 |
+
return_tensors=None,
|
703 |
+
)
|
704 |
+
if (
|
705 |
+
result["input_ids"][-1] != tokenizer.eos_token_id
|
706 |
+
and len(result["input_ids"]) < cutoff_len
|
707 |
+
and add_eos_token
|
708 |
+
):
|
709 |
+
result["input_ids"].append(tokenizer.eos_token_id)
|
710 |
+
result["attention_mask"].append(1)
|
711 |
+
|
712 |
+
result["labels"] = result["input_ids"].copy()
|
713 |
+
|
714 |
+
return result
|
715 |
+
|
716 |
+
|
717 |
+
def prune_long_sequences(data_point, cutoff_len=None):
|
718 |
"""
|
719 |
+
Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
|
720 |
+
:param data_point:
|
721 |
+
:param cutoff_len:
|
722 |
:return:
|
723 |
"""
|
724 |
+
assert cutoff_len is not None
|
725 |
+
return len(data_point['input_ids']) < cutoff_len
|
726 |
+
|
727 |
+
|
728 |
+
def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
|
729 |
+
cutoff_len=None, tokenizer=None):
|
730 |
+
assert prompt_type is not None
|
731 |
+
assert cutoff_len is not None
|
732 |
+
assert tokenizer is not None
|
733 |
+
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
734 |
+
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
735 |
+
if not train_on_inputs:
|
736 |
+
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
737 |
+
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
738 |
+
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
739 |
+
if add_eos_token:
|
740 |
+
user_prompt_len -= 1
|
741 |
+
|
742 |
+
# ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
|
743 |
+
tokenized_full_prompt["labels"] = [
|
744 |
+
-100
|
745 |
+
] * user_prompt_len + tokenized_full_prompt["labels"][
|
746 |
+
user_prompt_len:
|
747 |
+
] # could be sped up, probably
|
748 |
+
return tokenized_full_prompt
|
749 |
|
750 |
|
751 |
def get_prompt(prompt_type, chat, context, reduced):
|
|
|
869 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
870 |
promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
|
871 |
|
872 |
+
prompt = context if not reduced else ''
|
873 |
|
874 |
if input and promptA:
|
875 |
prompt += f"""{promptA}"""
|
generate.py
ADDED
@@ -0,0 +1,1185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
import traceback
|
5 |
+
import typing
|
6 |
+
from threading import Thread
|
7 |
+
from datetime import datetime
|
8 |
+
import filelock
|
9 |
+
import psutil
|
10 |
+
|
11 |
+
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
|
12 |
+
|
13 |
+
SEED = 1236
|
14 |
+
set_seed(SEED)
|
15 |
+
|
16 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
17 |
+
from typing import Union
|
18 |
+
import numpy as np
|
19 |
+
import pandas as pd
|
20 |
+
|
21 |
+
import fire
|
22 |
+
import torch
|
23 |
+
from peft import PeftModel
|
24 |
+
from transformers import GenerationConfig, StoppingCriteriaList, AutoModel, TextIteratorStreamer
|
25 |
+
from accelerate import init_empty_weights, infer_auto_device_map
|
26 |
+
|
27 |
+
from prompter import Prompter
|
28 |
+
|
29 |
+
from finetune import get_loaders, example_data_points, generate_prompt, human, bot, inv_prompt_type_to_model_lower
|
30 |
+
from stopping import StoppingCriteriaSub
|
31 |
+
|
32 |
+
eval_extra_columns = ['prompt', 'response', 'score']
|
33 |
+
|
34 |
+
|
35 |
+
def main(
|
36 |
+
load_8bit: bool = False,
|
37 |
+
load_half: bool = True,
|
38 |
+
infer_devices: bool = True, # really if to "control" devices now
|
39 |
+
base_model: str = '',
|
40 |
+
tokenizer_base_model: str = '',
|
41 |
+
lora_weights: str = "",
|
42 |
+
gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
|
43 |
+
|
44 |
+
prompt_type: Union[int, str] = None,
|
45 |
+
# input to generation
|
46 |
+
temperature: float = None,
|
47 |
+
top_p: float = None,
|
48 |
+
top_k: int = None,
|
49 |
+
num_beams: int = None,
|
50 |
+
repetition_penalty: float = None,
|
51 |
+
num_return_sequences: int = None,
|
52 |
+
do_sample: bool = None,
|
53 |
+
max_new_tokens: int = None,
|
54 |
+
min_new_tokens: int = None,
|
55 |
+
early_stopping: Union[bool, str] = None,
|
56 |
+
max_time: float = None,
|
57 |
+
|
58 |
+
debug: bool = False,
|
59 |
+
save_dir: str = None,
|
60 |
+
share: bool = True,
|
61 |
+
local_files_only: bool = False,
|
62 |
+
resume_download: bool = True,
|
63 |
+
use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
|
64 |
+
|
65 |
+
src_lang: str = "English",
|
66 |
+
tgt_lang: str = "Russian",
|
67 |
+
|
68 |
+
gradio: bool = True,
|
69 |
+
gradio_avoid_processing_markdown: bool = False,
|
70 |
+
chat: bool = True,
|
71 |
+
chat_history: int = 4096, # character length of chat context/history
|
72 |
+
chat_context: bool = False, # use default context if human_bot
|
73 |
+
stream_output: bool = True,
|
74 |
+
show_examples: bool = None,
|
75 |
+
verbose: bool = False,
|
76 |
+
h2ocolors: bool = True,
|
77 |
+
height: int = 400,
|
78 |
+
show_lora: bool = True,
|
79 |
+
# set to True to load --base_model after client logs in,
|
80 |
+
# to be able to free GPU memory when model is swapped
|
81 |
+
login_mode_if_model0: bool = False,
|
82 |
+
block_gradio_exit: bool = True,
|
83 |
+
concurrency_count: int = 1,
|
84 |
+
api_open: bool = False, # don't let API skip queue
|
85 |
+
allow_api: bool = True,
|
86 |
+
input_lines: int = 1,
|
87 |
+
|
88 |
+
sanitize_user_prompt: bool = True,
|
89 |
+
sanitize_bot_response: bool = True,
|
90 |
+
|
91 |
+
extra_model_options: typing.List[str] = [],
|
92 |
+
extra_lora_options: typing.List[str] = [],
|
93 |
+
|
94 |
+
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
95 |
+
auto_score: bool = True,
|
96 |
+
|
97 |
+
eval_sharegpt_prompts_only: int = 0,
|
98 |
+
eval_sharegpt_prompts_only_seed: int = 1234,
|
99 |
+
eval_sharegpt_as_output: bool = False,
|
100 |
+
|
101 |
+
hard_stop_list: typing.List[str] = [],
|
102 |
+
):
|
103 |
+
is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
|
104 |
+
is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
|
105 |
+
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
106 |
+
is_low_mem = is_hf # assumes run on 24GB consumer GPU
|
107 |
+
admin_pass = os.getenv("ADMIN_PASS")
|
108 |
+
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
109 |
+
# but becomes unrecoverable sometimes if raise, so just be silent for now
|
110 |
+
raise_generate_gpu_exceptions = not is_public
|
111 |
+
|
112 |
+
# allow set token directly
|
113 |
+
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
114 |
+
|
115 |
+
if is_public:
|
116 |
+
input_lines = 1 # ensure set, for ease of use
|
117 |
+
temperature = 0.2
|
118 |
+
top_p = 0.85
|
119 |
+
top_k = 70
|
120 |
+
do_sample = True
|
121 |
+
if is_low_mem:
|
122 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
123 |
+
load_8bit = True
|
124 |
+
else:
|
125 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
126 |
+
if is_low_mem:
|
127 |
+
load_8bit = True
|
128 |
+
if is_hf:
|
129 |
+
# must override share if in spaces
|
130 |
+
share = False
|
131 |
+
save_dir = os.getenv('SAVE_DIR', save_dir)
|
132 |
+
score_model = os.getenv('SCORE_MODEL', score_model)
|
133 |
+
if score_model == 'None':
|
134 |
+
score_model = ''
|
135 |
+
concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
|
136 |
+
api_open = bool(int(os.getenv('API_OPEN', api_open)))
|
137 |
+
allow_api = bool(int(os.getenv('ALLOW_API', allow_api)))
|
138 |
+
|
139 |
+
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
140 |
+
if n_gpus == 0:
|
141 |
+
gpu_id = None
|
142 |
+
load_8bit = False
|
143 |
+
load_half = False
|
144 |
+
infer_devices = False
|
145 |
+
torch.backends.cudnn.benchmark = True
|
146 |
+
torch.backends.cudnn.enabled = False
|
147 |
+
torch.set_default_dtype(torch.float32)
|
148 |
+
if psutil.virtual_memory().available < 94*1024**3:
|
149 |
+
# 12B uses ~94GB
|
150 |
+
# 6.9B uses ~47GB
|
151 |
+
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
|
152 |
+
|
153 |
+
# get defaults
|
154 |
+
model_lower = base_model.lower()
|
155 |
+
if not gradio:
|
156 |
+
# force, else not single response like want to look at
|
157 |
+
stream_output = False
|
158 |
+
# else prompt removal can mess up output
|
159 |
+
chat = False
|
160 |
+
|
161 |
+
placeholder_instruction, placeholder_input, \
|
162 |
+
stream_output, show_examples, \
|
163 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
164 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
165 |
+
repetition_penalty, num_return_sequences, \
|
166 |
+
do_sample, \
|
167 |
+
src_lang, tgt_lang, \
|
168 |
+
examples, \
|
169 |
+
task_info = \
|
170 |
+
get_generate_params(model_lower, chat,
|
171 |
+
stream_output, show_examples,
|
172 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
173 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
174 |
+
repetition_penalty, num_return_sequences,
|
175 |
+
do_sample,
|
176 |
+
)
|
177 |
+
|
178 |
+
if not gradio:
|
179 |
+
if eval_sharegpt_prompts_only > 0:
|
180 |
+
# override default examples with shareGPT ones for human-level eval purposes only
|
181 |
+
eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
|
182 |
+
if not os.path.isfile(eval_filename):
|
183 |
+
os.system(
|
184 |
+
'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
|
185 |
+
import json
|
186 |
+
data = json.load(open(eval_filename, 'rt'))
|
187 |
+
# focus on data that starts with human, else likely chopped from other data
|
188 |
+
turn_start = 0 # odd in general
|
189 |
+
data = [x for x in data if len(x['conversations']) > turn_start + 1 and
|
190 |
+
x['conversations'][turn_start]['from'] == 'human' and
|
191 |
+
x['conversations'][turn_start + 1]['from'] == 'gpt']
|
192 |
+
np.random.seed(eval_sharegpt_prompts_only_seed)
|
193 |
+
example1 = examples[-1] # pick reference example
|
194 |
+
examples = []
|
195 |
+
responses = []
|
196 |
+
for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
|
197 |
+
assert data[i]['conversations'][turn_start]['from'] == 'human'
|
198 |
+
instruction = data[i]['conversations'][turn_start]['value']
|
199 |
+
assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
|
200 |
+
output = data[i]['conversations'][turn_start + 1]['value']
|
201 |
+
examplenew = example1.copy()
|
202 |
+
assert not chat, "No gradio must use chat=False, uses nochat instruct"
|
203 |
+
examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
|
204 |
+
examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
|
205 |
+
examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
|
206 |
+
examples.append(examplenew)
|
207 |
+
responses.append(output)
|
208 |
+
|
209 |
+
num_examples = len(examples)
|
210 |
+
scoring_path = 'scoring'
|
211 |
+
os.makedirs(scoring_path, exist_ok=True)
|
212 |
+
if eval_sharegpt_as_output:
|
213 |
+
used_base_model = 'gpt35'
|
214 |
+
used_lora_weights = ''
|
215 |
+
else:
|
216 |
+
used_base_model = str(base_model.split('/')[-1])
|
217 |
+
used_lora_weights = str(lora_weights.split('/')[-1])
|
218 |
+
eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
|
219 |
+
eval_sharegpt_prompts_only_seed,
|
220 |
+
eval_sharegpt_as_output,
|
221 |
+
used_base_model,
|
222 |
+
used_lora_weights)
|
223 |
+
eval_filename = os.path.join(scoring_path, eval_filename)
|
224 |
+
|
225 |
+
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
|
226 |
+
context_class = NullContext() if n_gpus > 1 or n_gpus == 0 else torch.device("cuda")
|
227 |
+
|
228 |
+
with context_class:
|
229 |
+
# ensure was set right above before examples generated
|
230 |
+
assert not stream_output, "stream_output=True does not make sense with example loop"
|
231 |
+
import time
|
232 |
+
from functools import partial
|
233 |
+
|
234 |
+
# get score model
|
235 |
+
smodel, stokenizer, sdevice = get_score_model(**locals())
|
236 |
+
|
237 |
+
if not eval_sharegpt_as_output:
|
238 |
+
model, tokenizer, device = get_model(**locals())
|
239 |
+
model_state = [model, tokenizer, device, base_model]
|
240 |
+
fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
|
241 |
+
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
242 |
+
chat_context=chat_context,
|
243 |
+
concurrency_count=concurrency_count)
|
244 |
+
else:
|
245 |
+
assert eval_sharegpt_prompts_only > 0
|
246 |
+
|
247 |
+
def get_response(*args, exi=0):
|
248 |
+
# assumes same ordering of examples and responses
|
249 |
+
yield responses[exi]
|
250 |
+
|
251 |
+
fun = get_response
|
252 |
+
t0 = time.time()
|
253 |
+
score_dump = []
|
254 |
+
|
255 |
+
import matplotlib.pyplot as plt
|
256 |
+
|
257 |
+
for exi, ex in enumerate(examples):
|
258 |
+
instruction = ex[eval_func_param_names.index('instruction_nochat')]
|
259 |
+
iinput = ex[eval_func_param_names.index('iinput_nochat')]
|
260 |
+
context = ex[eval_func_param_names.index('context')]
|
261 |
+
clear_torch_cache()
|
262 |
+
print("")
|
263 |
+
print("START" + "=" * 100)
|
264 |
+
print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
|
265 |
+
print("-" * 105)
|
266 |
+
# fun yields as generator, so have to iterate over it
|
267 |
+
# Also means likely do NOT want --stream_output=True, else would show all generations
|
268 |
+
gener = fun(*tuple(ex), exi=exi) if eval_sharegpt_as_output else fun(*tuple(ex))
|
269 |
+
for res in gener:
|
270 |
+
print(res)
|
271 |
+
if smodel:
|
272 |
+
score_with_prompt = False
|
273 |
+
if score_with_prompt:
|
274 |
+
data_point = dict(instruction=instruction, input=iinput, context=context)
|
275 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
276 |
+
prompt = prompter.generate_prompt(data_point)
|
277 |
+
else:
|
278 |
+
# just raw input and output
|
279 |
+
if eval_sharegpt_prompts_only > 0:
|
280 |
+
# only our own examples have this filled at moment
|
281 |
+
assert iinput in [None, ''], iinput # should be no iinput
|
282 |
+
if not (chat_context and prompt_type == 'human_bot'):
|
283 |
+
assert context in [None, ''], context # should be no context
|
284 |
+
prompt = instruction
|
285 |
+
cutoff_len = 768 if is_low_mem else 2048
|
286 |
+
inputs = stokenizer(prompt, res,
|
287 |
+
return_tensors="pt",
|
288 |
+
truncation=True,
|
289 |
+
max_length=cutoff_len)
|
290 |
+
try:
|
291 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
292 |
+
except torch.cuda.OutOfMemoryError as e:
|
293 |
+
print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
|
294 |
+
traceback.print_exc()
|
295 |
+
score = 0.0
|
296 |
+
clear_torch_cache()
|
297 |
+
except (Exception, RuntimeError) as e:
|
298 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
299 |
+
'expected scalar type Half but found Float' in str(e) or \
|
300 |
+
'probability tensor contains either' in str(e) or \
|
301 |
+
'cublasLt ran into an error!' in str(e):
|
302 |
+
print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
|
303 |
+
flush=True)
|
304 |
+
traceback.print_exc()
|
305 |
+
score = 0.0
|
306 |
+
clear_torch_cache()
|
307 |
+
else:
|
308 |
+
raise
|
309 |
+
print("SCORE %s: %s" % (exi, score), flush=True)
|
310 |
+
score_dump.append(ex + [prompt, res, score])
|
311 |
+
# dump every score in case abort
|
312 |
+
df_scores = pd.DataFrame(score_dump,
|
313 |
+
columns=eval_func_param_names + eval_extra_columns)
|
314 |
+
df_scores.to_parquet(eval_filename, index=False)
|
315 |
+
# plot histogram so far
|
316 |
+
plt.figure(figsize=(10, 10))
|
317 |
+
plt.hist(df_scores['score'], bins=20)
|
318 |
+
score_avg = np.mean(df_scores['score'])
|
319 |
+
score_median = np.median(df_scores['score'])
|
320 |
+
plt.title("Score avg: %s median: %s" % (score_avg, score_median))
|
321 |
+
plt.savefig(eval_filename.replace('.parquet', '.png'))
|
322 |
+
plt.close()
|
323 |
+
|
324 |
+
print("END" + "=" * 102)
|
325 |
+
print("")
|
326 |
+
t2 = time.time()
|
327 |
+
print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
|
328 |
+
t1 = time.time()
|
329 |
+
print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
|
330 |
+
return eval_filename
|
331 |
+
|
332 |
+
if gradio:
|
333 |
+
# imported here so don't require gradio to run generate
|
334 |
+
from gradio_runner import go_gradio
|
335 |
+
|
336 |
+
# get default model
|
337 |
+
all_kwargs = locals().copy()
|
338 |
+
if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
|
339 |
+
model0, tokenizer0, device = get_model(**all_kwargs)
|
340 |
+
else:
|
341 |
+
# if empty model, then don't load anything, just get gradio up
|
342 |
+
model0, tokenizer0, device = None, None, None
|
343 |
+
model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
|
344 |
+
|
345 |
+
# get score model
|
346 |
+
smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
|
347 |
+
score_model_state0 = [smodel, stokenizer, sdevice, score_model]
|
348 |
+
|
349 |
+
go_gradio(**locals())
|
350 |
+
|
351 |
+
|
352 |
+
def get_device():
|
353 |
+
if torch.cuda.is_available():
|
354 |
+
device = "cuda"
|
355 |
+
else:
|
356 |
+
device = "cpu"
|
357 |
+
|
358 |
+
return device
|
359 |
+
|
360 |
+
|
361 |
+
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
362 |
+
gpu_id=0,
|
363 |
+
use_auth_token=False):
|
364 |
+
"""
|
365 |
+
Ensure model gets on correct device
|
366 |
+
:param base_model:
|
367 |
+
:param model_loader:
|
368 |
+
:param load_half:
|
369 |
+
:param model_kwargs:
|
370 |
+
:param reward_type:
|
371 |
+
:param gpu_id:
|
372 |
+
:param use_auth_token:
|
373 |
+
:return:
|
374 |
+
"""
|
375 |
+
with init_empty_weights():
|
376 |
+
from transformers import AutoConfig
|
377 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
|
378 |
+
model = AutoModel.from_config(
|
379 |
+
config,
|
380 |
+
)
|
381 |
+
|
382 |
+
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
383 |
+
# NOTE: Some models require avoiding sharding some layers,
|
384 |
+
# then would pass no_split_module_classes and give list of those layers.
|
385 |
+
device_map = infer_auto_device_map(
|
386 |
+
model,
|
387 |
+
dtype=torch.float16 if load_half else torch.float32,
|
388 |
+
)
|
389 |
+
if hasattr(model, 'model'):
|
390 |
+
device_map_model = infer_auto_device_map(
|
391 |
+
model.model,
|
392 |
+
dtype=torch.float16 if load_half else torch.float32,
|
393 |
+
)
|
394 |
+
device_map.update(device_map_model)
|
395 |
+
print('device_map: %s' % device_map, flush=True)
|
396 |
+
|
397 |
+
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
398 |
+
|
399 |
+
if n_gpus > 0:
|
400 |
+
if gpu_id >= 0:
|
401 |
+
# FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
|
402 |
+
# So avoid for now, just put on first GPU, unless score_model, put on last
|
403 |
+
if reward_type:
|
404 |
+
device_map = {'': n_gpus - 1}
|
405 |
+
else:
|
406 |
+
device_map = {'': min(n_gpus - 1, gpu_id)}
|
407 |
+
if gpu_id == -1:
|
408 |
+
device_map = {'': 'cuda'}
|
409 |
+
else:
|
410 |
+
device_map = {'': 'cpu'}
|
411 |
+
model_kwargs['load_in_8bit'] = False
|
412 |
+
|
413 |
+
load_in_8bit = model_kwargs.get('load_in_8bit', False)
|
414 |
+
model_kwargs['device_map'] = device_map
|
415 |
+
|
416 |
+
if load_in_8bit or not load_half:
|
417 |
+
model = model_loader.from_pretrained(
|
418 |
+
base_model,
|
419 |
+
**model_kwargs,
|
420 |
+
)
|
421 |
+
else:
|
422 |
+
model = model_loader.from_pretrained(
|
423 |
+
base_model,
|
424 |
+
**model_kwargs,
|
425 |
+
).half()
|
426 |
+
return model
|
427 |
+
|
428 |
+
|
429 |
+
def get_model(
|
430 |
+
load_8bit: bool = False,
|
431 |
+
load_half: bool = True,
|
432 |
+
infer_devices: bool = True,
|
433 |
+
base_model: str = '',
|
434 |
+
tokenizer_base_model: str = '',
|
435 |
+
lora_weights: str = "",
|
436 |
+
gpu_id: int = 0,
|
437 |
+
|
438 |
+
reward_type: bool = None,
|
439 |
+
local_files_only: bool = False,
|
440 |
+
resume_download: bool = True,
|
441 |
+
use_auth_token: Union[str, bool] = False,
|
442 |
+
compile: bool = True,
|
443 |
+
**kwargs,
|
444 |
+
):
|
445 |
+
"""
|
446 |
+
|
447 |
+
:param load_8bit: load model in 8-bit, not supported by all models
|
448 |
+
:param load_half: load model in 16-bit
|
449 |
+
:param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
|
450 |
+
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
451 |
+
So it is not the default
|
452 |
+
:param base_model: name/path of base model
|
453 |
+
:param tokenizer_base_model: name/path of tokenizer
|
454 |
+
:param lora_weights: name/path
|
455 |
+
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
|
456 |
+
:param reward_type: reward type model for sequence classification
|
457 |
+
:param local_files_only: use local files instead of from HF
|
458 |
+
:param resume_download: resume downloads from HF
|
459 |
+
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
460 |
+
:param compile: whether to compile torch model
|
461 |
+
:param kwargs:
|
462 |
+
:return:
|
463 |
+
"""
|
464 |
+
print("Get %s model" % base_model, flush=True)
|
465 |
+
if lora_weights is not None and lora_weights.strip():
|
466 |
+
print("Get %s lora weights" % lora_weights, flush=True)
|
467 |
+
device = get_device()
|
468 |
+
|
469 |
+
if 'gpt2' in base_model.lower():
|
470 |
+
# RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
|
471 |
+
load_8bit = False
|
472 |
+
|
473 |
+
assert base_model.strip(), (
|
474 |
+
"Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
|
475 |
+
)
|
476 |
+
|
477 |
+
from transformers import AutoConfig
|
478 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
|
479 |
+
llama_type_from_config = 'llama' in str(config).lower()
|
480 |
+
llama_type_from_name = "llama" in base_model.lower()
|
481 |
+
llama_type = llama_type_from_config or llama_type_from_name
|
482 |
+
if llama_type:
|
483 |
+
print("Detected as llama type from"
|
484 |
+
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
485 |
+
|
486 |
+
model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
|
487 |
+
if not tokenizer_base_model:
|
488 |
+
tokenizer_base_model = base_model
|
489 |
+
|
490 |
+
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
491 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
492 |
+
local_files_only=local_files_only,
|
493 |
+
resume_download=resume_download,
|
494 |
+
use_auth_token=use_auth_token,
|
495 |
+
)
|
496 |
+
else:
|
497 |
+
tokenizer = tokenizer_loader
|
498 |
+
|
499 |
+
if isinstance(tokenizer, str):
|
500 |
+
# already a pipeline, tokenizer_loader is string for task
|
501 |
+
model = model_loader(tokenizer,
|
502 |
+
model=base_model,
|
503 |
+
device=0 if device == "cuda" else -1,
|
504 |
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
|
505 |
+
else:
|
506 |
+
assert device in ["cuda", "cpu"], "Unsupported device %s" % device
|
507 |
+
model_kwargs = dict(local_files_only=local_files_only,
|
508 |
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
509 |
+
resume_download=resume_download,
|
510 |
+
use_auth_token=use_auth_token)
|
511 |
+
if 'mbart-' not in base_model.lower():
|
512 |
+
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
513 |
+
device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
|
514 |
+
))
|
515 |
+
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
516 |
+
# could put on other GPUs
|
517 |
+
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
|
518 |
+
model_kwargs.pop('torch_dtype', None)
|
519 |
+
|
520 |
+
if not lora_weights:
|
521 |
+
with torch.device(device):
|
522 |
+
if infer_devices:
|
523 |
+
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
524 |
+
gpu_id=gpu_id, use_auth_token=use_auth_token)
|
525 |
+
else:
|
526 |
+
if load_half and not load_8bit:
|
527 |
+
model = model_loader.from_pretrained(
|
528 |
+
base_model,
|
529 |
+
**model_kwargs).half()
|
530 |
+
else:
|
531 |
+
model = model_loader.from_pretrained(
|
532 |
+
base_model,
|
533 |
+
**model_kwargs)
|
534 |
+
elif load_8bit:
|
535 |
+
model = model_loader.from_pretrained(
|
536 |
+
base_model,
|
537 |
+
**model_kwargs
|
538 |
+
)
|
539 |
+
model = PeftModel.from_pretrained(
|
540 |
+
model,
|
541 |
+
lora_weights,
|
542 |
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
543 |
+
local_files_only=local_files_only,
|
544 |
+
resume_download=resume_download,
|
545 |
+
use_auth_token=use_auth_token,
|
546 |
+
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
|
547 |
+
)
|
548 |
+
else:
|
549 |
+
with torch.device(device):
|
550 |
+
model = model_loader.from_pretrained(
|
551 |
+
base_model,
|
552 |
+
**model_kwargs
|
553 |
+
)
|
554 |
+
model = PeftModel.from_pretrained(
|
555 |
+
model,
|
556 |
+
lora_weights,
|
557 |
+
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
558 |
+
local_files_only=local_files_only,
|
559 |
+
resume_download=resume_download,
|
560 |
+
use_auth_token=use_auth_token,
|
561 |
+
device_map="auto",
|
562 |
+
)
|
563 |
+
if load_half:
|
564 |
+
model.half()
|
565 |
+
|
566 |
+
# unwind broken decapoda-research config
|
567 |
+
if llama_type:
|
568 |
+
model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
|
569 |
+
model.config.bos_token_id = 1
|
570 |
+
model.config.eos_token_id = 2
|
571 |
+
if 'gpt2' in base_model.lower():
|
572 |
+
# add special tokens that otherwise all share the same id
|
573 |
+
tokenizer.add_special_tokens({'bos_token': '<bos>',
|
574 |
+
'eos_token': '<eos>',
|
575 |
+
'pad_token': '<pad>'})
|
576 |
+
|
577 |
+
if not isinstance(tokenizer, str):
|
578 |
+
model.eval()
|
579 |
+
if torch.__version__ >= "2" and sys.platform != "win32" and compile:
|
580 |
+
model = torch.compile(model)
|
581 |
+
|
582 |
+
return model, tokenizer, device
|
583 |
+
|
584 |
+
|
585 |
+
def get_score_model(**kwargs):
|
586 |
+
# score model
|
587 |
+
if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
|
588 |
+
score_all_kwargs = kwargs.copy()
|
589 |
+
score_all_kwargs['load_8bit'] = False
|
590 |
+
score_all_kwargs['load_half'] = False
|
591 |
+
score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
|
592 |
+
score_all_kwargs['tokenizer_base_model'] = ''
|
593 |
+
score_all_kwargs['lora_weights'] = ''
|
594 |
+
score_all_kwargs['llama_type'] = False
|
595 |
+
score_all_kwargs['compile'] = False
|
596 |
+
smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
|
597 |
+
else:
|
598 |
+
smodel, stokenizer, sdevice = None, None, None
|
599 |
+
return smodel, stokenizer, sdevice
|
600 |
+
|
601 |
+
|
602 |
+
eval_func_param_names = ['instruction',
|
603 |
+
'iinput',
|
604 |
+
'context',
|
605 |
+
'stream_output',
|
606 |
+
'prompt_type',
|
607 |
+
'temperature',
|
608 |
+
'top_p',
|
609 |
+
'top_k',
|
610 |
+
'num_beams',
|
611 |
+
'max_new_tokens',
|
612 |
+
'min_new_tokens',
|
613 |
+
'early_stopping',
|
614 |
+
'max_time',
|
615 |
+
'repetition_penalty',
|
616 |
+
'num_return_sequences',
|
617 |
+
'do_sample',
|
618 |
+
'chat',
|
619 |
+
'instruction_nochat',
|
620 |
+
'iinput_nochat',
|
621 |
+
]
|
622 |
+
|
623 |
+
|
624 |
+
def evaluate(
|
625 |
+
model_state,
|
626 |
+
# START NOTE: Examples must have same order of parameters
|
627 |
+
instruction,
|
628 |
+
iinput,
|
629 |
+
context,
|
630 |
+
stream_output,
|
631 |
+
prompt_type,
|
632 |
+
temperature,
|
633 |
+
top_p,
|
634 |
+
top_k,
|
635 |
+
num_beams,
|
636 |
+
max_new_tokens,
|
637 |
+
min_new_tokens,
|
638 |
+
early_stopping,
|
639 |
+
max_time,
|
640 |
+
repetition_penalty,
|
641 |
+
num_return_sequences,
|
642 |
+
do_sample,
|
643 |
+
chat,
|
644 |
+
instruction_nochat,
|
645 |
+
iinput_nochat,
|
646 |
+
# END NOTE: Examples must have same order of parameters
|
647 |
+
src_lang=None,
|
648 |
+
tgt_lang=None,
|
649 |
+
debug=False,
|
650 |
+
concurrency_count=None,
|
651 |
+
save_dir=None,
|
652 |
+
hard_stop_list=None,
|
653 |
+
sanitize_bot_response=True,
|
654 |
+
model_state0=None,
|
655 |
+
is_low_mem=None,
|
656 |
+
raise_generate_gpu_exceptions=None,
|
657 |
+
chat_context=None,
|
658 |
+
):
|
659 |
+
# ensure passed these
|
660 |
+
assert concurrency_count is not None
|
661 |
+
assert is_low_mem is not None
|
662 |
+
assert raise_generate_gpu_exceptions is not None
|
663 |
+
assert chat_context is not None
|
664 |
+
|
665 |
+
if debug:
|
666 |
+
locals_dict = locals().copy()
|
667 |
+
locals_dict.pop('model_state', None)
|
668 |
+
locals_dict.pop('model_state0', None)
|
669 |
+
print(locals_dict)
|
670 |
+
|
671 |
+
no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
|
672 |
+
|
673 |
+
if model_state0 is None:
|
674 |
+
# e.g. for no gradio case, set dummy value, else should be set
|
675 |
+
model_state0 = [None, None, None, None]
|
676 |
+
|
677 |
+
if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
|
678 |
+
# try to free-up original model (i.e. list was passed as reference)
|
679 |
+
if model_state0 is not None and model_state0[0] is not None:
|
680 |
+
model_state0[0].cpu()
|
681 |
+
model_state0[0] = None
|
682 |
+
# try to free-up original tokenizer (i.e. list was passed as reference)
|
683 |
+
if model_state0 is not None and model_state0[1] is not None:
|
684 |
+
model_state0[1] = None
|
685 |
+
clear_torch_cache()
|
686 |
+
model, tokenizer, device, base_model = model_state
|
687 |
+
elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
|
688 |
+
assert isinstance(model_state[0], str)
|
689 |
+
model, tokenizer, device, base_model = model_state0
|
690 |
+
else:
|
691 |
+
raise AssertionError(no_model_msg)
|
692 |
+
|
693 |
+
if base_model is None:
|
694 |
+
raise AssertionError(no_model_msg)
|
695 |
+
|
696 |
+
assert base_model.strip(), no_model_msg
|
697 |
+
assert model, "Model is missing"
|
698 |
+
assert tokenizer, "Tokenizer is missing"
|
699 |
+
|
700 |
+
# choose chat or non-chat mode
|
701 |
+
if not chat:
|
702 |
+
instruction = instruction_nochat
|
703 |
+
iinput = iinput_nochat
|
704 |
+
|
705 |
+
if not context:
|
706 |
+
# get hidden context if have one
|
707 |
+
context = get_context(chat_context, prompt_type)
|
708 |
+
|
709 |
+
data_point = dict(context=context, instruction=instruction, input=iinput)
|
710 |
+
prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
|
711 |
+
prompt = prompter.generate_prompt(data_point)
|
712 |
+
|
713 |
+
if hard_stop_list is None:
|
714 |
+
# acts like undo on user entry and bot response
|
715 |
+
hard_stop_list = []
|
716 |
+
|
717 |
+
if isinstance(tokenizer, str):
|
718 |
+
# pipeline
|
719 |
+
if tokenizer == "summarization":
|
720 |
+
key = 'summary_text'
|
721 |
+
else:
|
722 |
+
raise RuntimeError("No such task type %s" % tokenizer)
|
723 |
+
# NOTE: uses max_length only
|
724 |
+
yield model(prompt, max_length=max_new_tokens)[0][key]
|
725 |
+
|
726 |
+
if 'mbart-' in base_model.lower():
|
727 |
+
assert src_lang is not None
|
728 |
+
tokenizer.src_lang = languages_covered()[src_lang]
|
729 |
+
|
730 |
+
if chat:
|
731 |
+
# override, ignore user change
|
732 |
+
num_return_sequences = 1
|
733 |
+
if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
|
734 |
+
if prompt_type == 'human_bot':
|
735 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
736 |
+
# stopping only starts once output is beyond prompt
|
737 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
738 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
739 |
+
encounters = [1, 2]
|
740 |
+
elif prompt_type == 'instruct_vicuna':
|
741 |
+
# even below is not enough, generic strings and many ways to encode
|
742 |
+
stop_words = [
|
743 |
+
'### Human:',
|
744 |
+
"""
|
745 |
+
### Human:""",
|
746 |
+
"""
|
747 |
+
### Human:
|
748 |
+
""",
|
749 |
+
'### Assistant:',
|
750 |
+
"""
|
751 |
+
### Assistant:""",
|
752 |
+
"""
|
753 |
+
### Assistant:
|
754 |
+
""",
|
755 |
+
]
|
756 |
+
encounters = [1, 2]
|
757 |
+
else:
|
758 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
759 |
+
stop_words = ['### End']
|
760 |
+
encounters = [1]
|
761 |
+
stop_words_ids = [
|
762 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
763 |
+
# handle single token case
|
764 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
765 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
766 |
+
# avoid padding in front of tokens
|
767 |
+
if tokenizer.pad_token:
|
768 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
769 |
+
# handle fake \n added
|
770 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
771 |
+
# build stopper
|
772 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
|
773 |
+
else:
|
774 |
+
stopping_criteria = StoppingCriteriaList()
|
775 |
+
|
776 |
+
# help to avoid errors like:
|
777 |
+
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
778 |
+
# RuntimeError: expected scalar type Half but found Float
|
779 |
+
# with - 256
|
780 |
+
max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
|
781 |
+
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
782 |
+
output_smallest = 30 * 4
|
783 |
+
prompt = prompt[-cutoff_len - output_smallest:]
|
784 |
+
inputs = tokenizer(prompt,
|
785 |
+
return_tensors="pt",
|
786 |
+
truncation=True,
|
787 |
+
max_length=max_length_tokenize)
|
788 |
+
if debug and len(inputs["input_ids"]) > 0:
|
789 |
+
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
790 |
+
input_ids = inputs["input_ids"].to(device)
|
791 |
+
generation_config = GenerationConfig(
|
792 |
+
temperature=float(temperature),
|
793 |
+
top_p=float(top_p),
|
794 |
+
top_k=top_k,
|
795 |
+
num_beams=num_beams,
|
796 |
+
do_sample=do_sample,
|
797 |
+
repetition_penalty=float(repetition_penalty),
|
798 |
+
num_return_sequences=num_return_sequences,
|
799 |
+
renormalize_logits=True,
|
800 |
+
remove_invalid_values=True,
|
801 |
+
)
|
802 |
+
|
803 |
+
gen_kwargs = dict(input_ids=input_ids,
|
804 |
+
generation_config=generation_config,
|
805 |
+
return_dict_in_generate=True,
|
806 |
+
output_scores=True,
|
807 |
+
max_new_tokens=max_new_tokens, # prompt + new
|
808 |
+
min_new_tokens=min_new_tokens, # prompt + new
|
809 |
+
early_stopping=early_stopping, # False, True, "never"
|
810 |
+
max_time=max_time,
|
811 |
+
stopping_criteria=stopping_criteria,
|
812 |
+
)
|
813 |
+
if 'gpt2' in base_model.lower():
|
814 |
+
gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
|
815 |
+
elif 'mbart-' in base_model.lower():
|
816 |
+
assert tgt_lang is not None
|
817 |
+
tgt_lang = languages_covered()[tgt_lang]
|
818 |
+
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
|
819 |
+
else:
|
820 |
+
gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
|
821 |
+
|
822 |
+
decoder = functools.partial(tokenizer.decode,
|
823 |
+
skip_special_tokens=True,
|
824 |
+
clean_up_tokenization_spaces=True,
|
825 |
+
)
|
826 |
+
decoder_raw = functools.partial(tokenizer.decode,
|
827 |
+
skip_special_tokens=False,
|
828 |
+
clean_up_tokenization_spaces=True,
|
829 |
+
)
|
830 |
+
|
831 |
+
with torch.no_grad():
|
832 |
+
# protection for gradio not keeping track of closed users,
|
833 |
+
# else hit bitsandbytes lack of thread safety:
|
834 |
+
# https://github.com/h2oai/h2ogpt/issues/104
|
835 |
+
# but only makes sense if concurrency_count == 1
|
836 |
+
context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
|
837 |
+
print('Pre-Generate: %s' % str(datetime.now()), flush=True)
|
838 |
+
decoded_output = None
|
839 |
+
with context_class("generate.lock"):
|
840 |
+
print('Generate: %s' % str(datetime.now()), flush=True)
|
841 |
+
# decoded tokenized prompt can deviate from prompt due to special characters
|
842 |
+
inputs_decoded = decoder(input_ids[0])
|
843 |
+
inputs_decoded_raw = decoder_raw(input_ids[0])
|
844 |
+
if inputs_decoded == prompt:
|
845 |
+
# normal
|
846 |
+
pass
|
847 |
+
elif inputs_decoded.lstrip() == prompt.lstrip():
|
848 |
+
# sometimes extra space in front, make prompt same for prompt removal
|
849 |
+
prompt = inputs_decoded
|
850 |
+
elif inputs_decoded_raw == prompt:
|
851 |
+
# some models specify special tokens that are part of normal prompt, so can't skip them
|
852 |
+
inputs_decoded_raw = inputs_decoded
|
853 |
+
decoder = decoder_raw
|
854 |
+
else:
|
855 |
+
print("WARNING: Special characters in prompt", flush=True)
|
856 |
+
if stream_output:
|
857 |
+
skip_prompt = False
|
858 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
|
859 |
+
gen_kwargs.update(dict(streamer=streamer))
|
860 |
+
target_func = generate_with_exceptions
|
861 |
+
target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
|
862 |
+
raise_generate_gpu_exceptions, **gen_kwargs)
|
863 |
+
thread = Thread(target=target)
|
864 |
+
thread.start()
|
865 |
+
outputs = ""
|
866 |
+
for new_text in streamer:
|
867 |
+
outputs += new_text
|
868 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
869 |
+
sanitize_bot_response=sanitize_bot_response)
|
870 |
+
decoded_output = outputs
|
871 |
+
else:
|
872 |
+
outputs = model.generate(**gen_kwargs)
|
873 |
+
outputs = [decoder(s) for s in outputs.sequences]
|
874 |
+
yield prompter.get_response(outputs, prompt=inputs_decoded,
|
875 |
+
sanitize_bot_response=sanitize_bot_response)
|
876 |
+
if outputs and len(outputs) >= 1:
|
877 |
+
decoded_output = prompt + outputs[0]
|
878 |
+
if save_dir and decoded_output:
|
879 |
+
save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
|
880 |
+
print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
|
881 |
+
|
882 |
+
|
883 |
+
def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
|
884 |
+
try:
|
885 |
+
func(**kwargs)
|
886 |
+
except torch.cuda.OutOfMemoryError as e:
|
887 |
+
print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
888 |
+
flush=True)
|
889 |
+
if kwargs['input_ids'] is not None:
|
890 |
+
kwargs['input_ids'].cpu()
|
891 |
+
kwargs['input_ids'] = None
|
892 |
+
traceback.print_exc()
|
893 |
+
clear_torch_cache()
|
894 |
+
return
|
895 |
+
except (Exception, RuntimeError) as e:
|
896 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
897 |
+
'expected scalar type Half but found Float' in str(e) or \
|
898 |
+
'probability tensor contains either' in str(e) or \
|
899 |
+
'cublasLt ran into an error!' in str(e) or \
|
900 |
+
'mat1 and mat2 shapes cannot be multiplied' in str(e):
|
901 |
+
print(
|
902 |
+
"GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
|
903 |
+
flush=True)
|
904 |
+
traceback.print_exc()
|
905 |
+
clear_torch_cache()
|
906 |
+
if raise_generate_gpu_exceptions:
|
907 |
+
raise
|
908 |
+
return
|
909 |
+
else:
|
910 |
+
clear_torch_cache()
|
911 |
+
raise
|
912 |
+
|
913 |
+
|
914 |
+
def get_generate_params(model_lower, chat,
|
915 |
+
stream_output, show_examples,
|
916 |
+
prompt_type, temperature, top_p, top_k, num_beams,
|
917 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time,
|
918 |
+
repetition_penalty, num_return_sequences,
|
919 |
+
do_sample):
|
920 |
+
use_defaults = False
|
921 |
+
use_default_examples = True
|
922 |
+
examples = []
|
923 |
+
task_info = f"{prompt_type}"
|
924 |
+
if model_lower:
|
925 |
+
print(f"Using Model {model_lower}", flush=True)
|
926 |
+
else:
|
927 |
+
print("No model defined yet", flush=True)
|
928 |
+
|
929 |
+
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
|
930 |
+
early_stopping = early_stopping if early_stopping is not None else False
|
931 |
+
max_time_defaults = 60 * 3
|
932 |
+
max_time = max_time if max_time is not None else max_time_defaults
|
933 |
+
|
934 |
+
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
935 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
936 |
+
|
937 |
+
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
938 |
+
if show_examples is None:
|
939 |
+
if chat:
|
940 |
+
show_examples = False
|
941 |
+
else:
|
942 |
+
show_examples = True
|
943 |
+
|
944 |
+
summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
|
945 |
+
Philipp: Sure you can use the new Hugging Face Deep Learning Container.
|
946 |
+
Jeff: ok.
|
947 |
+
Jeff: and how can I get started?
|
948 |
+
Jeff: where can I find documentation?
|
949 |
+
Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
|
950 |
+
|
951 |
+
if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
|
952 |
+
placeholder_instruction = summarize_example1
|
953 |
+
placeholder_input = ""
|
954 |
+
use_defaults = True
|
955 |
+
use_default_examples = False
|
956 |
+
examples += [
|
957 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
958 |
+
1.0, 1,
|
959 |
+
False]]
|
960 |
+
task_info = "Summarization"
|
961 |
+
elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
|
962 |
+
placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
|
963 |
+
placeholder_input = ""
|
964 |
+
use_defaults = True
|
965 |
+
use_default_examples = True
|
966 |
+
task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
|
967 |
+
elif 'mbart-' in model_lower:
|
968 |
+
placeholder_instruction = "The girl has long hair."
|
969 |
+
placeholder_input = ""
|
970 |
+
use_defaults = True
|
971 |
+
use_default_examples = False
|
972 |
+
examples += [
|
973 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
974 |
+
1.0, 1,
|
975 |
+
False]]
|
976 |
+
elif 'gpt2' in model_lower:
|
977 |
+
placeholder_instruction = "The sky is"
|
978 |
+
placeholder_input = ""
|
979 |
+
prompt_type = prompt_type or 'plain'
|
980 |
+
use_default_examples = True # some will be odd "continuations" but can be ok
|
981 |
+
examples += [
|
982 |
+
[placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
|
983 |
+
1.0, 1,
|
984 |
+
False]]
|
985 |
+
task_info = "Auto-complete phrase, code, etc."
|
986 |
+
use_defaults = True
|
987 |
+
else:
|
988 |
+
if chat:
|
989 |
+
placeholder_instruction = "Enter a question or imperative."
|
990 |
+
else:
|
991 |
+
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
992 |
+
placeholder_input = ""
|
993 |
+
if model_lower:
|
994 |
+
prompt_type = prompt_type or 'human_bot'
|
995 |
+
else:
|
996 |
+
prompt_type = ''
|
997 |
+
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
998 |
+
stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
|
999 |
+
False]]
|
1000 |
+
task_info = "No task"
|
1001 |
+
if prompt_type == 'instruct':
|
1002 |
+
task_info = "Answer question or follow imperative as instruction with optionally input."
|
1003 |
+
elif prompt_type == 'plain':
|
1004 |
+
task_info = "Auto-complete phrase, code, etc."
|
1005 |
+
elif prompt_type == 'human_bot':
|
1006 |
+
if chat:
|
1007 |
+
task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
|
1008 |
+
else:
|
1009 |
+
task_info = "Ask question/imperative (input concatenated with instruction)"
|
1010 |
+
|
1011 |
+
# revert to plain if still nothing
|
1012 |
+
prompt_type = prompt_type or 'plain'
|
1013 |
+
if use_defaults:
|
1014 |
+
temperature = 1.0 if temperature is None else temperature
|
1015 |
+
top_p = 1.0 if top_p is None else top_p
|
1016 |
+
top_k = 40 if top_k is None else top_k
|
1017 |
+
num_beams = num_beams or 1
|
1018 |
+
max_new_tokens = max_new_tokens or 128
|
1019 |
+
repetition_penalty = repetition_penalty or 1.07
|
1020 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1021 |
+
do_sample = False if do_sample is None else do_sample
|
1022 |
+
else:
|
1023 |
+
temperature = 0.2 if temperature is None else temperature
|
1024 |
+
top_p = 0.85 if top_p is None else top_p
|
1025 |
+
top_k = 70 if top_k is None else top_k
|
1026 |
+
if chat:
|
1027 |
+
num_beams = num_beams or 1
|
1028 |
+
else:
|
1029 |
+
num_beams = num_beams or 4
|
1030 |
+
max_new_tokens = max_new_tokens or 256
|
1031 |
+
repetition_penalty = repetition_penalty or 1.07
|
1032 |
+
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
1033 |
+
do_sample = True if do_sample is None else do_sample
|
1034 |
+
# doesn't include chat, instruction_nochat, iinput_nochat, added later
|
1035 |
+
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
|
1036 |
+
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
|
1037 |
+
|
1038 |
+
if use_default_examples:
|
1039 |
+
examples += [
|
1040 |
+
["Translate English to French", "Good morning"] + params_list,
|
1041 |
+
["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
|
1042 |
+
["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
|
1043 |
+
[
|
1044 |
+
"Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
|
1045 |
+
''] + params_list,
|
1046 |
+
['Translate to German: My name is Arthur', ''] + params_list,
|
1047 |
+
["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
|
1048 |
+
['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
|
1049 |
+
''] + params_list,
|
1050 |
+
['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
|
1051 |
+
['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
|
1052 |
+
["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
|
1053 |
+
[
|
1054 |
+
"Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
|
1055 |
+
''] + params_list,
|
1056 |
+
['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
|
1057 |
+
[
|
1058 |
+
'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
|
1059 |
+
''] + params_list,
|
1060 |
+
["""def area_of_rectangle(a: float, b: float):
|
1061 |
+
\"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
|
1062 |
+
["""# a function in native python:
|
1063 |
+
def mean(a):
|
1064 |
+
return sum(a)/len(a)
|
1065 |
+
|
1066 |
+
# the same function using numpy:
|
1067 |
+
import numpy as np
|
1068 |
+
def mean(a):""", ''] + params_list,
|
1069 |
+
["""X = np.random.randn(100, 100)
|
1070 |
+
y = np.random.randint(0, 1, 100)
|
1071 |
+
|
1072 |
+
# fit random forest classifier with 20 estimators""", ''] + params_list,
|
1073 |
+
]
|
1074 |
+
|
1075 |
+
src_lang = "English"
|
1076 |
+
tgt_lang = "Russian"
|
1077 |
+
|
1078 |
+
# move to correct position
|
1079 |
+
for example in examples:
|
1080 |
+
example += [chat, '', '']
|
1081 |
+
# adjust examples if non-chat mode
|
1082 |
+
if not chat:
|
1083 |
+
example[eval_func_param_names.index('instruction_nochat')] = example[
|
1084 |
+
eval_func_param_names.index('instruction')]
|
1085 |
+
example[eval_func_param_names.index('instruction')] = ''
|
1086 |
+
|
1087 |
+
example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
|
1088 |
+
example[eval_func_param_names.index('iinput')] = ''
|
1089 |
+
|
1090 |
+
return placeholder_instruction, placeholder_input, \
|
1091 |
+
stream_output, show_examples, \
|
1092 |
+
prompt_type, temperature, top_p, top_k, num_beams, \
|
1093 |
+
max_new_tokens, min_new_tokens, early_stopping, max_time, \
|
1094 |
+
repetition_penalty, num_return_sequences, \
|
1095 |
+
do_sample, \
|
1096 |
+
src_lang, tgt_lang, \
|
1097 |
+
examples, \
|
1098 |
+
task_info
|
1099 |
+
|
1100 |
+
|
1101 |
+
def languages_covered():
|
1102 |
+
# https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
|
1103 |
+
covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
|
1104 |
+
covered = covered.split(', ')
|
1105 |
+
covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
|
1106 |
+
return covered
|
1107 |
+
|
1108 |
+
|
1109 |
+
def get_context(chat_context, prompt_type):
|
1110 |
+
if chat_context and prompt_type == 'human_bot':
|
1111 |
+
context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
|
1112 |
+
<human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
|
1113 |
+
else:
|
1114 |
+
context0 = ''
|
1115 |
+
return context0
|
1116 |
+
|
1117 |
+
|
1118 |
+
def test_test_prompt(prompt_type='instruct', data_point=0):
|
1119 |
+
example_data_point = example_data_points[data_point]
|
1120 |
+
example_data_point.pop('output', None)
|
1121 |
+
return generate_prompt(example_data_point, prompt_type, False, False)
|
1122 |
+
|
1123 |
+
|
1124 |
+
def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
|
1125 |
+
question = question[-cutoff_len:]
|
1126 |
+
answer = answer[-cutoff_len:]
|
1127 |
+
|
1128 |
+
inputs = stokenizer(question, answer,
|
1129 |
+
return_tensors="pt",
|
1130 |
+
truncation=True,
|
1131 |
+
max_length=max_length_tokenize).to(smodel.device)
|
1132 |
+
try:
|
1133 |
+
score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
|
1134 |
+
except torch.cuda.OutOfMemoryError as e:
|
1135 |
+
print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
|
1136 |
+
del inputs
|
1137 |
+
traceback.print_exc()
|
1138 |
+
clear_torch_cache()
|
1139 |
+
return 'Response Score: GPU OOM'
|
1140 |
+
except (Exception, RuntimeError) as e:
|
1141 |
+
if 'Expected all tensors to be on the same device' in str(e) or \
|
1142 |
+
'expected scalar type Half but found Float' in str(e) or \
|
1143 |
+
'probability tensor contains either' in str(e) or \
|
1144 |
+
'cublasLt ran into an error!' in str(e):
|
1145 |
+
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
|
1146 |
+
flush=True)
|
1147 |
+
traceback.print_exc()
|
1148 |
+
clear_torch_cache()
|
1149 |
+
return 'Response Score: GPU Error'
|
1150 |
+
else:
|
1151 |
+
raise
|
1152 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
1153 |
+
return score
|
1154 |
+
|
1155 |
+
|
1156 |
+
if __name__ == "__main__":
|
1157 |
+
print("""
|
1158 |
+
WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
|
1159 |
+
python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
|
1160 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
|
1161 |
+
|
1162 |
+
# generate without lora weights, no prompt
|
1163 |
+
python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
|
1164 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
|
1165 |
+
|
1166 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
|
1167 |
+
# OpenChatKit settings:
|
1168 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
|
1169 |
+
|
1170 |
+
python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
|
1171 |
+
python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
|
1172 |
+
python generate.py --base_model='philschmid/bart-large-cnn-samsum'
|
1173 |
+
python generate.py --base_model='philschmid/flan-t5-base-samsum'
|
1174 |
+
python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
|
1175 |
+
|
1176 |
+
python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
|
1177 |
+
|
1178 |
+
must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
|
1179 |
+
can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
|
1180 |
+
python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
|
1181 |
+
|
1182 |
+
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
|
1183 |
+
|
1184 |
+
""", flush=True)
|
1185 |
+
fire.Fire(main)
|
gradio_runner.py
ADDED
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import inspect
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
7 |
+
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
8 |
+
ping
|
9 |
+
from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
|
10 |
+
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
14 |
+
|
15 |
+
|
16 |
+
def go_gradio(**kwargs):
|
17 |
+
allow_api = kwargs['allow_api']
|
18 |
+
is_public = kwargs['is_public']
|
19 |
+
is_hf = kwargs['is_hf']
|
20 |
+
is_low_mem = kwargs['is_low_mem']
|
21 |
+
n_gpus = kwargs['n_gpus']
|
22 |
+
admin_pass = kwargs['admin_pass']
|
23 |
+
model_state0 = kwargs['model_state0']
|
24 |
+
score_model_state0 = kwargs['score_model_state0']
|
25 |
+
queue = True
|
26 |
+
|
27 |
+
# easy update of kwargs needed for evaluate() etc.
|
28 |
+
kwargs.update(locals())
|
29 |
+
|
30 |
+
if 'mbart-' in kwargs['model_lower']:
|
31 |
+
instruction_label_nochat = "Text to translate"
|
32 |
+
else:
|
33 |
+
instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
|
34 |
+
" use Enter for multiple input lines)"
|
35 |
+
if kwargs['input_lines'] > 1:
|
36 |
+
instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
|
37 |
+
else:
|
38 |
+
instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
|
39 |
+
|
40 |
+
title = 'h2oGPT'
|
41 |
+
if 'h2ogpt-research' in kwargs['base_model']:
|
42 |
+
title += " [Research demonstration]"
|
43 |
+
if kwargs['verbose']:
|
44 |
+
description = f"""Model {kwargs['base_model']} Instruct dataset.
|
45 |
+
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
|
46 |
+
Command: {str(' '.join(sys.argv))}
|
47 |
+
Hash: {get_githash()}
|
48 |
+
"""
|
49 |
+
else:
|
50 |
+
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).<br>"
|
51 |
+
if is_public:
|
52 |
+
description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
53 |
+
description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
|
54 |
+
if kwargs['load_8bit']:
|
55 |
+
description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
|
56 |
+
description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
|
57 |
+
if 'h2ogpt-research' in kwargs['base_model']:
|
58 |
+
description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
|
59 |
+
description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
|
60 |
+
|
61 |
+
if kwargs['verbose']:
|
62 |
+
task_info_md = f"""
|
63 |
+
### Task: {kwargs['task_info']}"""
|
64 |
+
else:
|
65 |
+
task_info_md = ''
|
66 |
+
|
67 |
+
if kwargs['h2ocolors']:
|
68 |
+
css_code = """footer {visibility: hidden;}
|
69 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
70 |
+
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
71 |
+
"""
|
72 |
+
else:
|
73 |
+
css_code = """footer {visibility: hidden}"""
|
74 |
+
|
75 |
+
if kwargs['gradio_avoid_processing_markdown']:
|
76 |
+
from gradio_client import utils as client_utils
|
77 |
+
from gradio.components import Chatbot
|
78 |
+
|
79 |
+
# gradio has issue with taking too long to process input/output for markdown etc.
|
80 |
+
# Avoid for now, allow raw html to render, good enough for chatbot.
|
81 |
+
def _postprocess_chat_messages(self, chat_message: str):
|
82 |
+
if chat_message is None:
|
83 |
+
return None
|
84 |
+
elif isinstance(chat_message, (tuple, list)):
|
85 |
+
filepath = chat_message[0]
|
86 |
+
mime_type = client_utils.get_mimetype(filepath)
|
87 |
+
filepath = self.make_temp_copy_if_needed(filepath)
|
88 |
+
return {
|
89 |
+
"name": filepath,
|
90 |
+
"mime_type": mime_type,
|
91 |
+
"alt_text": chat_message[1] if len(chat_message) > 1 else None,
|
92 |
+
"data": None, # These last two fields are filled in by the frontend
|
93 |
+
"is_file": True,
|
94 |
+
}
|
95 |
+
elif isinstance(chat_message, str):
|
96 |
+
return chat_message
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
|
99 |
+
|
100 |
+
Chatbot._postprocess_chat_messages = _postprocess_chat_messages
|
101 |
+
|
102 |
+
theme = H2oTheme() if kwargs['h2ocolors'] else SoftTheme()
|
103 |
+
demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
|
104 |
+
callback = gr.CSVLogger()
|
105 |
+
|
106 |
+
model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
|
107 |
+
if kwargs['base_model'].strip() not in model_options:
|
108 |
+
lora_options = [kwargs['base_model'].strip()] + model_options
|
109 |
+
lora_options = kwargs['extra_lora_options']
|
110 |
+
if kwargs['lora_weights'].strip() not in lora_options:
|
111 |
+
lora_options = [kwargs['lora_weights'].strip()] + lora_options
|
112 |
+
# always add in no lora case
|
113 |
+
# add fake space so doesn't go away in gradio dropdown
|
114 |
+
no_lora_str = no_model_str = '[None/Remove]'
|
115 |
+
lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
|
116 |
+
# always add in no model case so can free memory
|
117 |
+
# add fake space so doesn't go away in gradio dropdown
|
118 |
+
model_options = [no_model_str] + model_options
|
119 |
+
|
120 |
+
# transcribe, will be detranscribed before use by evaluate()
|
121 |
+
if not kwargs['lora_weights'].strip():
|
122 |
+
kwargs['lora_weights'] = no_lora_str
|
123 |
+
|
124 |
+
if not kwargs['base_model'].strip():
|
125 |
+
kwargs['base_model'] = no_model_str
|
126 |
+
|
127 |
+
# transcribe for gradio
|
128 |
+
kwargs['gpu_id'] = str(kwargs['gpu_id'])
|
129 |
+
|
130 |
+
no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
|
131 |
+
output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
|
132 |
+
'base_model') else no_model_msg
|
133 |
+
output_label0_model2 = no_model_msg
|
134 |
+
|
135 |
+
with demo:
|
136 |
+
# avoid actual model/tokenizer here or anything that would be bad to deepcopy
|
137 |
+
# https://github.com/gradio-app/gradio/issues/3558
|
138 |
+
model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']])
|
139 |
+
model_state2 = gr.State([None, None, None, None])
|
140 |
+
model_options_state = gr.State([model_options])
|
141 |
+
lora_options_state = gr.State([lora_options])
|
142 |
+
gr.Markdown(f"""
|
143 |
+
{get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
|
144 |
+
|
145 |
+
{description}
|
146 |
+
{task_info_md}
|
147 |
+
""")
|
148 |
+
if is_hf:
|
149 |
+
gr.HTML(
|
150 |
+
'''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
|
151 |
+
|
152 |
+
# go button visible if
|
153 |
+
base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
|
154 |
+
go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
|
155 |
+
normal_block = gr.Row(visible=not base_wanted)
|
156 |
+
with normal_block:
|
157 |
+
with gr.Tabs():
|
158 |
+
with gr.Row():
|
159 |
+
col_nochat = gr.Column(visible=not kwargs['chat'])
|
160 |
+
with col_nochat: # FIXME: for model comparison, and check rest
|
161 |
+
text_output_nochat = gr.Textbox(lines=5, label=output_label0)
|
162 |
+
instruction_nochat = gr.Textbox(
|
163 |
+
lines=kwargs['input_lines'],
|
164 |
+
label=instruction_label_nochat,
|
165 |
+
placeholder=kwargs['placeholder_instruction'],
|
166 |
+
)
|
167 |
+
iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
|
168 |
+
placeholder=kwargs['placeholder_input'])
|
169 |
+
submit_nochat = gr.Button("Submit")
|
170 |
+
flag_btn_nochat = gr.Button("Flag")
|
171 |
+
if not kwargs['auto_score']:
|
172 |
+
with gr.Column(visible=kwargs['score_model']):
|
173 |
+
score_btn_nochat = gr.Button("Score last prompt & response")
|
174 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
175 |
+
else:
|
176 |
+
with gr.Column(visible=kwargs['score_model']):
|
177 |
+
score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
|
178 |
+
col_chat = gr.Column(visible=kwargs['chat'])
|
179 |
+
with col_chat:
|
180 |
+
with gr.Row():
|
181 |
+
text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
|
182 |
+
text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
|
183 |
+
height=kwargs['height'] or 400)
|
184 |
+
with gr.Row():
|
185 |
+
with gr.Column(scale=50):
|
186 |
+
instruction = gr.Textbox(
|
187 |
+
lines=kwargs['input_lines'],
|
188 |
+
label=instruction_label,
|
189 |
+
placeholder=kwargs['placeholder_instruction'],
|
190 |
+
)
|
191 |
+
with gr.Row():
|
192 |
+
submit = gr.Button(value='Submit').style(full_width=False, size='sm')
|
193 |
+
stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
|
194 |
+
with gr.Row():
|
195 |
+
clear = gr.Button("New Conversation")
|
196 |
+
flag_btn = gr.Button("Flag")
|
197 |
+
if not kwargs['auto_score']: # FIXME: For checkbox model2
|
198 |
+
with gr.Column(visible=kwargs['score_model']):
|
199 |
+
with gr.Row():
|
200 |
+
score_btn = gr.Button("Score last prompt & response").style(
|
201 |
+
full_width=False, size='sm')
|
202 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
203 |
+
score_res2 = gr.Row(visible=False)
|
204 |
+
with score_res2:
|
205 |
+
score_btn2 = gr.Button("Score last prompt & response 2").style(
|
206 |
+
full_width=False, size='sm')
|
207 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
|
208 |
+
else:
|
209 |
+
with gr.Column(visible=kwargs['score_model']):
|
210 |
+
score_text = gr.Textbox("Response Score: NA", show_label=False)
|
211 |
+
score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
|
212 |
+
retry = gr.Button("Regenerate")
|
213 |
+
undo = gr.Button("Undo")
|
214 |
+
with gr.TabItem("Input/Output"):
|
215 |
+
with gr.Row():
|
216 |
+
if 'mbart-' in kwargs['model_lower']:
|
217 |
+
src_lang = gr.Dropdown(list(languages_covered().keys()),
|
218 |
+
value=kwargs['src_lang'],
|
219 |
+
label="Input Language")
|
220 |
+
tgt_lang = gr.Dropdown(list(languages_covered().keys()),
|
221 |
+
value=kwargs['tgt_lang'],
|
222 |
+
label="Output Language")
|
223 |
+
with gr.TabItem("Expert"):
|
224 |
+
with gr.Row():
|
225 |
+
with gr.Column():
|
226 |
+
stream_output = gr.components.Checkbox(label="Stream output",
|
227 |
+
value=kwargs['stream_output'])
|
228 |
+
prompt_type = gr.Dropdown(prompt_types_strings,
|
229 |
+
value=kwargs['prompt_type'], label="Prompt Type",
|
230 |
+
visible=not is_public)
|
231 |
+
prompt_type2 = gr.Dropdown(prompt_types_strings,
|
232 |
+
value=kwargs['prompt_type'], label="Prompt Type Model 2",
|
233 |
+
visible=not is_public and False)
|
234 |
+
do_sample = gr.Checkbox(label="Sample",
|
235 |
+
info="Enable sampler, required for use of temperature, top_p, top_k",
|
236 |
+
value=kwargs['do_sample'])
|
237 |
+
temperature = gr.Slider(minimum=0.01, maximum=3,
|
238 |
+
value=kwargs['temperature'],
|
239 |
+
label="Temperature",
|
240 |
+
info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
|
241 |
+
top_p = gr.Slider(minimum=0, maximum=1,
|
242 |
+
value=kwargs['top_p'], label="Top p",
|
243 |
+
info="Cumulative probability of tokens to sample from")
|
244 |
+
top_k = gr.Slider(
|
245 |
+
minimum=0, maximum=100, step=1,
|
246 |
+
value=kwargs['top_k'], label="Top k",
|
247 |
+
info='Num. tokens to sample from'
|
248 |
+
)
|
249 |
+
max_beams = 8 if not is_low_mem else 2
|
250 |
+
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
|
251 |
+
value=min(max_beams, kwargs['num_beams']), label="Beams",
|
252 |
+
info="Number of searches for optimal overall probability. "
|
253 |
+
"Uses more GPU memory/compute")
|
254 |
+
max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
|
255 |
+
max_new_tokens = gr.Slider(
|
256 |
+
minimum=1, maximum=max_max_new_tokens, step=1,
|
257 |
+
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
|
258 |
+
)
|
259 |
+
min_new_tokens = gr.Slider(
|
260 |
+
minimum=0, maximum=max_max_new_tokens, step=1,
|
261 |
+
value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
|
262 |
+
)
|
263 |
+
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
|
264 |
+
value=kwargs['early_stopping'])
|
265 |
+
max_max_time = 60 * 5 if not is_low_mem else 60
|
266 |
+
max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
|
267 |
+
value=min(max_max_time, kwargs['max_time']), label="Max. time",
|
268 |
+
info="Max. time to search optimal output.")
|
269 |
+
repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
|
270 |
+
value=kwargs['repetition_penalty'],
|
271 |
+
label="Repetition Penalty")
|
272 |
+
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
|
273 |
+
value=kwargs['num_return_sequences'],
|
274 |
+
label="Number Returns", info="Must be <= num_beams",
|
275 |
+
visible=not is_public)
|
276 |
+
iinput = gr.Textbox(lines=4, label="Input",
|
277 |
+
placeholder=kwargs['placeholder_input'],
|
278 |
+
visible=not is_public)
|
279 |
+
context = gr.Textbox(lines=3, label="System Pre-Context",
|
280 |
+
info="Directly pre-appended without prompt processing",
|
281 |
+
visible=not is_public)
|
282 |
+
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
|
283 |
+
visible=not is_public)
|
284 |
+
|
285 |
+
with gr.TabItem("Models"):
|
286 |
+
load_msg = "Load-Unload Model/LORA" if not is_public \
|
287 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
|
288 |
+
load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
|
289 |
+
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
|
290 |
+
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
|
291 |
+
value=False, visible=not is_public)
|
292 |
+
with gr.Row():
|
293 |
+
n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
|
294 |
+
with gr.Column():
|
295 |
+
with gr.Row():
|
296 |
+
with gr.Column(scale=50):
|
297 |
+
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
|
298 |
+
value=kwargs['base_model'])
|
299 |
+
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
|
300 |
+
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
|
301 |
+
with gr.Column(scale=1):
|
302 |
+
load_model_button = gr.Button(load_msg)
|
303 |
+
model_load8bit_checkbox = gr.components.Checkbox(
|
304 |
+
label="Load 8-bit [requires support]",
|
305 |
+
value=kwargs['load_8bit'])
|
306 |
+
model_infer_devices_checkbox = gr.components.Checkbox(
|
307 |
+
label="Choose Devices [If not Checked, use all GPUs]",
|
308 |
+
value=kwargs['infer_devices'])
|
309 |
+
model_gpu = gr.Dropdown(n_gpus_list,
|
310 |
+
label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
|
311 |
+
value=kwargs['gpu_id'])
|
312 |
+
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
|
313 |
+
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
|
314 |
+
visible=kwargs['show_lora'])
|
315 |
+
with gr.Row():
|
316 |
+
with gr.Column(scale=50):
|
317 |
+
new_model = gr.Textbox(label="New Model HF name/path")
|
318 |
+
new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
|
319 |
+
with gr.Column(scale=1):
|
320 |
+
add_model_button = gr.Button("Add new model name")
|
321 |
+
add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
|
322 |
+
col_model2 = gr.Column(visible=False)
|
323 |
+
with col_model2:
|
324 |
+
with gr.Row():
|
325 |
+
with gr.Column(scale=50):
|
326 |
+
model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
|
327 |
+
value=no_model_str)
|
328 |
+
lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
|
329 |
+
value=no_lora_str,
|
330 |
+
visible=kwargs['show_lora'])
|
331 |
+
with gr.Column(scale=1):
|
332 |
+
load_model_button2 = gr.Button(load_msg2)
|
333 |
+
model_load8bit_checkbox2 = gr.components.Checkbox(
|
334 |
+
label="Load 8-bit 2 [requires support]",
|
335 |
+
value=kwargs['load_8bit'])
|
336 |
+
model_infer_devices_checkbox2 = gr.components.Checkbox(
|
337 |
+
label="Choose Devices 2 [If not Checked, use all GPUs]",
|
338 |
+
value=kwargs[
|
339 |
+
'infer_devices'])
|
340 |
+
model_gpu2 = gr.Dropdown(n_gpus_list,
|
341 |
+
label="GPU ID [-1 = all GPUs, if choose is enabled]",
|
342 |
+
value=kwargs['gpu_id'])
|
343 |
+
# no model/lora loaded ever in model2 by default
|
344 |
+
model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
|
345 |
+
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
|
346 |
+
visible=kwargs['show_lora'])
|
347 |
+
with gr.TabItem("System"):
|
348 |
+
admin_row = gr.Row()
|
349 |
+
with admin_row:
|
350 |
+
admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
|
351 |
+
admin_btn = gr.Button(value="Admin Access", visible=is_public)
|
352 |
+
system_row = gr.Row(visible=not is_public)
|
353 |
+
with system_row:
|
354 |
+
with gr.Column():
|
355 |
+
with gr.Row():
|
356 |
+
system_btn = gr.Button(value='Get System Info')
|
357 |
+
system_text = gr.Textbox(label='System Info')
|
358 |
+
|
359 |
+
with gr.Row():
|
360 |
+
zip_btn = gr.Button("Zip")
|
361 |
+
zip_text = gr.Textbox(label="Zip file name")
|
362 |
+
file_output = gr.File()
|
363 |
+
with gr.Row():
|
364 |
+
s3up_btn = gr.Button("S3UP")
|
365 |
+
s3up_text = gr.Textbox(label='S3UP result')
|
366 |
+
|
367 |
+
# Get flagged data
|
368 |
+
zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
|
369 |
+
zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text])
|
370 |
+
s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text)
|
371 |
+
|
372 |
+
def check_admin_pass(x):
|
373 |
+
return gr.update(visible=x == admin_pass)
|
374 |
+
|
375 |
+
def close_admin(x):
|
376 |
+
return gr.update(visible=not (x == admin_pass))
|
377 |
+
|
378 |
+
admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row) \
|
379 |
+
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row)
|
380 |
+
|
381 |
+
# Get inputs to evaluate()
|
382 |
+
all_kwargs = kwargs.copy()
|
383 |
+
all_kwargs.update(locals())
|
384 |
+
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
|
385 |
+
from functools import partial
|
386 |
+
kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
|
387 |
+
# ensure present
|
388 |
+
for k in inputs_kwargs_list:
|
389 |
+
assert k in kwargs_evaluate, "Missing %s" % k
|
390 |
+
fun = partial(evaluate,
|
391 |
+
**kwargs_evaluate)
|
392 |
+
fun2 = partial(evaluate,
|
393 |
+
**kwargs_evaluate)
|
394 |
+
|
395 |
+
dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
|
396 |
+
size="sm",
|
397 |
+
)
|
398 |
+
dark_mode_btn.click(
|
399 |
+
None,
|
400 |
+
None,
|
401 |
+
None,
|
402 |
+
_js=get_dark_js(),
|
403 |
+
api_name="dark" if allow_api else None,
|
404 |
+
)
|
405 |
+
|
406 |
+
# Control chat and non-chat blocks, which can be independently used by chat checkbox swap
|
407 |
+
def col_nochat_fun(x):
|
408 |
+
return gr.Column.update(visible=not x)
|
409 |
+
|
410 |
+
def col_chat_fun(x):
|
411 |
+
return gr.Column.update(visible=x)
|
412 |
+
|
413 |
+
def context_fun(x):
|
414 |
+
return gr.Textbox.update(visible=not x)
|
415 |
+
|
416 |
+
chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
|
417 |
+
.then(col_chat_fun, chat, col_chat) \
|
418 |
+
.then(context_fun, chat, context)
|
419 |
+
|
420 |
+
# examples after submit or any other buttons for chat or no chat
|
421 |
+
if kwargs['examples'] is not None and kwargs['show_examples']:
|
422 |
+
gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
|
423 |
+
|
424 |
+
# Score
|
425 |
+
def score_last_response(*args, nochat=False, model2=False):
|
426 |
+
""" Similar to user() """
|
427 |
+
args_list = list(args)
|
428 |
+
|
429 |
+
max_length_tokenize = 512 if is_low_mem else 2048
|
430 |
+
cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
|
431 |
+
smodel = score_model_state0[0]
|
432 |
+
stokenizer = score_model_state0[1]
|
433 |
+
sdevice = score_model_state0[2]
|
434 |
+
if not nochat:
|
435 |
+
history = args_list[-1]
|
436 |
+
if history is None:
|
437 |
+
if not model2:
|
438 |
+
# maybe only doing first model, no need to complain
|
439 |
+
print("Bad history in scoring last response, fix for now", flush=True)
|
440 |
+
history = []
|
441 |
+
if smodel is not None and \
|
442 |
+
stokenizer is not None and \
|
443 |
+
sdevice is not None and \
|
444 |
+
history is not None and len(history) > 0 and \
|
445 |
+
history[-1] is not None and \
|
446 |
+
len(history[-1]) >= 2:
|
447 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
448 |
+
|
449 |
+
question = history[-1][0]
|
450 |
+
|
451 |
+
answer = history[-1][1]
|
452 |
+
else:
|
453 |
+
return 'Response Score: NA'
|
454 |
+
else:
|
455 |
+
answer = args_list[-1]
|
456 |
+
instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
|
457 |
+
question = args_list[instruction_nochat_arg_id]
|
458 |
+
|
459 |
+
if question is None:
|
460 |
+
return 'Response Score: Bad Question'
|
461 |
+
if answer is None:
|
462 |
+
return 'Response Score: Bad Answer'
|
463 |
+
score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len)
|
464 |
+
if isinstance(score, str):
|
465 |
+
return 'Response Score: NA'
|
466 |
+
return 'Response Score: {:.1%}'.format(score)
|
467 |
+
|
468 |
+
def noop_score_last_response(*args, **kwargs):
|
469 |
+
return "Response Score: Disabled"
|
470 |
+
|
471 |
+
if kwargs['score_model']:
|
472 |
+
score_fun = score_last_response
|
473 |
+
else:
|
474 |
+
score_fun = noop_score_last_response
|
475 |
+
|
476 |
+
score_args = dict(fn=score_fun,
|
477 |
+
inputs=inputs_list + [text_output],
|
478 |
+
outputs=[score_text],
|
479 |
+
)
|
480 |
+
score_args2 = dict(fn=partial(score_fun, model2=True),
|
481 |
+
inputs=inputs_list + [text_output2],
|
482 |
+
outputs=[score_text2],
|
483 |
+
)
|
484 |
+
|
485 |
+
score_args_nochat = dict(fn=partial(score_fun, nochat=True),
|
486 |
+
inputs=inputs_list + [text_output_nochat],
|
487 |
+
outputs=[score_text_nochat],
|
488 |
+
)
|
489 |
+
if not kwargs['auto_score']:
|
490 |
+
score_event = score_btn.click(**score_args, queue=queue, api_name='score' if allow_api else None) \
|
491 |
+
.then(**score_args2, queue=queue, api_name='score2' if allow_api else None)
|
492 |
+
score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=queue,
|
493 |
+
api_name='score_nochat' if allow_api else None)
|
494 |
+
|
495 |
+
def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
|
496 |
+
"""
|
497 |
+
User that fills history for bot
|
498 |
+
:param args:
|
499 |
+
:param undo:
|
500 |
+
:param sanitize_user_prompt:
|
501 |
+
:param model2:
|
502 |
+
:return:
|
503 |
+
"""
|
504 |
+
args_list = list(args)
|
505 |
+
user_message = args_list[0]
|
506 |
+
input1 = args_list[1]
|
507 |
+
context1 = args_list[2]
|
508 |
+
if input1 and not user_message.endswith(':'):
|
509 |
+
user_message1 = user_message + ":" + input1
|
510 |
+
elif input1:
|
511 |
+
user_message1 = user_message + input1
|
512 |
+
else:
|
513 |
+
user_message1 = user_message
|
514 |
+
if sanitize_user_prompt:
|
515 |
+
from better_profanity import profanity
|
516 |
+
user_message1 = profanity.censor(user_message1)
|
517 |
+
|
518 |
+
history = args_list[-1]
|
519 |
+
if undo and history:
|
520 |
+
history.pop()
|
521 |
+
args_list = args_list[:-1] # FYI, even if unused currently
|
522 |
+
if history is None:
|
523 |
+
if not model2:
|
524 |
+
# no need to complain so often unless model1
|
525 |
+
print("Bad history, fix for now", flush=True)
|
526 |
+
history = []
|
527 |
+
# ensure elements not mixed across models as output,
|
528 |
+
# even if input is currently same source
|
529 |
+
history = history.copy()
|
530 |
+
if undo:
|
531 |
+
return history
|
532 |
+
else:
|
533 |
+
# FIXME: compare, same history for now
|
534 |
+
return history + [[user_message1, None]]
|
535 |
+
|
536 |
+
def bot(*args, retry=False):
|
537 |
+
"""
|
538 |
+
bot that consumes history for user input
|
539 |
+
instruction (from input_list) itself is not consumed by bot
|
540 |
+
:param args:
|
541 |
+
:param retry:
|
542 |
+
:return:
|
543 |
+
"""
|
544 |
+
args_list = list(args).copy()
|
545 |
+
history = args_list[-1] # model_state is -2
|
546 |
+
if retry and history:
|
547 |
+
history.pop()
|
548 |
+
if not history:
|
549 |
+
print("No history", flush=True)
|
550 |
+
return
|
551 |
+
# ensure output will be unique to models
|
552 |
+
history = history.copy()
|
553 |
+
instruction1 = history[-1][0]
|
554 |
+
context1 = ''
|
555 |
+
if kwargs['chat_history'] > 0:
|
556 |
+
prompt_type_arg_id = eval_func_param_names.index('prompt_type')
|
557 |
+
prompt_type1 = args_list[prompt_type_arg_id]
|
558 |
+
chat_arg_id = eval_func_param_names.index('chat')
|
559 |
+
chat1 = args_list[chat_arg_id]
|
560 |
+
context1 = ''
|
561 |
+
for histi in range(len(history) - 1):
|
562 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
563 |
+
context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
|
564 |
+
'<br>', '\n')
|
565 |
+
if not context1.endswith('\n'):
|
566 |
+
context1 += '\n'
|
567 |
+
if context1 and not context1.endswith('\n'):
|
568 |
+
context1 += '\n' # ensure if terminates abruptly, then human continues on next line
|
569 |
+
args_list[0] = instruction1 # override original instruction with history from user
|
570 |
+
# only include desired chat history
|
571 |
+
args_list[2] = context1[-kwargs['chat_history']:]
|
572 |
+
model_state1 = args_list[-2]
|
573 |
+
if model_state1[0] is None or model_state1[0] == no_model_str:
|
574 |
+
return
|
575 |
+
args_list = args_list[:-2]
|
576 |
+
fun1 = partial(evaluate,
|
577 |
+
model_state1,
|
578 |
+
**kwargs_evaluate)
|
579 |
+
try:
|
580 |
+
for output in fun1(*tuple(args_list)):
|
581 |
+
bot_message = output
|
582 |
+
history[-1][1] = bot_message
|
583 |
+
yield history
|
584 |
+
except StopIteration:
|
585 |
+
yield history
|
586 |
+
except RuntimeError as e:
|
587 |
+
if "generator raised StopIteration" in str(e):
|
588 |
+
# assume last entry was bad, undo
|
589 |
+
history.pop()
|
590 |
+
yield history
|
591 |
+
raise
|
592 |
+
except Exception as e:
|
593 |
+
# put error into user input
|
594 |
+
history[-1][0] = "Exception: %s" % str(e)
|
595 |
+
yield history
|
596 |
+
raise
|
597 |
+
return
|
598 |
+
|
599 |
+
# NORMAL MODEL
|
600 |
+
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
|
601 |
+
inputs=inputs_list + [text_output],
|
602 |
+
outputs=text_output,
|
603 |
+
)
|
604 |
+
bot_args = dict(fn=bot,
|
605 |
+
inputs=inputs_list + [model_state] + [text_output],
|
606 |
+
outputs=text_output,
|
607 |
+
)
|
608 |
+
retry_bot_args = dict(fn=functools.partial(bot, retry=True),
|
609 |
+
inputs=inputs_list + [model_state] + [text_output],
|
610 |
+
outputs=text_output,
|
611 |
+
)
|
612 |
+
undo_user_args = dict(fn=functools.partial(user, undo=True),
|
613 |
+
inputs=inputs_list + [text_output],
|
614 |
+
outputs=text_output,
|
615 |
+
)
|
616 |
+
|
617 |
+
# MODEL2
|
618 |
+
user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
|
619 |
+
inputs=inputs_list + [text_output2],
|
620 |
+
outputs=text_output2,
|
621 |
+
)
|
622 |
+
bot_args2 = dict(fn=bot,
|
623 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
624 |
+
outputs=text_output2,
|
625 |
+
)
|
626 |
+
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
|
627 |
+
inputs=inputs_list + [model_state2] + [text_output2],
|
628 |
+
outputs=text_output2,
|
629 |
+
)
|
630 |
+
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
|
631 |
+
inputs=inputs_list + [text_output2],
|
632 |
+
outputs=text_output2,
|
633 |
+
)
|
634 |
+
|
635 |
+
def clear_instruct():
|
636 |
+
return gr.Textbox.update(value='')
|
637 |
+
|
638 |
+
if kwargs['auto_score']:
|
639 |
+
# in case 2nd model, consume instruction first, so can clear quickly
|
640 |
+
# bot doesn't consume instruction itself, just history from user, so why works
|
641 |
+
submit_event = instruction.submit(**user_args, queue=queue,
|
642 |
+
api_name='instruction' if allow_api else None) \
|
643 |
+
.then(**user_args2, api_name='instruction2' if allow_api else None) \
|
644 |
+
.then(clear_instruct, None, instruction) \
|
645 |
+
.then(clear_instruct, None, iinput) \
|
646 |
+
.then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
|
647 |
+
.then(**score_args, api_name='instruction_bot_score' if allow_api else None, queue=queue) \
|
648 |
+
.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
|
649 |
+
.then(**score_args2, api_name='instruction_bot_score2' if allow_api else None, queue=queue) \
|
650 |
+
.then(clear_torch_cache)
|
651 |
+
submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
|
652 |
+
.then(**user_args2, api_name='submit2' if allow_api else None) \
|
653 |
+
.then(clear_instruct, None, instruction) \
|
654 |
+
.then(clear_instruct, None, iinput) \
|
655 |
+
.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
|
656 |
+
.then(**score_args, api_name='submit_bot_score' if allow_api else None, queue=queue) \
|
657 |
+
.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
|
658 |
+
.then(**score_args2, api_name='submit_bot_score2' if allow_api else None, queue=queue) \
|
659 |
+
.then(clear_torch_cache)
|
660 |
+
submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
|
661 |
+
.then(**user_args2, api_name='retry2' if allow_api else None) \
|
662 |
+
.then(clear_instruct, None, instruction) \
|
663 |
+
.then(clear_instruct, None, iinput) \
|
664 |
+
.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
|
665 |
+
.then(**score_args, api_name='retry_bot_score' if allow_api else None, queue=queue) \
|
666 |
+
.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
|
667 |
+
.then(**score_args2, api_name='retry_bot_score2' if allow_api else None, queue=queue) \
|
668 |
+
.then(clear_torch_cache)
|
669 |
+
submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
|
670 |
+
.then(**undo_user_args2, api_name='undo2' if allow_api else None) \
|
671 |
+
.then(clear_instruct, None, instruction) \
|
672 |
+
.then(clear_instruct, None, iinput) \
|
673 |
+
.then(**score_args, api_name='undo_score' if allow_api else None) \
|
674 |
+
.then(**score_args2, api_name='undo_score2' if allow_api else None)
|
675 |
+
else:
|
676 |
+
submit_event = instruction.submit(**user_args,
|
677 |
+
api_name='instruction' if allow_api else None) \
|
678 |
+
.then(**user_args2, api_name='instruction2' if allow_api else None) \
|
679 |
+
.then(clear_instruct, None, instruction) \
|
680 |
+
.then(clear_instruct, None, iinput) \
|
681 |
+
.then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
|
682 |
+
.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
|
683 |
+
.then(clear_torch_cache)
|
684 |
+
submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
|
685 |
+
.then(**user_args2, api_name='submit2' if allow_api else None) \
|
686 |
+
.then(clear_instruct, None, instruction) \
|
687 |
+
.then(clear_instruct, None, iinput) \
|
688 |
+
.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
|
689 |
+
.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
|
690 |
+
.then(clear_torch_cache)
|
691 |
+
submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
|
692 |
+
.then(**user_args2, api_name='retry2' if allow_api else None) \
|
693 |
+
.then(clear_instruct, None, instruction) \
|
694 |
+
.then(clear_instruct, None, iinput) \
|
695 |
+
.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
|
696 |
+
.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
|
697 |
+
.then(clear_torch_cache)
|
698 |
+
submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
|
699 |
+
.then(**undo_user_args2, api_name='undo2' if allow_api else None)
|
700 |
+
|
701 |
+
# does both models
|
702 |
+
clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
|
703 |
+
.then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
|
704 |
+
# NOTE: clear of instruction/iinput for nochat has to come after score,
|
705 |
+
# because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
|
706 |
+
submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
|
707 |
+
outputs=text_output_nochat,
|
708 |
+
queue=queue,
|
709 |
+
api_name='submit_nochat' if allow_api else None) \
|
710 |
+
.then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \
|
711 |
+
.then(clear_instruct, None, instruction_nochat) \
|
712 |
+
.then(clear_instruct, None, iinput_nochat) \
|
713 |
+
.then(clear_torch_cache)
|
714 |
+
|
715 |
+
def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
|
716 |
+
# ensure old model removed from GPU memory
|
717 |
+
if kwargs['debug']:
|
718 |
+
print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
719 |
+
|
720 |
+
model0 = model_state0[0]
|
721 |
+
if isinstance(model_state_old[0], str) and model0 is not None:
|
722 |
+
# best can do, move model loaded at first to CPU
|
723 |
+
model0.cpu()
|
724 |
+
|
725 |
+
if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
|
726 |
+
try:
|
727 |
+
model_state_old[0].cpu()
|
728 |
+
except Exception as e:
|
729 |
+
# sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
|
730 |
+
print("Unable to put model on CPU: %s" % str(e), flush=True)
|
731 |
+
del model_state_old[0]
|
732 |
+
model_state_old[0] = None
|
733 |
+
|
734 |
+
if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
|
735 |
+
del model_state_old[1]
|
736 |
+
model_state_old[1] = None
|
737 |
+
|
738 |
+
clear_torch_cache()
|
739 |
+
if kwargs['debug']:
|
740 |
+
print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True)
|
741 |
+
|
742 |
+
if model_name is None or model_name == no_model_str:
|
743 |
+
# no-op if no model, just free memory
|
744 |
+
# no detranscribe needed for model, never go into evaluate
|
745 |
+
lora_weights = no_lora_str
|
746 |
+
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
|
747 |
+
|
748 |
+
all_kwargs1 = all_kwargs.copy()
|
749 |
+
all_kwargs1['base_model'] = model_name.strip()
|
750 |
+
all_kwargs1['load_8bit'] = load_8bit
|
751 |
+
all_kwargs1['infer_devices'] = infer_devices
|
752 |
+
all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
|
753 |
+
model_lower = model_name.strip().lower()
|
754 |
+
if model_lower in inv_prompt_type_to_model_lower:
|
755 |
+
prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
|
756 |
+
else:
|
757 |
+
prompt_type1 = prompt_type_old
|
758 |
+
|
759 |
+
# detranscribe
|
760 |
+
if lora_weights == no_lora_str:
|
761 |
+
lora_weights = ''
|
762 |
+
|
763 |
+
all_kwargs1['lora_weights'] = lora_weights.strip()
|
764 |
+
model1, tokenizer1, device1 = get_model(**all_kwargs1)
|
765 |
+
clear_torch_cache()
|
766 |
+
|
767 |
+
if kwargs['debug']:
|
768 |
+
print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True)
|
769 |
+
return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
|
770 |
+
|
771 |
+
def dropdown_prompt_type_list(x):
|
772 |
+
return gr.Dropdown.update(value=x)
|
773 |
+
|
774 |
+
def chatbot_list(x, model_used_in):
|
775 |
+
return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
|
776 |
+
|
777 |
+
load_model_args = dict(fn=load_model,
|
778 |
+
inputs=[model_choice, lora_choice, model_state, prompt_type,
|
779 |
+
model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
|
780 |
+
outputs=[model_state, model_used, lora_used, prompt_type])
|
781 |
+
prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
|
782 |
+
chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
|
783 |
+
nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
|
784 |
+
if not is_public:
|
785 |
+
load_model_event = load_model_button.click(**load_model_args) \
|
786 |
+
.then(**prompt_update_args) \
|
787 |
+
.then(**chatbot_update_args) \
|
788 |
+
.then(**nochat_update_args) \
|
789 |
+
.then(clear_torch_cache)
|
790 |
+
|
791 |
+
load_model_args2 = dict(fn=load_model,
|
792 |
+
inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
|
793 |
+
model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
|
794 |
+
outputs=[model_state2, model_used2, lora_used2, prompt_type2])
|
795 |
+
prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
|
796 |
+
chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
|
797 |
+
if not is_public:
|
798 |
+
load_model_event2 = load_model_button2.click(**load_model_args2) \
|
799 |
+
.then(**prompt_update_args2) \
|
800 |
+
.then(**chatbot_update_args2) \
|
801 |
+
.then(clear_torch_cache)
|
802 |
+
|
803 |
+
def dropdown_model_list(list0, x):
|
804 |
+
new_state = [list0[0] + [x]]
|
805 |
+
new_options = [*new_state[0]]
|
806 |
+
return gr.Dropdown.update(value=x, choices=new_options), \
|
807 |
+
gr.Dropdown.update(value=x, choices=new_options), \
|
808 |
+
'', new_state
|
809 |
+
|
810 |
+
add_model_event = add_model_button.click(fn=dropdown_model_list,
|
811 |
+
inputs=[model_options_state, new_model],
|
812 |
+
outputs=[model_choice, model_choice2, new_model, model_options_state])
|
813 |
+
|
814 |
+
def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
|
815 |
+
new_state = [list0[0] + [x]]
|
816 |
+
new_options = [*new_state[0]]
|
817 |
+
# don't switch drop-down to added lora if already have model loaded
|
818 |
+
x1 = x if model_used1 == no_model_str else lora_used1
|
819 |
+
x2 = x if model_used2 == no_model_str else lora_used2
|
820 |
+
return gr.Dropdown.update(value=x1, choices=new_options), \
|
821 |
+
gr.Dropdown.update(value=x2, choices=new_options), \
|
822 |
+
'', new_state
|
823 |
+
|
824 |
+
add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
|
825 |
+
inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
|
826 |
+
lora_used2],
|
827 |
+
outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
|
828 |
+
|
829 |
+
go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None) \
|
830 |
+
.then(lambda: gr.update(visible=True), None, normal_block) \
|
831 |
+
.then(**load_model_args).then(**prompt_update_args)
|
832 |
+
|
833 |
+
def compare_textbox_fun(x):
|
834 |
+
return gr.Textbox.update(visible=x)
|
835 |
+
|
836 |
+
def compare_column_fun(x):
|
837 |
+
return gr.Column.update(visible=x)
|
838 |
+
|
839 |
+
def compare_prompt_fun(x):
|
840 |
+
return gr.Dropdown.update(visible=x)
|
841 |
+
|
842 |
+
compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2,
|
843 |
+
api_name="compare_checkbox" if allow_api else None) \
|
844 |
+
.then(compare_column_fun, compare_checkbox, col_model2) \
|
845 |
+
.then(compare_prompt_fun, compare_checkbox, prompt_type2) \
|
846 |
+
.then(compare_textbox_fun, compare_checkbox, score_text2)
|
847 |
+
# FIXME: add score_res2 in condition, but do better
|
848 |
+
|
849 |
+
# callback for logging flagged input/output
|
850 |
+
callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points")
|
851 |
+
flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None,
|
852 |
+
preprocess=False,
|
853 |
+
api_name='flag' if allow_api else None)
|
854 |
+
flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None,
|
855 |
+
preprocess=False,
|
856 |
+
api_name='flag_nochat' if allow_api else None)
|
857 |
+
|
858 |
+
def get_system_info():
|
859 |
+
return gr.Textbox.update(value=system_info_print())
|
860 |
+
|
861 |
+
system_event = system_btn.click(get_system_info, outputs=system_text,
|
862 |
+
api_name='system_info' if allow_api else None)
|
863 |
+
|
864 |
+
# don't pass text_output, don't want to clear output, just stop it
|
865 |
+
# FIXME: have to click once to stop output and second time to stop GPUs going
|
866 |
+
stop_btn.click(lambda: None, None, None,
|
867 |
+
cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
|
868 |
+
queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache)
|
869 |
+
demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
|
870 |
+
|
871 |
+
demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
|
872 |
+
favicon_path = "h2o-logo.svg"
|
873 |
+
|
874 |
+
scheduler = BackgroundScheduler()
|
875 |
+
scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
|
876 |
+
if is_public:
|
877 |
+
scheduler.add_job(func=ping, trigger="interval", seconds=60)
|
878 |
+
scheduler.start()
|
879 |
+
|
880 |
+
demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
|
881 |
+
favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
|
882 |
+
print("Started GUI", flush=True)
|
883 |
+
if kwargs['block_gradio_exit']:
|
884 |
+
demo.block_thread()
|
885 |
+
|
886 |
+
|
887 |
+
input_args_list = ['model_state']
|
888 |
+
inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
|
889 |
+
'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
|
890 |
+
|
891 |
+
|
892 |
+
def get_inputs_list(inputs_dict, model_lower):
|
893 |
+
"""
|
894 |
+
map gradio objects in locals() to inputs for evaluate().
|
895 |
+
:param inputs_dict:
|
896 |
+
:param model_lower:
|
897 |
+
:return:
|
898 |
+
"""
|
899 |
+
inputs_list_names = list(inspect.signature(evaluate).parameters)
|
900 |
+
inputs_list = []
|
901 |
+
for k in inputs_list_names:
|
902 |
+
if k == 'kwargs':
|
903 |
+
continue
|
904 |
+
if k in input_args_list + inputs_kwargs_list:
|
905 |
+
# these are added via partial, not taken as input
|
906 |
+
continue
|
907 |
+
if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
|
908 |
+
continue
|
909 |
+
inputs_list.append(inputs_dict[k])
|
910 |
+
return inputs_list
|
gradio_themes.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
from gradio.themes.soft import Soft
|
3 |
+
from gradio.themes.utils import Color, colors, sizes
|
4 |
+
|
5 |
+
h2o_yellow = Color(
|
6 |
+
name="yellow",
|
7 |
+
c50="#fffef2",
|
8 |
+
c100="#fff9e6",
|
9 |
+
c200="#ffecb3",
|
10 |
+
c300="#ffe28c",
|
11 |
+
c400="#ffd659",
|
12 |
+
c500="#fec925",
|
13 |
+
c600="#e6ac00",
|
14 |
+
c700="#bf8f00",
|
15 |
+
c800="#a67c00",
|
16 |
+
c900="#664d00",
|
17 |
+
c950="#403000",
|
18 |
+
)
|
19 |
+
h2o_gray = Color(
|
20 |
+
name="gray",
|
21 |
+
c50="#f8f8f8",
|
22 |
+
c100="#e5e5e5",
|
23 |
+
c200="#cccccc",
|
24 |
+
c300="#b2b2b2",
|
25 |
+
c400="#999999",
|
26 |
+
c500="#7f7f7f",
|
27 |
+
c600="#666666",
|
28 |
+
c700="#4c4c4c",
|
29 |
+
c800="#333333",
|
30 |
+
c900="#191919",
|
31 |
+
c950="#0d0d0d",
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
class H2oTheme(Soft):
|
36 |
+
def __init__(
|
37 |
+
self,
|
38 |
+
*,
|
39 |
+
primary_hue: colors.Color | str = h2o_yellow,
|
40 |
+
secondary_hue: colors.Color | str = h2o_yellow,
|
41 |
+
neutral_hue: colors.Color | str = h2o_gray,
|
42 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
43 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
44 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
45 |
+
):
|
46 |
+
super().__init__(
|
47 |
+
primary_hue=primary_hue,
|
48 |
+
secondary_hue=secondary_hue,
|
49 |
+
neutral_hue=neutral_hue,
|
50 |
+
spacing_size=spacing_size,
|
51 |
+
radius_size=radius_size,
|
52 |
+
text_size=text_size,
|
53 |
+
)
|
54 |
+
super().set(
|
55 |
+
link_text_color="#3344DD",
|
56 |
+
link_text_color_hover="#3344DD",
|
57 |
+
link_text_color_visited="#3344DD",
|
58 |
+
link_text_color_dark="#74abff",
|
59 |
+
link_text_color_hover_dark="#a3c8ff",
|
60 |
+
link_text_color_active_dark="#a3c8ff",
|
61 |
+
link_text_color_visited_dark="#74abff",
|
62 |
+
button_primary_text_color="*neutral_950",
|
63 |
+
button_primary_text_color_dark="*neutral_950",
|
64 |
+
button_primary_background_fill="*primary_500",
|
65 |
+
button_primary_background_fill_dark="*primary_500",
|
66 |
+
block_label_background_fill="*primary_500",
|
67 |
+
block_label_background_fill_dark="*primary_500",
|
68 |
+
block_label_text_color="*neutral_950",
|
69 |
+
block_label_text_color_dark="*neutral_950",
|
70 |
+
block_title_text_color="*neutral_950",
|
71 |
+
block_title_text_color_dark="*neutral_950",
|
72 |
+
block_background_fill_dark="*neutral_950",
|
73 |
+
body_background_fill="*neutral_50",
|
74 |
+
body_background_fill_dark="*neutral_900",
|
75 |
+
background_fill_primary_dark="*block_background_fill",
|
76 |
+
block_radius="0 0 8px 8px",
|
77 |
+
)
|
78 |
+
|
79 |
+
|
80 |
+
class SoftTheme(Soft):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
*,
|
84 |
+
primary_hue: colors.Color | str = colors.indigo,
|
85 |
+
secondary_hue: colors.Color | str = colors.indigo,
|
86 |
+
neutral_hue: colors.Color | str = colors.gray,
|
87 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
88 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
89 |
+
text_size: sizes.Size | str = sizes.text_md,
|
90 |
+
):
|
91 |
+
super().__init__(
|
92 |
+
primary_hue=primary_hue,
|
93 |
+
secondary_hue=secondary_hue,
|
94 |
+
neutral_hue=neutral_hue,
|
95 |
+
spacing_size=spacing_size,
|
96 |
+
radius_size=radius_size,
|
97 |
+
text_size=text_size,
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
|
102 |
+
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
|
103 |
+
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
|
104 |
+
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
|
105 |
+
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
|
106 |
+
'82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
|
107 |
+
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
|
108 |
+
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
|
109 |
+
'76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
|
110 |
+
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
|
111 |
+
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
|
112 |
+
'69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
|
113 |
+
'62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
|
114 |
+
'62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
|
115 |
+
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
|
116 |
+
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
|
117 |
+
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
118 |
+
|
119 |
+
|
120 |
+
def get_h2o_title(title):
|
121 |
+
return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
|
122 |
+
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
123 |
+
<h1 style="line-height:60px">{title}</h1>
|
124 |
+
</div>
|
125 |
+
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
126 |
+
<img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/h2o-qr.png></img>
|
127 |
+
</div>
|
128 |
+
"""
|
129 |
+
|
130 |
+
|
131 |
+
def get_simple_title(title):
|
132 |
+
return f"""<h1 align="center"> {title}</h1>"""
|
133 |
+
|
134 |
+
|
135 |
+
def get_dark_js():
|
136 |
+
return """() => {
|
137 |
+
if (document.querySelectorAll('.dark').length) {
|
138 |
+
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
139 |
+
} else {
|
140 |
+
document.querySelector('body').classList.add('dark');
|
141 |
+
}
|
142 |
+
}"""
|
prompter.py
CHANGED
@@ -71,7 +71,8 @@ class Prompter(object):
|
|
71 |
output = output.split(self.pre_response)[1]
|
72 |
allow_terminate = True
|
73 |
else:
|
74 |
-
|
|
|
75 |
allow_terminate = False
|
76 |
else:
|
77 |
allow_terminate = True
|
|
|
71 |
output = output.split(self.pre_response)[1]
|
72 |
allow_terminate = True
|
73 |
else:
|
74 |
+
if output:
|
75 |
+
print("Failure of parsing or not enough output yet: %s" % output, flush=True)
|
76 |
allow_terminate = False
|
77 |
else:
|
78 |
allow_terminate = True
|
requirements.txt
CHANGED
@@ -19,9 +19,10 @@ pandas==2.0.0
|
|
19 |
matplotlib==3.7.1
|
20 |
loralib==0.1.1
|
21 |
bitsandbytes==0.38.1
|
22 |
-
git+https://github.com/huggingface/peft.git@
|
23 |
transformers==4.28.1
|
24 |
tokenizers==0.13.3
|
|
|
25 |
|
26 |
# optional for generate
|
27 |
pynvml==11.5.0
|
|
|
19 |
matplotlib==3.7.1
|
20 |
loralib==0.1.1
|
21 |
bitsandbytes==0.38.1
|
22 |
+
git+https://github.com/huggingface/peft.git@e8f66b8a425eced6c592089d40b8d33d82c2b2f0
|
23 |
transformers==4.28.1
|
24 |
tokenizers==0.13.3
|
25 |
+
APScheduler==3.10.1
|
26 |
|
27 |
# optional for generate
|
28 |
pynvml==11.5.0
|
stopping.py
CHANGED
@@ -9,11 +9,11 @@ from transformers import StoppingCriteria
|
|
9 |
|
10 |
class StoppingCriteriaSub(StoppingCriteria):
|
11 |
|
12 |
-
def __init__(self, stops=[], encounters=[]):
|
13 |
super().__init__()
|
14 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
15 |
self.encounters = encounters
|
16 |
-
self.stops = [stop.to(
|
17 |
self.num_stops = [0] * len(stops)
|
18 |
|
19 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
@@ -25,115 +25,3 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
25 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
26 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
27 |
return False
|
28 |
-
|
29 |
-
|
30 |
-
class Stream(StoppingCriteria):
|
31 |
-
"""
|
32 |
-
This class can be used to callback during generation. Keep
|
33 |
-
in mind for decoder-only type of transformers, this will include the initial prompted tokens.
|
34 |
-
|
35 |
-
Args:
|
36 |
-
func (`callable`):
|
37 |
-
A callable function to apply on first input in list every iteration of generation
|
38 |
-
"""
|
39 |
-
|
40 |
-
def __init__(self, func=None):
|
41 |
-
self.func = func
|
42 |
-
|
43 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
44 |
-
if self.func is not None:
|
45 |
-
# only consume first of multiple responses
|
46 |
-
self.func(input_ids[0])
|
47 |
-
return False
|
48 |
-
|
49 |
-
|
50 |
-
class CallbackToGenerator(collections.abc.Generator):
|
51 |
-
"""
|
52 |
-
A generator wrapper for a function that invokes a callback multiple times.
|
53 |
-
|
54 |
-
Calling `send` on the generator emits a value from one callback, and returns
|
55 |
-
the next.
|
56 |
-
|
57 |
-
Note this starts a background thread
|
58 |
-
"""
|
59 |
-
|
60 |
-
def __init__(self, func, *args, callback=None, **kwargs):
|
61 |
-
self.func = func
|
62 |
-
self.args = args
|
63 |
-
self.kwargs = kwargs
|
64 |
-
self.callback = callback
|
65 |
-
|
66 |
-
self._ready_queue = Queue(1)
|
67 |
-
self._done_queue = Queue(1)
|
68 |
-
self._done_holder = [False]
|
69 |
-
|
70 |
-
# local to avoid reference cycles
|
71 |
-
ready_queue = self._ready_queue
|
72 |
-
done_queue = self._done_queue
|
73 |
-
done_holder = self._done_holder
|
74 |
-
|
75 |
-
def val_callback(value):
|
76 |
-
done_queue.put((False, value))
|
77 |
-
cmd, val = ready_queue.get()
|
78 |
-
if cmd == 'send':
|
79 |
-
return val
|
80 |
-
elif cmd == 'throw':
|
81 |
-
raise val
|
82 |
-
else:
|
83 |
-
assert False # pragma: no cover
|
84 |
-
|
85 |
-
def thread_func():
|
86 |
-
while True:
|
87 |
-
cmd, val = ready_queue.get()
|
88 |
-
if cmd == 'send' and val is not None:
|
89 |
-
done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
|
90 |
-
continue
|
91 |
-
break
|
92 |
-
try:
|
93 |
-
if cmd == 'throw':
|
94 |
-
raise val
|
95 |
-
ret = func(callback=val_callback, **self.kwargs)
|
96 |
-
raise StopIteration(ret) if ret is not None else StopIteration
|
97 |
-
except BaseException as e:
|
98 |
-
done_holder[0] = True
|
99 |
-
done_queue.put((True, e))
|
100 |
-
|
101 |
-
self._thread = Thread(target=thread_func)
|
102 |
-
self._thread.start()
|
103 |
-
|
104 |
-
def _put(self, *args):
|
105 |
-
if self._done_holder[0]:
|
106 |
-
raise StopIteration
|
107 |
-
self._ready_queue.put(args)
|
108 |
-
is_exception, val = self._done_queue.get()
|
109 |
-
if is_exception:
|
110 |
-
try:
|
111 |
-
raise val
|
112 |
-
finally:
|
113 |
-
# prevent val's traceback containing a reference cycle
|
114 |
-
del val
|
115 |
-
else:
|
116 |
-
return val
|
117 |
-
|
118 |
-
def send(self, value):
|
119 |
-
return self._put('send', value)
|
120 |
-
|
121 |
-
def throw(self, exc):
|
122 |
-
return self._put('throw', exc)
|
123 |
-
|
124 |
-
def close(self):
|
125 |
-
try:
|
126 |
-
self.throw(GeneratorExit)
|
127 |
-
except StopIteration:
|
128 |
-
self._thread.join()
|
129 |
-
except GeneratorExit:
|
130 |
-
self._thread.join()
|
131 |
-
except BaseException:
|
132 |
-
self._thread.join()
|
133 |
-
raise
|
134 |
-
else:
|
135 |
-
# yielded again, can't clean up the thread
|
136 |
-
raise RuntimeError('Task with callback ignored GeneratorExit')
|
137 |
-
|
138 |
-
def __del__(self):
|
139 |
-
self.close()
|
|
|
9 |
|
10 |
class StoppingCriteriaSub(StoppingCriteria):
|
11 |
|
12 |
+
def __init__(self, stops=[], encounters=[], device="cuda"):
|
13 |
super().__init__()
|
14 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
15 |
self.encounters = encounters
|
16 |
+
self.stops = [stop.to(device) for stop in stops]
|
17 |
self.num_stops = [0] * len(stops)
|
18 |
|
19 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
|
25 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
26 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
27 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
CHANGED
@@ -1,6 +1,12 @@
|
|
|
|
1 |
import os
|
2 |
import gc
|
|
|
3 |
import random
|
|
|
|
|
|
|
|
|
4 |
import time
|
5 |
import traceback
|
6 |
import zipfile
|
@@ -8,7 +14,6 @@ from datetime import datetime
|
|
8 |
import filelock
|
9 |
import numpy as np
|
10 |
import pandas as pd
|
11 |
-
import torch
|
12 |
|
13 |
|
14 |
def set_seed(seed: int):
|
@@ -16,6 +21,7 @@ def set_seed(seed: int):
|
|
16 |
Sets the seed of the entire notebook so results are the same every time we run.
|
17 |
This is for REPRODUCIBILITY.
|
18 |
"""
|
|
|
19 |
np.random.seed(seed)
|
20 |
random_state = np.random.RandomState(seed)
|
21 |
random.seed(seed)
|
@@ -39,12 +45,22 @@ def flatten_list(lis):
|
|
39 |
|
40 |
|
41 |
def clear_torch_cache():
|
42 |
-
|
|
|
43 |
torch.cuda.empty_cache()
|
44 |
torch.cuda.ipc_collect()
|
45 |
gc.collect()
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def system_info():
|
49 |
import psutil
|
50 |
|
@@ -184,3 +200,62 @@ def _s3up(filename):
|
|
184 |
)
|
185 |
if ret in [None, '']:
|
186 |
return "Successfully uploaded %s" % filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
import os
|
3 |
import gc
|
4 |
+
import pathlib
|
5 |
import random
|
6 |
+
import shutil
|
7 |
+
import subprocess
|
8 |
+
import sys
|
9 |
+
import threading
|
10 |
import time
|
11 |
import traceback
|
12 |
import zipfile
|
|
|
14 |
import filelock
|
15 |
import numpy as np
|
16 |
import pandas as pd
|
|
|
17 |
|
18 |
|
19 |
def set_seed(seed: int):
|
|
|
21 |
Sets the seed of the entire notebook so results are the same every time we run.
|
22 |
This is for REPRODUCIBILITY.
|
23 |
"""
|
24 |
+
import torch
|
25 |
np.random.seed(seed)
|
26 |
random_state = np.random.RandomState(seed)
|
27 |
random.seed(seed)
|
|
|
45 |
|
46 |
|
47 |
def clear_torch_cache():
|
48 |
+
import torch
|
49 |
+
if torch.cuda.is_available():
|
50 |
torch.cuda.empty_cache()
|
51 |
torch.cuda.ipc_collect()
|
52 |
gc.collect()
|
53 |
|
54 |
|
55 |
+
def ping():
|
56 |
+
print('Ping: %s' % str(datetime.now()), flush=True)
|
57 |
+
|
58 |
+
|
59 |
+
def get_torch_allocated():
|
60 |
+
import torch
|
61 |
+
return torch.cuda.memory_allocated()
|
62 |
+
|
63 |
+
|
64 |
def system_info():
|
65 |
import psutil
|
66 |
|
|
|
200 |
)
|
201 |
if ret in [None, '']:
|
202 |
return "Successfully uploaded %s" % filename
|
203 |
+
|
204 |
+
|
205 |
+
def get_githash():
|
206 |
+
try:
|
207 |
+
githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
|
208 |
+
except:
|
209 |
+
githash = ''
|
210 |
+
return githash
|
211 |
+
|
212 |
+
|
213 |
+
def copy_code(run_id):
|
214 |
+
"""
|
215 |
+
copy code to track changes
|
216 |
+
:param run_id:
|
217 |
+
:return:
|
218 |
+
"""
|
219 |
+
rnd_num = str(random.randint(0, 2 ** 31))
|
220 |
+
run_id = 'run_' + str(run_id)
|
221 |
+
os.makedirs(run_id, exist_ok=True)
|
222 |
+
me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
|
223 |
+
me_file = os.path.basename(__file__)
|
224 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash())
|
225 |
+
if os.path.isfile(new_me):
|
226 |
+
new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
|
227 |
+
shutil.copy(me_full, new_me)
|
228 |
+
else:
|
229 |
+
shutil.copy(me_full, new_me)
|
230 |
+
|
231 |
+
|
232 |
+
class NullContext(threading.local):
|
233 |
+
"""No-op context manager, executes block without doing any additional processing.
|
234 |
+
|
235 |
+
Used as a stand-in if a particular block of code is only sometimes
|
236 |
+
used with a normal context manager:
|
237 |
+
"""
|
238 |
+
def __init__(self, *args, **kwargs):
|
239 |
+
pass
|
240 |
+
|
241 |
+
def __enter__(self):
|
242 |
+
return self
|
243 |
+
|
244 |
+
def __exit__(self, exc_type, exc_value, exc_traceback):
|
245 |
+
self.finally_act()
|
246 |
+
|
247 |
+
def finally_act(self):
|
248 |
+
pass
|
249 |
+
|
250 |
+
|
251 |
+
def wrapped_partial(func, *args, **kwargs):
|
252 |
+
"""
|
253 |
+
Give partial properties of normal function, like __name__ attribute etc.
|
254 |
+
:param func:
|
255 |
+
:param args:
|
256 |
+
:param kwargs:
|
257 |
+
:return:
|
258 |
+
"""
|
259 |
+
partial_func = functools.partial(func, *args, **kwargs)
|
260 |
+
functools.update_wrapper(partial_func, func)
|
261 |
+
return partial_func
|