pseudotensor commited on
Commit
e0ba5f2
·
1 Parent(s): 09d4719

Update with h2oGPT hash 61628d335bdb685fdcc63ca9821cf5607f41a9e3

Browse files
Files changed (10) hide show
  1. app.py +0 -0
  2. app.py +1 -0
  3. finetune.py +169 -124
  4. generate.py +1185 -0
  5. gradio_runner.py +910 -0
  6. gradio_themes.py +142 -0
  7. prompter.py +2 -1
  8. requirements.txt +2 -1
  9. stopping.py +2 -114
  10. utils.py +77 -2
app.py DELETED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ generate.py
finetune.py CHANGED
@@ -1,55 +1,22 @@
1
  import os
2
- import pathlib
3
- import random
4
- import shutil
5
- import subprocess
6
  import sys
7
  import time
8
- from datetime import datetime
9
  from typing import List, Union
 
10
  import fire
11
  import numpy as np
 
12
  import torch
13
- from datasets import load_dataset, concatenate_datasets
14
- import transformers
15
- import torch.distributed as dist
16
-
17
- from peft import (
18
- prepare_model_for_int8_training,
19
- LoraConfig,
20
- get_peft_model,
21
- get_peft_model_state_dict,
22
- set_peft_model_state_dict,
23
- )
24
-
25
- from peft import mapping
26
- lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING
27
 
28
 
29
  def log(*args, **kwargs):
30
  if int(os.environ.get("LOCAL_RANK", 0)) == 0:
 
 
31
  print(*args, **kwargs)
32
 
33
 
34
- try:
35
- import neptune
36
- from transformers.integrations import NeptuneCallback
37
-
38
- neptune_run = neptune.init_run(
39
- source_files=[],
40
- )
41
- log("Connected to Neptune.")
42
- except ImportError:
43
- neptune_run = None
44
- log("Please pip install neptune for tracking.")
45
- except neptune.exceptions.NeptuneMissingApiTokenException:
46
- neptune_run = None
47
- os.environ["NEPTUNE_MODE"] = 'debug'
48
- log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
49
-
50
- from enum import Enum
51
-
52
-
53
  class PromptType(Enum):
54
  plain = 0
55
  instruct = 1
@@ -87,6 +54,7 @@ prompt_type_to_model_name = {
87
  'h2oai/h2ogpt-oasst1-512-12b',
88
  'h2oai/h2ogpt-oasst1-512-20b',
89
  'h2oai/h2ogpt-oig-oasst1-512-6.9b',
 
90
  ],
91
  'dai_faq': [],
92
  'summarize': [],
@@ -134,7 +102,7 @@ def train(
134
  tokenizer_base_model: str = None,
135
  # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
136
 
137
- data_path: str = None,
138
  data_col_dict: dict = None,
139
  # data_path: str = "./dai_docs.train.json",
140
  prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
@@ -158,6 +126,7 @@ def train(
158
  micro_batch_size: int = 4,
159
  gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
160
  fp16=True,
 
161
 
162
  # general training hyperparams
163
  num_epochs: float = 1,
@@ -175,12 +144,14 @@ def train(
175
  lora_dropout: float = 0.05,
176
  lora_target_modules: List[str] = None,
177
  llama_type: bool = None,
 
178
 
179
  # llm hyperparams
180
  train_on_inputs: bool = True, # if False, masks out inputs in loss
181
  group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
182
  resume_from_checkpoint: str = None, # either training checkpoint or final adapter
183
- cutoff_len: int = 1024, # Good default, especially when have high quality non-trivial data
 
184
 
185
  # torch training params
186
  ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
@@ -190,8 +161,15 @@ def train(
190
  warmup_steps: int = 100,
191
  logging_steps: int = 1,
192
  save_steps: int = None, # must be round multiple of eval_steps
 
193
  add_eos_token: bool = False,
194
  ):
 
 
 
 
 
 
195
  # allow set token directly
196
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
197
 
@@ -211,7 +189,7 @@ def train(
211
  if not output_dir:
212
  output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
213
  if os.path.exists(output_dir) and not resume_from_checkpoint:
214
- raise FileExistsError(f"output_dir based on run_id {run_id} already exists. Please pick a different run_id.")
215
  else:
216
  if os.path.exists(output_dir) and not resume_from_checkpoint:
217
  raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
@@ -223,6 +201,21 @@ def train(
223
  tokenizer_base_model = base_model
224
  if llama_type is None:
225
  llama_type = "llama" in base_model.lower()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  assert (
227
  base_model
228
  ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
@@ -254,7 +247,7 @@ def train(
254
 
255
  model = model_loader.from_pretrained(
256
  base_model,
257
- load_in_8bit=True,
258
  device_map=device_map,
259
  torch_dtype=torch.float16,
260
  max_memory=max_memory,
@@ -268,66 +261,28 @@ def train(
268
  model.is_parallelizable = True
269
  model.model_parallel = True
270
 
271
- tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
272
- local_files_only=local_files_only,
273
- resume_download=resume_download,
274
- use_auth_token=use_auth_token)
275
 
276
- tokenizer.pad_token_id = 0 # different from the eos token
277
- # when generating, we will use the logits of right-most token to predict the next token
278
- # so the padding should be on the left,
279
- # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
280
- tokenizer.padding_side = "left" # Allow batched inference
281
-
282
- def tokenize(prompt, add_eos_token=True):
283
- # there's probably a way to do this with the tokenizer settings
284
- # but again, gotta move fast
285
- result = tokenizer(
286
- prompt,
287
- truncation=True,
288
- max_length=cutoff_len,
289
- padding=False,
290
- return_tensors=None,
291
  )
292
- if (
293
- result["input_ids"][-1] != tokenizer.eos_token_id
294
- and len(result["input_ids"]) < cutoff_len
295
- and add_eos_token
296
- ):
297
- result["input_ids"].append(tokenizer.eos_token_id)
298
- result["attention_mask"].append(1)
299
 
300
- result["labels"] = result["input_ids"].copy()
 
 
 
 
 
 
 
301
 
302
- return result
 
 
303
 
304
- def generate_and_tokenize_prompt(data_point, add_eos=add_eos_token):
305
- full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
306
- tokenized_full_prompt = tokenize(full_prompt)
307
- if not train_on_inputs:
308
- user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
309
- tokenized_user_prompt = tokenize(user_prompt, add_eos_token=add_eos)
310
- user_prompt_len = len(tokenized_user_prompt["input_ids"])
311
- if add_eos:
312
- user_prompt_len -= 1
313
-
314
- # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
315
- tokenized_full_prompt["labels"] = [
316
- -100
317
- ] * user_prompt_len + tokenized_full_prompt["labels"][
318
- user_prompt_len:
319
- ] # could be sped up, probably
320
- return tokenized_full_prompt
321
-
322
- if "gpt-neox" not in base_model or True:
323
- model = prepare_model_for_int8_training(model)
324
- else:
325
- model = prepare_model_for_int8_training(
326
- model,
327
- output_embedding_layer_name="embed_out", # keep output logits in float32
328
- layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
329
- )
330
  if lora_weights:
 
331
  from peft import PeftModel
332
  model = PeftModel.from_pretrained(
333
  model,
@@ -338,7 +293,7 @@ def train(
338
  resume_download=resume_download,
339
  use_auth_token=use_auth_token,
340
  )
341
- else:
342
  if lora_target_modules is None:
343
  base_model_lower = base_model.lower()
344
  if base_model_lower in lora_mappings:
@@ -386,7 +341,11 @@ def train(
386
  log(f"Checkpoint {checkpoint_name} not found")
387
 
388
  print(model)
389
- model.print_trainable_parameters() # Be more transparent about the % of trainable params.
 
 
 
 
390
 
391
  metrics = {}
392
  for name in supported_metrics:
@@ -405,6 +364,7 @@ def train(
405
  elif val_set_size < 1.0 and val_set_size != 0:
406
  raise RuntimeError("Fractional validation size not supported.")
407
 
 
408
  if valid_path:
409
  data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
410
  else:
@@ -427,10 +387,16 @@ def train(
427
  else:
428
  data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
429
  data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
 
 
 
 
 
 
430
 
431
  # only get as much as we need to balance
432
  valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
433
- train_size = max(1, min(data_mix_in.num_rows - valid_size, int(num_rows * data_mix_in_factor)))
434
  mixin_small = data_mix_in.train_test_split(
435
  test_size=train_size + valid_size,
436
  shuffle=True, seed=np.random.randint(10000),
@@ -486,10 +452,20 @@ def train(
486
 
487
  assert train_data is not None
488
 
 
 
 
 
489
  # shuffle and tokenize data
490
  if train_data_mix_in:
491
  train_data = concatenate_datasets([train_data, train_data_mix_in])
492
- train_data = train_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
 
 
 
 
 
 
493
  train_set_size = len(train_data)
494
 
495
  if valid_data and valid_data_mix_in:
@@ -498,7 +474,8 @@ def train(
498
  valid_data = valid_data_mix_in
499
 
500
  if valid_data:
501
- valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt, num_proc=os.cpu_count() // torch.cuda.device_count())
 
502
  val_set_size = len(valid_data)
503
  else:
504
  val_set_size = 0
@@ -509,6 +486,22 @@ def train(
509
  del sample_row_dict['labels']
510
  log("Sample input: %s" % sample_row_dict)
511
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
512
  if neptune_run:
513
  neptune_callback = NeptuneCallback(run=neptune_run)
514
  callbacks = [neptune_callback]
@@ -578,6 +571,7 @@ def train(
578
  else:
579
  trainer_kwargs = dict()
580
 
 
581
  trainer = transformers.Trainer(
582
  model=model,
583
  tokenizer=tokenizer,
@@ -605,7 +599,7 @@ def train(
605
  eval_steps=eval_steps if val_set_size > 0 else None,
606
  save_steps=save_steps,
607
  output_dir=output_dir,
608
- save_total_limit=3,
609
  load_best_model_at_end=True if val_set_size > 0 else False,
610
  ddp_find_unused_parameters=False if ddp else None,
611
  group_by_length=group_by_length,
@@ -622,6 +616,8 @@ def train(
622
  model.config.use_cache = False
623
 
624
  old_state_dict = model.state_dict
 
 
625
  model.state_dict = (
626
  lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
627
  ).__get__(model, type(model))
@@ -629,7 +625,8 @@ def train(
629
  if torch.__version__ >= "2" and sys.platform != "win32":
630
  model = torch.compile(model)
631
  # WIP (not generally replacing layers until pytorch 2.1)
632
- torch.backends.cuda.enable_flash_sdp(True)
 
633
 
634
  if gpus > 1 and not ddp:
635
  assert trainer.is_model_parallel
@@ -649,6 +646,9 @@ def get_loaders(llama_type, model_name, reward_type):
649
  from transformers import LlamaForCausalLM, LlamaTokenizer
650
  model_loader = LlamaForCausalLM
651
  tokenizer_loader = LlamaTokenizer
 
 
 
652
  elif 'gpt2' in model_name.lower():
653
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
654
  return GPT2LMHeadModel, GPT2Tokenizer
@@ -676,31 +676,76 @@ def get_loaders(llama_type, model_name, reward_type):
676
  return model_loader, tokenizer_loader
677
 
678
 
679
- def get_githash():
680
- try:
681
- githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
682
- except:
683
- githash = ''
684
- return githash
 
 
 
 
 
685
 
 
686
 
687
- def copy_code(run_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
  """
689
- copy code to track changes
690
- :param run_id:
 
691
  :return:
692
  """
693
- rnd_num = str(random.randint(0, 2 ** 31))
694
- run_id = 'run_' + str(run_id)
695
- os.makedirs(run_id, exist_ok=True)
696
- me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
697
- me_file = os.path.basename(__file__)
698
- new_me = os.path.join(run_id, me_file + '_' + get_githash())
699
- if os.path.isfile(new_me):
700
- new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
701
- shutil.copy(me_full, new_me)
702
- else:
703
- shutil.copy(me_full, new_me)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
704
 
705
 
706
  def get_prompt(prompt_type, chat, context, reduced):
@@ -824,7 +869,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
824
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
825
  promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
826
 
827
- prompt = context
828
 
829
  if input and promptA:
830
  prompt += f"""{promptA}"""
 
1
  import os
 
 
 
 
2
  import sys
3
  import time
4
+ from functools import partial
5
  from typing import List, Union
6
+ from enum import Enum
7
  import fire
8
  import numpy as np
9
+ from utils import get_githash, copy_code
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def log(*args, **kwargs):
14
  if int(os.environ.get("LOCAL_RANK", 0)) == 0:
15
+ if 'flush' not in kwargs:
16
+ kwargs['flush'] = True
17
  print(*args, **kwargs)
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  class PromptType(Enum):
21
  plain = 0
22
  instruct = 1
 
54
  'h2oai/h2ogpt-oasst1-512-12b',
55
  'h2oai/h2ogpt-oasst1-512-20b',
56
  'h2oai/h2ogpt-oig-oasst1-512-6.9b',
57
+ 'h2oai/h2ogpt-research-oasst1-512-30b', # private
58
  ],
59
  'dai_faq': [],
60
  'summarize': [],
 
102
  tokenizer_base_model: str = None,
103
  # tokenizer_base_model: str = 'EleutherAI/gpt-neox-20b',
104
 
105
+ data_path: str = "h2oai/openassistant_oasst1_h2ogpt",
106
  data_col_dict: dict = None,
107
  # data_path: str = "./dai_docs.train.json",
108
  prompt_type: Union[str, int] = "plain", # "plain", "instruct", "quality", "human_bot", "dai_faq"
 
126
  micro_batch_size: int = 4,
127
  gradient_checkpointing=False, # unnecessary with gradient accumulation enabled
128
  fp16=True,
129
+ train_8bit=True,
130
 
131
  # general training hyperparams
132
  num_epochs: float = 1,
 
144
  lora_dropout: float = 0.05,
145
  lora_target_modules: List[str] = None,
146
  llama_type: bool = None,
147
+ llama_flash_attn: bool = False,
148
 
149
  # llm hyperparams
150
  train_on_inputs: bool = True, # if False, masks out inputs in loss
151
  group_by_length: bool = False, # if True, faster, but produces an odd training loss curve
152
  resume_from_checkpoint: str = None, # either training checkpoint or final adapter
153
+ cutoff_len: int = 512, # larger values use more memory
154
+ drop_truncations: bool = False, # if True, drop any truncated long sequences
155
 
156
  # torch training params
157
  ddp: bool = True, # set to False if OOM with True, for multi-GPU model parallelism
 
161
  warmup_steps: int = 100,
162
  logging_steps: int = 1,
163
  save_steps: int = None, # must be round multiple of eval_steps
164
+ save_total_limit: int = 3,
165
  add_eos_token: bool = False,
166
  ):
167
+
168
+ if llama_flash_attn:
169
+ # Need to call this before importing transformers.
170
+ from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
171
+ replace_llama_attn_with_flash_attn()
172
+
173
  # allow set token directly
174
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
175
 
 
189
  if not output_dir:
190
  output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}"
191
  if os.path.exists(output_dir) and not resume_from_checkpoint:
192
+ raise FileExistsError(f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.")
193
  else:
194
  if os.path.exists(output_dir) and not resume_from_checkpoint:
195
  raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.")
 
201
  tokenizer_base_model = base_model
202
  if llama_type is None:
203
  llama_type = "llama" in base_model.lower()
204
+ if llama_type and llama_flash_attn:
205
+ import pkg_resources
206
+ try:
207
+ pkg_resources.get_distribution('flash_attn')
208
+ can_do_flash_attn = True
209
+ except (pkg_resources.DistributionNotFound, pkg_resources.ContextualVersionConflict):
210
+ can_do_flash_attn = False
211
+
212
+ if not can_do_flash_attn:
213
+ raise RuntimeError("""Flash attention not installed.
214
+ NOTE: for current pytorch 2.0, flash attention requires installing cuda 11.7 via https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=runfile_local and then when running, to avoid installing driver, docs, samples, just install toolkit. Then when pip installing flash attention do:
215
+
216
+ CUDA_HOME=/usr/local/cuda-11.7 pip install flash-attn""")
217
+ from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
218
+ replace_llama_attn_with_flash_attn()
219
  assert (
220
  base_model
221
  ), "Please specify a --base_model, e.g. --base_model='decapoda-research/llama-7b-hf'"
 
247
 
248
  model = model_loader.from_pretrained(
249
  base_model,
250
+ load_in_8bit=train_8bit,
251
  device_map=device_map,
252
  torch_dtype=torch.float16,
253
  max_memory=max_memory,
 
261
  model.is_parallelizable = True
262
  model.model_parallel = True
263
 
264
+ tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
 
 
 
265
 
266
+ if train_8bit:
267
+ from peft import (
268
+ prepare_model_for_int8_training,
 
 
 
 
 
 
 
 
 
 
 
 
269
  )
 
 
 
 
 
 
 
270
 
271
+ if "gpt-neox" not in base_model or True:
272
+ model = prepare_model_for_int8_training(model)
273
+ else:
274
+ model = prepare_model_for_int8_training(
275
+ model,
276
+ output_embedding_layer_name="embed_out", # keep output logits in float32
277
+ layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
278
+ )
279
 
280
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, utils
281
+ lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
282
+ lora_mappings['distilgpt2'] = ["c_attn"]
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  if lora_weights:
285
+
286
  from peft import PeftModel
287
  model = PeftModel.from_pretrained(
288
  model,
 
293
  resume_download=resume_download,
294
  use_auth_token=use_auth_token,
295
  )
296
+ elif lora_r > 0:
297
  if lora_target_modules is None:
298
  base_model_lower = base_model.lower()
299
  if base_model_lower in lora_mappings:
 
341
  log(f"Checkpoint {checkpoint_name} not found")
342
 
343
  print(model)
344
+ try:
345
+ # only for PeftModel
346
+ model.print_trainable_parameters() # Be more transparent about the % of trainable params.
347
+ except:
348
+ pass
349
 
350
  metrics = {}
351
  for name in supported_metrics:
 
364
  elif val_set_size < 1.0 and val_set_size != 0:
365
  raise RuntimeError("Fractional validation size not supported.")
366
 
367
+ from datasets import load_dataset, concatenate_datasets
368
  if valid_path:
369
  data = load_dataset("json", data_files={"train": data_path, "valid": valid_path})
370
  else:
 
387
  else:
388
  data_mix_in = load_dataset(data_mix_in_path)["train"] # can be large
389
  data_mix_in = data_mix_in.rename_columns(data_mix_in_col_dict or {})
390
+ mix_in_rows = int(num_rows * data_mix_in_factor)
391
+
392
+ if mix_in_rows > data_mix_in.num_rows:
393
+ # duplicate rows if mix-in is smaller than required
394
+ log("Duplicating mixin to compensate for its size for training size and mixin fraction")
395
+ data_mix_in = concatenate_datasets([data_mix_in] * int(np.ceil(mix_in_rows / data_mix_in.num_rows)))
396
 
397
  # only get as much as we need to balance
398
  valid_size = min(data_mix_in.num_rows // 2, val_set_size or 0)
399
+ train_size = max(1, min(data_mix_in.num_rows - valid_size, mix_in_rows))
400
  mixin_small = data_mix_in.train_test_split(
401
  test_size=train_size + valid_size,
402
  shuffle=True, seed=np.random.randint(10000),
 
452
 
453
  assert train_data is not None
454
 
455
+ generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
456
+ train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
457
+ cutoff_len=cutoff_len, tokenizer=tokenizer)
458
+
459
  # shuffle and tokenize data
460
  if train_data_mix_in:
461
  train_data = concatenate_datasets([train_data, train_data_mix_in])
462
+ log("Tokenizing %s training rows" % train_data.num_rows)
463
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
464
+ if drop_truncations:
465
+ log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows)
466
+ prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len)
467
+ train_data = train_data.filter(prune_long_sequences_func, num_proc=os.cpu_count() // torch.cuda.device_count())
468
+ log("avoid keeping truncated cases to avoid contaminating model with truncation cases. New size: %s" % train_data.num_rows)
469
  train_set_size = len(train_data)
470
 
471
  if valid_data and valid_data_mix_in:
 
474
  valid_data = valid_data_mix_in
475
 
476
  if valid_data:
477
+ log("Tokenizing %s validation rows" % valid_data.num_rows)
478
+ valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count())
479
  val_set_size = len(valid_data)
480
  else:
481
  val_set_size = 0
 
486
  del sample_row_dict['labels']
487
  log("Sample input: %s" % sample_row_dict)
488
 
489
+ try:
490
+ import neptune
491
+ from transformers.integrations import NeptuneCallback
492
+
493
+ neptune_run = neptune.init_run(
494
+ source_files=[],
495
+ )
496
+ log("Connected to Neptune.")
497
+ except ImportError:
498
+ neptune_run = None
499
+ log("Please pip install neptune for tracking.")
500
+ except neptune.exceptions.NeptuneMissingApiTokenException:
501
+ neptune_run = None
502
+ os.environ["NEPTUNE_MODE"] = 'debug'
503
+ log("No neptune configured, set NEPTUNE_API_TOKEN env var.")
504
+
505
  if neptune_run:
506
  neptune_callback = NeptuneCallback(run=neptune_run)
507
  callbacks = [neptune_callback]
 
571
  else:
572
  trainer_kwargs = dict()
573
 
574
+ import transformers
575
  trainer = transformers.Trainer(
576
  model=model,
577
  tokenizer=tokenizer,
 
599
  eval_steps=eval_steps if val_set_size > 0 else None,
600
  save_steps=save_steps,
601
  output_dir=output_dir,
602
+ save_total_limit=save_total_limit,
603
  load_best_model_at_end=True if val_set_size > 0 else False,
604
  ddp_find_unused_parameters=False if ddp else None,
605
  group_by_length=group_by_length,
 
616
  model.config.use_cache = False
617
 
618
  old_state_dict = model.state_dict
619
+ from peft import get_peft_model_state_dict
620
+
621
  model.state_dict = (
622
  lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
623
  ).__get__(model, type(model))
 
625
  if torch.__version__ >= "2" and sys.platform != "win32":
626
  model = torch.compile(model)
627
  # WIP (not generally replacing layers until pytorch 2.1)
628
+ if not llama_flash_attn:
629
+ torch.backends.cuda.enable_flash_sdp(True)
630
 
631
  if gpus > 1 and not ddp:
632
  assert trainer.is_model_parallel
 
646
  from transformers import LlamaForCausalLM, LlamaTokenizer
647
  model_loader = LlamaForCausalLM
648
  tokenizer_loader = LlamaTokenizer
649
+ elif 'distilgpt2' in model_name.lower():
650
+ from transformers import AutoModelForCausalLM, AutoTokenizer
651
+ return AutoModelForCausalLM, AutoTokenizer
652
  elif 'gpt2' in model_name.lower():
653
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
654
  return GPT2LMHeadModel, GPT2Tokenizer
 
676
  return model_loader, tokenizer_loader
677
 
678
 
679
+ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
680
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
681
+ local_files_only=local_files_only,
682
+ resume_download=resume_download,
683
+ use_auth_token=use_auth_token)
684
+
685
+ tokenizer.pad_token_id = 0 # different from the eos token
686
+ # when generating, we will use the logits of right-most token to predict the next token
687
+ # so the padding should be on the left,
688
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
689
+ tokenizer.padding_side = "left" # Allow batched inference
690
 
691
+ return tokenizer
692
 
693
+
694
+ def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
695
+ # there's probably a way to do this with the tokenizer settings
696
+ # but again, gotta move fast
697
+ result = tokenizer(
698
+ prompt,
699
+ truncation=True,
700
+ max_length=cutoff_len,
701
+ padding=False,
702
+ return_tensors=None,
703
+ )
704
+ if (
705
+ result["input_ids"][-1] != tokenizer.eos_token_id
706
+ and len(result["input_ids"]) < cutoff_len
707
+ and add_eos_token
708
+ ):
709
+ result["input_ids"].append(tokenizer.eos_token_id)
710
+ result["attention_mask"].append(1)
711
+
712
+ result["labels"] = result["input_ids"].copy()
713
+
714
+ return result
715
+
716
+
717
+ def prune_long_sequences(data_point, cutoff_len=None):
718
  """
719
+ Prune if too long for tokenizer, so truncation doesn't lead training to learn from truncated language
720
+ :param data_point:
721
+ :param cutoff_len:
722
  :return:
723
  """
724
+ assert cutoff_len is not None
725
+ return len(data_point['input_ids']) < cutoff_len
726
+
727
+
728
+ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=False, add_eos_token=False,
729
+ cutoff_len=None, tokenizer=None):
730
+ assert prompt_type is not None
731
+ assert cutoff_len is not None
732
+ assert tokenizer is not None
733
+ full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
734
+ tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
735
+ if not train_on_inputs:
736
+ user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
737
+ tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
738
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
739
+ if add_eos_token:
740
+ user_prompt_len -= 1
741
+
742
+ # ignore_index=-100 ensures torch/tf don't include padding token id in CrossEntropyLoss
743
+ tokenized_full_prompt["labels"] = [
744
+ -100
745
+ ] * user_prompt_len + tokenized_full_prompt["labels"][
746
+ user_prompt_len:
747
+ ] # could be sped up, probably
748
+ return tokenized_full_prompt
749
 
750
 
751
  def get_prompt(prompt_type, chat, context, reduced):
 
869
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
870
  promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response = get_prompt(prompt_type, chat, context, reduced)
871
 
872
+ prompt = context if not reduced else ''
873
 
874
  if input and promptA:
875
  prompt += f"""{promptA}"""
generate.py ADDED
@@ -0,0 +1,1185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import sys
3
+ import os
4
+ import traceback
5
+ import typing
6
+ from threading import Thread
7
+ from datetime import datetime
8
+ import filelock
9
+ import psutil
10
+
11
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial
12
+
13
+ SEED = 1236
14
+ set_seed(SEED)
15
+
16
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
17
+ from typing import Union
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ import fire
22
+ import torch
23
+ from peft import PeftModel
24
+ from transformers import GenerationConfig, StoppingCriteriaList, AutoModel, TextIteratorStreamer
25
+ from accelerate import init_empty_weights, infer_auto_device_map
26
+
27
+ from prompter import Prompter
28
+
29
+ from finetune import get_loaders, example_data_points, generate_prompt, human, bot, inv_prompt_type_to_model_lower
30
+ from stopping import StoppingCriteriaSub
31
+
32
+ eval_extra_columns = ['prompt', 'response', 'score']
33
+
34
+
35
+ def main(
36
+ load_8bit: bool = False,
37
+ load_half: bool = True,
38
+ infer_devices: bool = True, # really if to "control" devices now
39
+ base_model: str = '',
40
+ tokenizer_base_model: str = '',
41
+ lora_weights: str = "",
42
+ gpu_id: int = 0, # if infer_devices = True and gpu_id != -1
43
+
44
+ prompt_type: Union[int, str] = None,
45
+ # input to generation
46
+ temperature: float = None,
47
+ top_p: float = None,
48
+ top_k: int = None,
49
+ num_beams: int = None,
50
+ repetition_penalty: float = None,
51
+ num_return_sequences: int = None,
52
+ do_sample: bool = None,
53
+ max_new_tokens: int = None,
54
+ min_new_tokens: int = None,
55
+ early_stopping: Union[bool, str] = None,
56
+ max_time: float = None,
57
+
58
+ debug: bool = False,
59
+ save_dir: str = None,
60
+ share: bool = True,
61
+ local_files_only: bool = False,
62
+ resume_download: bool = True,
63
+ use_auth_token: Union[str, bool] = False, # True requires CLI did huggingface-cli login before running
64
+
65
+ src_lang: str = "English",
66
+ tgt_lang: str = "Russian",
67
+
68
+ gradio: bool = True,
69
+ gradio_avoid_processing_markdown: bool = False,
70
+ chat: bool = True,
71
+ chat_history: int = 4096, # character length of chat context/history
72
+ chat_context: bool = False, # use default context if human_bot
73
+ stream_output: bool = True,
74
+ show_examples: bool = None,
75
+ verbose: bool = False,
76
+ h2ocolors: bool = True,
77
+ height: int = 400,
78
+ show_lora: bool = True,
79
+ # set to True to load --base_model after client logs in,
80
+ # to be able to free GPU memory when model is swapped
81
+ login_mode_if_model0: bool = False,
82
+ block_gradio_exit: bool = True,
83
+ concurrency_count: int = 1,
84
+ api_open: bool = False, # don't let API skip queue
85
+ allow_api: bool = True,
86
+ input_lines: int = 1,
87
+
88
+ sanitize_user_prompt: bool = True,
89
+ sanitize_bot_response: bool = True,
90
+
91
+ extra_model_options: typing.List[str] = [],
92
+ extra_lora_options: typing.List[str] = [],
93
+
94
+ score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
95
+ auto_score: bool = True,
96
+
97
+ eval_sharegpt_prompts_only: int = 0,
98
+ eval_sharegpt_prompts_only_seed: int = 1234,
99
+ eval_sharegpt_as_output: bool = False,
100
+
101
+ hard_stop_list: typing.List[str] = [],
102
+ ):
103
+ is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
104
+ is_gpth2oai = bool(os.getenv("GPT_H2O_AI"))
105
+ is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
106
+ is_low_mem = is_hf # assumes run on 24GB consumer GPU
107
+ admin_pass = os.getenv("ADMIN_PASS")
108
+ # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
109
+ # but becomes unrecoverable sometimes if raise, so just be silent for now
110
+ raise_generate_gpu_exceptions = not is_public
111
+
112
+ # allow set token directly
113
+ use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
114
+
115
+ if is_public:
116
+ input_lines = 1 # ensure set, for ease of use
117
+ temperature = 0.2
118
+ top_p = 0.85
119
+ top_k = 70
120
+ do_sample = True
121
+ if is_low_mem:
122
+ base_model = 'h2oai/h2ogpt-oasst1-512-12b'
123
+ load_8bit = True
124
+ else:
125
+ base_model = 'h2oai/h2ogpt-oasst1-512-20b'
126
+ if is_low_mem:
127
+ load_8bit = True
128
+ if is_hf:
129
+ # must override share if in spaces
130
+ share = False
131
+ save_dir = os.getenv('SAVE_DIR', save_dir)
132
+ score_model = os.getenv('SCORE_MODEL', score_model)
133
+ if score_model == 'None':
134
+ score_model = ''
135
+ concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count))
136
+ api_open = bool(int(os.getenv('API_OPEN', api_open)))
137
+ allow_api = bool(int(os.getenv('ALLOW_API', allow_api)))
138
+
139
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
140
+ if n_gpus == 0:
141
+ gpu_id = None
142
+ load_8bit = False
143
+ load_half = False
144
+ infer_devices = False
145
+ torch.backends.cudnn.benchmark = True
146
+ torch.backends.cudnn.enabled = False
147
+ torch.set_default_dtype(torch.float32)
148
+ if psutil.virtual_memory().available < 94*1024**3:
149
+ # 12B uses ~94GB
150
+ # 6.9B uses ~47GB
151
+ base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
152
+
153
+ # get defaults
154
+ model_lower = base_model.lower()
155
+ if not gradio:
156
+ # force, else not single response like want to look at
157
+ stream_output = False
158
+ # else prompt removal can mess up output
159
+ chat = False
160
+
161
+ placeholder_instruction, placeholder_input, \
162
+ stream_output, show_examples, \
163
+ prompt_type, temperature, top_p, top_k, num_beams, \
164
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
165
+ repetition_penalty, num_return_sequences, \
166
+ do_sample, \
167
+ src_lang, tgt_lang, \
168
+ examples, \
169
+ task_info = \
170
+ get_generate_params(model_lower, chat,
171
+ stream_output, show_examples,
172
+ prompt_type, temperature, top_p, top_k, num_beams,
173
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
174
+ repetition_penalty, num_return_sequences,
175
+ do_sample,
176
+ )
177
+
178
+ if not gradio:
179
+ if eval_sharegpt_prompts_only > 0:
180
+ # override default examples with shareGPT ones for human-level eval purposes only
181
+ eval_filename = 'ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json'
182
+ if not os.path.isfile(eval_filename):
183
+ os.system(
184
+ 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % eval_filename)
185
+ import json
186
+ data = json.load(open(eval_filename, 'rt'))
187
+ # focus on data that starts with human, else likely chopped from other data
188
+ turn_start = 0 # odd in general
189
+ data = [x for x in data if len(x['conversations']) > turn_start + 1 and
190
+ x['conversations'][turn_start]['from'] == 'human' and
191
+ x['conversations'][turn_start + 1]['from'] == 'gpt']
192
+ np.random.seed(eval_sharegpt_prompts_only_seed)
193
+ example1 = examples[-1] # pick reference example
194
+ examples = []
195
+ responses = []
196
+ for i in list(np.random.randint(0, len(data), size=eval_sharegpt_prompts_only)):
197
+ assert data[i]['conversations'][turn_start]['from'] == 'human'
198
+ instruction = data[i]['conversations'][turn_start]['value']
199
+ assert data[i]['conversations'][turn_start + 1]['from'] == 'gpt'
200
+ output = data[i]['conversations'][turn_start + 1]['value']
201
+ examplenew = example1.copy()
202
+ assert not chat, "No gradio must use chat=False, uses nochat instruct"
203
+ examplenew[eval_func_param_names.index('instruction_nochat')] = instruction
204
+ examplenew[eval_func_param_names.index('iinput_nochat')] = '' # no input
205
+ examplenew[eval_func_param_names.index('context')] = get_context(chat_context, prompt_type)
206
+ examples.append(examplenew)
207
+ responses.append(output)
208
+
209
+ num_examples = len(examples)
210
+ scoring_path = 'scoring'
211
+ os.makedirs(scoring_path, exist_ok=True)
212
+ if eval_sharegpt_as_output:
213
+ used_base_model = 'gpt35'
214
+ used_lora_weights = ''
215
+ else:
216
+ used_base_model = str(base_model.split('/')[-1])
217
+ used_lora_weights = str(lora_weights.split('/')[-1])
218
+ eval_filename = "df_scores_%s_%s_%s_%s_%s_%s.parquet" % (num_examples, eval_sharegpt_prompts_only,
219
+ eval_sharegpt_prompts_only_seed,
220
+ eval_sharegpt_as_output,
221
+ used_base_model,
222
+ used_lora_weights)
223
+ eval_filename = os.path.join(scoring_path, eval_filename)
224
+
225
+ # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
226
+ context_class = NullContext() if n_gpus > 1 or n_gpus == 0 else torch.device("cuda")
227
+
228
+ with context_class:
229
+ # ensure was set right above before examples generated
230
+ assert not stream_output, "stream_output=True does not make sense with example loop"
231
+ import time
232
+ from functools import partial
233
+
234
+ # get score model
235
+ smodel, stokenizer, sdevice = get_score_model(**locals())
236
+
237
+ if not eval_sharegpt_as_output:
238
+ model, tokenizer, device = get_model(**locals())
239
+ model_state = [model, tokenizer, device, base_model]
240
+ fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
241
+ raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
242
+ chat_context=chat_context,
243
+ concurrency_count=concurrency_count)
244
+ else:
245
+ assert eval_sharegpt_prompts_only > 0
246
+
247
+ def get_response(*args, exi=0):
248
+ # assumes same ordering of examples and responses
249
+ yield responses[exi]
250
+
251
+ fun = get_response
252
+ t0 = time.time()
253
+ score_dump = []
254
+
255
+ import matplotlib.pyplot as plt
256
+
257
+ for exi, ex in enumerate(examples):
258
+ instruction = ex[eval_func_param_names.index('instruction_nochat')]
259
+ iinput = ex[eval_func_param_names.index('iinput_nochat')]
260
+ context = ex[eval_func_param_names.index('context')]
261
+ clear_torch_cache()
262
+ print("")
263
+ print("START" + "=" * 100)
264
+ print("Question: %s %s" % (instruction, ('input=%s' % iinput if iinput else '')))
265
+ print("-" * 105)
266
+ # fun yields as generator, so have to iterate over it
267
+ # Also means likely do NOT want --stream_output=True, else would show all generations
268
+ gener = fun(*tuple(ex), exi=exi) if eval_sharegpt_as_output else fun(*tuple(ex))
269
+ for res in gener:
270
+ print(res)
271
+ if smodel:
272
+ score_with_prompt = False
273
+ if score_with_prompt:
274
+ data_point = dict(instruction=instruction, input=iinput, context=context)
275
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
276
+ prompt = prompter.generate_prompt(data_point)
277
+ else:
278
+ # just raw input and output
279
+ if eval_sharegpt_prompts_only > 0:
280
+ # only our own examples have this filled at moment
281
+ assert iinput in [None, ''], iinput # should be no iinput
282
+ if not (chat_context and prompt_type == 'human_bot'):
283
+ assert context in [None, ''], context # should be no context
284
+ prompt = instruction
285
+ cutoff_len = 768 if is_low_mem else 2048
286
+ inputs = stokenizer(prompt, res,
287
+ return_tensors="pt",
288
+ truncation=True,
289
+ max_length=cutoff_len)
290
+ try:
291
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
292
+ except torch.cuda.OutOfMemoryError as e:
293
+ print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
294
+ traceback.print_exc()
295
+ score = 0.0
296
+ clear_torch_cache()
297
+ except (Exception, RuntimeError) as e:
298
+ if 'Expected all tensors to be on the same device' in str(e) or \
299
+ 'expected scalar type Half but found Float' in str(e) or \
300
+ 'probability tensor contains either' in str(e) or \
301
+ 'cublasLt ran into an error!' in str(e):
302
+ print("GPU error: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
303
+ flush=True)
304
+ traceback.print_exc()
305
+ score = 0.0
306
+ clear_torch_cache()
307
+ else:
308
+ raise
309
+ print("SCORE %s: %s" % (exi, score), flush=True)
310
+ score_dump.append(ex + [prompt, res, score])
311
+ # dump every score in case abort
312
+ df_scores = pd.DataFrame(score_dump,
313
+ columns=eval_func_param_names + eval_extra_columns)
314
+ df_scores.to_parquet(eval_filename, index=False)
315
+ # plot histogram so far
316
+ plt.figure(figsize=(10, 10))
317
+ plt.hist(df_scores['score'], bins=20)
318
+ score_avg = np.mean(df_scores['score'])
319
+ score_median = np.median(df_scores['score'])
320
+ plt.title("Score avg: %s median: %s" % (score_avg, score_median))
321
+ plt.savefig(eval_filename.replace('.parquet', '.png'))
322
+ plt.close()
323
+
324
+ print("END" + "=" * 102)
325
+ print("")
326
+ t2 = time.time()
327
+ print("Time taken so far: %.4f about %.4g per example" % (t2 - t0, (t2 - t0) / (1 + exi)))
328
+ t1 = time.time()
329
+ print("Total time taken: %.4f about %.4g per example" % (t1 - t0, (t1 - t0) / num_examples))
330
+ return eval_filename
331
+
332
+ if gradio:
333
+ # imported here so don't require gradio to run generate
334
+ from gradio_runner import go_gradio
335
+
336
+ # get default model
337
+ all_kwargs = locals().copy()
338
+ if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
339
+ model0, tokenizer0, device = get_model(**all_kwargs)
340
+ else:
341
+ # if empty model, then don't load anything, just get gradio up
342
+ model0, tokenizer0, device = None, None, None
343
+ model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
344
+
345
+ # get score model
346
+ smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
347
+ score_model_state0 = [smodel, stokenizer, sdevice, score_model]
348
+
349
+ go_gradio(**locals())
350
+
351
+
352
+ def get_device():
353
+ if torch.cuda.is_available():
354
+ device = "cuda"
355
+ else:
356
+ device = "cpu"
357
+
358
+ return device
359
+
360
+
361
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
362
+ gpu_id=0,
363
+ use_auth_token=False):
364
+ """
365
+ Ensure model gets on correct device
366
+ :param base_model:
367
+ :param model_loader:
368
+ :param load_half:
369
+ :param model_kwargs:
370
+ :param reward_type:
371
+ :param gpu_id:
372
+ :param use_auth_token:
373
+ :return:
374
+ """
375
+ with init_empty_weights():
376
+ from transformers import AutoConfig
377
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
378
+ model = AutoModel.from_config(
379
+ config,
380
+ )
381
+
382
+ # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
383
+ # NOTE: Some models require avoiding sharding some layers,
384
+ # then would pass no_split_module_classes and give list of those layers.
385
+ device_map = infer_auto_device_map(
386
+ model,
387
+ dtype=torch.float16 if load_half else torch.float32,
388
+ )
389
+ if hasattr(model, 'model'):
390
+ device_map_model = infer_auto_device_map(
391
+ model.model,
392
+ dtype=torch.float16 if load_half else torch.float32,
393
+ )
394
+ device_map.update(device_map_model)
395
+ print('device_map: %s' % device_map, flush=True)
396
+
397
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
398
+
399
+ if n_gpus > 0:
400
+ if gpu_id >= 0:
401
+ # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set.
402
+ # So avoid for now, just put on first GPU, unless score_model, put on last
403
+ if reward_type:
404
+ device_map = {'': n_gpus - 1}
405
+ else:
406
+ device_map = {'': min(n_gpus - 1, gpu_id)}
407
+ if gpu_id == -1:
408
+ device_map = {'': 'cuda'}
409
+ else:
410
+ device_map = {'': 'cpu'}
411
+ model_kwargs['load_in_8bit'] = False
412
+
413
+ load_in_8bit = model_kwargs.get('load_in_8bit', False)
414
+ model_kwargs['device_map'] = device_map
415
+
416
+ if load_in_8bit or not load_half:
417
+ model = model_loader.from_pretrained(
418
+ base_model,
419
+ **model_kwargs,
420
+ )
421
+ else:
422
+ model = model_loader.from_pretrained(
423
+ base_model,
424
+ **model_kwargs,
425
+ ).half()
426
+ return model
427
+
428
+
429
+ def get_model(
430
+ load_8bit: bool = False,
431
+ load_half: bool = True,
432
+ infer_devices: bool = True,
433
+ base_model: str = '',
434
+ tokenizer_base_model: str = '',
435
+ lora_weights: str = "",
436
+ gpu_id: int = 0,
437
+
438
+ reward_type: bool = None,
439
+ local_files_only: bool = False,
440
+ resume_download: bool = True,
441
+ use_auth_token: Union[str, bool] = False,
442
+ compile: bool = True,
443
+ **kwargs,
444
+ ):
445
+ """
446
+
447
+ :param load_8bit: load model in 8-bit, not supported by all models
448
+ :param load_half: load model in 16-bit
449
+ :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case)
450
+ For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
451
+ So it is not the default
452
+ :param base_model: name/path of base model
453
+ :param tokenizer_base_model: name/path of tokenizer
454
+ :param lora_weights: name/path
455
+ :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
456
+ :param reward_type: reward type model for sequence classification
457
+ :param local_files_only: use local files instead of from HF
458
+ :param resume_download: resume downloads from HF
459
+ :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
460
+ :param compile: whether to compile torch model
461
+ :param kwargs:
462
+ :return:
463
+ """
464
+ print("Get %s model" % base_model, flush=True)
465
+ if lora_weights is not None and lora_weights.strip():
466
+ print("Get %s lora weights" % lora_weights, flush=True)
467
+ device = get_device()
468
+
469
+ if 'gpt2' in base_model.lower():
470
+ # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half
471
+ load_8bit = False
472
+
473
+ assert base_model.strip(), (
474
+ "Please choose a base model with --base_model (CLI) or in Models Tab (gradio)"
475
+ )
476
+
477
+ from transformers import AutoConfig
478
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token)
479
+ llama_type_from_config = 'llama' in str(config).lower()
480
+ llama_type_from_name = "llama" in base_model.lower()
481
+ llama_type = llama_type_from_config or llama_type_from_name
482
+ if llama_type:
483
+ print("Detected as llama type from"
484
+ " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
485
+
486
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
487
+ if not tokenizer_base_model:
488
+ tokenizer_base_model = base_model
489
+
490
+ if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
491
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
492
+ local_files_only=local_files_only,
493
+ resume_download=resume_download,
494
+ use_auth_token=use_auth_token,
495
+ )
496
+ else:
497
+ tokenizer = tokenizer_loader
498
+
499
+ if isinstance(tokenizer, str):
500
+ # already a pipeline, tokenizer_loader is string for task
501
+ model = model_loader(tokenizer,
502
+ model=base_model,
503
+ device=0 if device == "cuda" else -1,
504
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32)
505
+ else:
506
+ assert device in ["cuda", "cpu"], "Unsupported device %s" % device
507
+ model_kwargs = dict(local_files_only=local_files_only,
508
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
509
+ resume_download=resume_download,
510
+ use_auth_token=use_auth_token)
511
+ if 'mbart-' not in base_model.lower():
512
+ model_kwargs.update(dict(load_in_8bit=load_8bit,
513
+ device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
514
+ ))
515
+ if 'OpenAssistant/reward-model'.lower() in base_model.lower():
516
+ # could put on other GPUs
517
+ model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
518
+ model_kwargs.pop('torch_dtype', None)
519
+
520
+ if not lora_weights:
521
+ with torch.device(device):
522
+ if infer_devices:
523
+ model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
524
+ gpu_id=gpu_id, use_auth_token=use_auth_token)
525
+ else:
526
+ if load_half and not load_8bit:
527
+ model = model_loader.from_pretrained(
528
+ base_model,
529
+ **model_kwargs).half()
530
+ else:
531
+ model = model_loader.from_pretrained(
532
+ base_model,
533
+ **model_kwargs)
534
+ elif load_8bit:
535
+ model = model_loader.from_pretrained(
536
+ base_model,
537
+ **model_kwargs
538
+ )
539
+ model = PeftModel.from_pretrained(
540
+ model,
541
+ lora_weights,
542
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
543
+ local_files_only=local_files_only,
544
+ resume_download=resume_download,
545
+ use_auth_token=use_auth_token,
546
+ device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
547
+ )
548
+ else:
549
+ with torch.device(device):
550
+ model = model_loader.from_pretrained(
551
+ base_model,
552
+ **model_kwargs
553
+ )
554
+ model = PeftModel.from_pretrained(
555
+ model,
556
+ lora_weights,
557
+ torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
558
+ local_files_only=local_files_only,
559
+ resume_download=resume_download,
560
+ use_auth_token=use_auth_token,
561
+ device_map="auto",
562
+ )
563
+ if load_half:
564
+ model.half()
565
+
566
+ # unwind broken decapoda-research config
567
+ if llama_type:
568
+ model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk
569
+ model.config.bos_token_id = 1
570
+ model.config.eos_token_id = 2
571
+ if 'gpt2' in base_model.lower():
572
+ # add special tokens that otherwise all share the same id
573
+ tokenizer.add_special_tokens({'bos_token': '<bos>',
574
+ 'eos_token': '<eos>',
575
+ 'pad_token': '<pad>'})
576
+
577
+ if not isinstance(tokenizer, str):
578
+ model.eval()
579
+ if torch.__version__ >= "2" and sys.platform != "win32" and compile:
580
+ model = torch.compile(model)
581
+
582
+ return model, tokenizer, device
583
+
584
+
585
+ def get_score_model(**kwargs):
586
+ # score model
587
+ if kwargs.get('score_model') is not None and kwargs.get('score_model').strip():
588
+ score_all_kwargs = kwargs.copy()
589
+ score_all_kwargs['load_8bit'] = False
590
+ score_all_kwargs['load_half'] = False
591
+ score_all_kwargs['base_model'] = kwargs.get('score_model').strip()
592
+ score_all_kwargs['tokenizer_base_model'] = ''
593
+ score_all_kwargs['lora_weights'] = ''
594
+ score_all_kwargs['llama_type'] = False
595
+ score_all_kwargs['compile'] = False
596
+ smodel, stokenizer, sdevice = get_model(**score_all_kwargs)
597
+ else:
598
+ smodel, stokenizer, sdevice = None, None, None
599
+ return smodel, stokenizer, sdevice
600
+
601
+
602
+ eval_func_param_names = ['instruction',
603
+ 'iinput',
604
+ 'context',
605
+ 'stream_output',
606
+ 'prompt_type',
607
+ 'temperature',
608
+ 'top_p',
609
+ 'top_k',
610
+ 'num_beams',
611
+ 'max_new_tokens',
612
+ 'min_new_tokens',
613
+ 'early_stopping',
614
+ 'max_time',
615
+ 'repetition_penalty',
616
+ 'num_return_sequences',
617
+ 'do_sample',
618
+ 'chat',
619
+ 'instruction_nochat',
620
+ 'iinput_nochat',
621
+ ]
622
+
623
+
624
+ def evaluate(
625
+ model_state,
626
+ # START NOTE: Examples must have same order of parameters
627
+ instruction,
628
+ iinput,
629
+ context,
630
+ stream_output,
631
+ prompt_type,
632
+ temperature,
633
+ top_p,
634
+ top_k,
635
+ num_beams,
636
+ max_new_tokens,
637
+ min_new_tokens,
638
+ early_stopping,
639
+ max_time,
640
+ repetition_penalty,
641
+ num_return_sequences,
642
+ do_sample,
643
+ chat,
644
+ instruction_nochat,
645
+ iinput_nochat,
646
+ # END NOTE: Examples must have same order of parameters
647
+ src_lang=None,
648
+ tgt_lang=None,
649
+ debug=False,
650
+ concurrency_count=None,
651
+ save_dir=None,
652
+ hard_stop_list=None,
653
+ sanitize_bot_response=True,
654
+ model_state0=None,
655
+ is_low_mem=None,
656
+ raise_generate_gpu_exceptions=None,
657
+ chat_context=None,
658
+ ):
659
+ # ensure passed these
660
+ assert concurrency_count is not None
661
+ assert is_low_mem is not None
662
+ assert raise_generate_gpu_exceptions is not None
663
+ assert chat_context is not None
664
+
665
+ if debug:
666
+ locals_dict = locals().copy()
667
+ locals_dict.pop('model_state', None)
668
+ locals_dict.pop('model_state0', None)
669
+ print(locals_dict)
670
+
671
+ no_model_msg = "Please choose a base model with --base_model (CLI) or in Models Tab (gradio).\nThen start New Conversation"
672
+
673
+ if model_state0 is None:
674
+ # e.g. for no gradio case, set dummy value, else should be set
675
+ model_state0 = [None, None, None, None]
676
+
677
+ if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
678
+ # try to free-up original model (i.e. list was passed as reference)
679
+ if model_state0 is not None and model_state0[0] is not None:
680
+ model_state0[0].cpu()
681
+ model_state0[0] = None
682
+ # try to free-up original tokenizer (i.e. list was passed as reference)
683
+ if model_state0 is not None and model_state0[1] is not None:
684
+ model_state0[1] = None
685
+ clear_torch_cache()
686
+ model, tokenizer, device, base_model = model_state
687
+ elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
688
+ assert isinstance(model_state[0], str)
689
+ model, tokenizer, device, base_model = model_state0
690
+ else:
691
+ raise AssertionError(no_model_msg)
692
+
693
+ if base_model is None:
694
+ raise AssertionError(no_model_msg)
695
+
696
+ assert base_model.strip(), no_model_msg
697
+ assert model, "Model is missing"
698
+ assert tokenizer, "Tokenizer is missing"
699
+
700
+ # choose chat or non-chat mode
701
+ if not chat:
702
+ instruction = instruction_nochat
703
+ iinput = iinput_nochat
704
+
705
+ if not context:
706
+ # get hidden context if have one
707
+ context = get_context(chat_context, prompt_type)
708
+
709
+ data_point = dict(context=context, instruction=instruction, input=iinput)
710
+ prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
711
+ prompt = prompter.generate_prompt(data_point)
712
+
713
+ if hard_stop_list is None:
714
+ # acts like undo on user entry and bot response
715
+ hard_stop_list = []
716
+
717
+ if isinstance(tokenizer, str):
718
+ # pipeline
719
+ if tokenizer == "summarization":
720
+ key = 'summary_text'
721
+ else:
722
+ raise RuntimeError("No such task type %s" % tokenizer)
723
+ # NOTE: uses max_length only
724
+ yield model(prompt, max_length=max_new_tokens)[0][key]
725
+
726
+ if 'mbart-' in base_model.lower():
727
+ assert src_lang is not None
728
+ tokenizer.src_lang = languages_covered()[src_lang]
729
+
730
+ if chat:
731
+ # override, ignore user change
732
+ num_return_sequences = 1
733
+ if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
734
+ if prompt_type == 'human_bot':
735
+ # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
736
+ # stopping only starts once output is beyond prompt
737
+ # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
738
+ stop_words = [human, bot, '\n' + human, '\n' + bot]
739
+ encounters = [1, 2]
740
+ elif prompt_type == 'instruct_vicuna':
741
+ # even below is not enough, generic strings and many ways to encode
742
+ stop_words = [
743
+ '### Human:',
744
+ """
745
+ ### Human:""",
746
+ """
747
+ ### Human:
748
+ """,
749
+ '### Assistant:',
750
+ """
751
+ ### Assistant:""",
752
+ """
753
+ ### Assistant:
754
+ """,
755
+ ]
756
+ encounters = [1, 2]
757
+ else:
758
+ # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
759
+ stop_words = ['### End']
760
+ encounters = [1]
761
+ stop_words_ids = [
762
+ tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
763
+ # handle single token case
764
+ stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
765
+ stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
766
+ # avoid padding in front of tokens
767
+ if tokenizer.pad_token:
768
+ stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
769
+ # handle fake \n added
770
+ stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
771
+ # build stopper
772
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
773
+ else:
774
+ stopping_criteria = StoppingCriteriaList()
775
+
776
+ # help to avoid errors like:
777
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
778
+ # RuntimeError: expected scalar type Half but found Float
779
+ # with - 256
780
+ max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
781
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
782
+ output_smallest = 30 * 4
783
+ prompt = prompt[-cutoff_len - output_smallest:]
784
+ inputs = tokenizer(prompt,
785
+ return_tensors="pt",
786
+ truncation=True,
787
+ max_length=max_length_tokenize)
788
+ if debug and len(inputs["input_ids"]) > 0:
789
+ print('input_ids length', len(inputs["input_ids"][0]), flush=True)
790
+ input_ids = inputs["input_ids"].to(device)
791
+ generation_config = GenerationConfig(
792
+ temperature=float(temperature),
793
+ top_p=float(top_p),
794
+ top_k=top_k,
795
+ num_beams=num_beams,
796
+ do_sample=do_sample,
797
+ repetition_penalty=float(repetition_penalty),
798
+ num_return_sequences=num_return_sequences,
799
+ renormalize_logits=True,
800
+ remove_invalid_values=True,
801
+ )
802
+
803
+ gen_kwargs = dict(input_ids=input_ids,
804
+ generation_config=generation_config,
805
+ return_dict_in_generate=True,
806
+ output_scores=True,
807
+ max_new_tokens=max_new_tokens, # prompt + new
808
+ min_new_tokens=min_new_tokens, # prompt + new
809
+ early_stopping=early_stopping, # False, True, "never"
810
+ max_time=max_time,
811
+ stopping_criteria=stopping_criteria,
812
+ )
813
+ if 'gpt2' in base_model.lower():
814
+ gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id))
815
+ elif 'mbart-' in base_model.lower():
816
+ assert tgt_lang is not None
817
+ tgt_lang = languages_covered()[tgt_lang]
818
+ gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
819
+ else:
820
+ gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
821
+
822
+ decoder = functools.partial(tokenizer.decode,
823
+ skip_special_tokens=True,
824
+ clean_up_tokenization_spaces=True,
825
+ )
826
+ decoder_raw = functools.partial(tokenizer.decode,
827
+ skip_special_tokens=False,
828
+ clean_up_tokenization_spaces=True,
829
+ )
830
+
831
+ with torch.no_grad():
832
+ # protection for gradio not keeping track of closed users,
833
+ # else hit bitsandbytes lack of thread safety:
834
+ # https://github.com/h2oai/h2ogpt/issues/104
835
+ # but only makes sense if concurrency_count == 1
836
+ context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
837
+ print('Pre-Generate: %s' % str(datetime.now()), flush=True)
838
+ decoded_output = None
839
+ with context_class("generate.lock"):
840
+ print('Generate: %s' % str(datetime.now()), flush=True)
841
+ # decoded tokenized prompt can deviate from prompt due to special characters
842
+ inputs_decoded = decoder(input_ids[0])
843
+ inputs_decoded_raw = decoder_raw(input_ids[0])
844
+ if inputs_decoded == prompt:
845
+ # normal
846
+ pass
847
+ elif inputs_decoded.lstrip() == prompt.lstrip():
848
+ # sometimes extra space in front, make prompt same for prompt removal
849
+ prompt = inputs_decoded
850
+ elif inputs_decoded_raw == prompt:
851
+ # some models specify special tokens that are part of normal prompt, so can't skip them
852
+ inputs_decoded_raw = inputs_decoded
853
+ decoder = decoder_raw
854
+ else:
855
+ print("WARNING: Special characters in prompt", flush=True)
856
+ if stream_output:
857
+ skip_prompt = False
858
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=skip_prompt)
859
+ gen_kwargs.update(dict(streamer=streamer))
860
+ target_func = generate_with_exceptions
861
+ target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
862
+ raise_generate_gpu_exceptions, **gen_kwargs)
863
+ thread = Thread(target=target)
864
+ thread.start()
865
+ outputs = ""
866
+ for new_text in streamer:
867
+ outputs += new_text
868
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
869
+ sanitize_bot_response=sanitize_bot_response)
870
+ decoded_output = outputs
871
+ else:
872
+ outputs = model.generate(**gen_kwargs)
873
+ outputs = [decoder(s) for s in outputs.sequences]
874
+ yield prompter.get_response(outputs, prompt=inputs_decoded,
875
+ sanitize_bot_response=sanitize_bot_response)
876
+ if outputs and len(outputs) >= 1:
877
+ decoded_output = prompt + outputs[0]
878
+ if save_dir and decoded_output:
879
+ save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
880
+ print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
881
+
882
+
883
+ def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
884
+ try:
885
+ func(**kwargs)
886
+ except torch.cuda.OutOfMemoryError as e:
887
+ print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
888
+ flush=True)
889
+ if kwargs['input_ids'] is not None:
890
+ kwargs['input_ids'].cpu()
891
+ kwargs['input_ids'] = None
892
+ traceback.print_exc()
893
+ clear_torch_cache()
894
+ return
895
+ except (Exception, RuntimeError) as e:
896
+ if 'Expected all tensors to be on the same device' in str(e) or \
897
+ 'expected scalar type Half but found Float' in str(e) or \
898
+ 'probability tensor contains either' in str(e) or \
899
+ 'cublasLt ran into an error!' in str(e) or \
900
+ 'mat1 and mat2 shapes cannot be multiplied' in str(e):
901
+ print(
902
+ "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
903
+ flush=True)
904
+ traceback.print_exc()
905
+ clear_torch_cache()
906
+ if raise_generate_gpu_exceptions:
907
+ raise
908
+ return
909
+ else:
910
+ clear_torch_cache()
911
+ raise
912
+
913
+
914
+ def get_generate_params(model_lower, chat,
915
+ stream_output, show_examples,
916
+ prompt_type, temperature, top_p, top_k, num_beams,
917
+ max_new_tokens, min_new_tokens, early_stopping, max_time,
918
+ repetition_penalty, num_return_sequences,
919
+ do_sample):
920
+ use_defaults = False
921
+ use_default_examples = True
922
+ examples = []
923
+ task_info = f"{prompt_type}"
924
+ if model_lower:
925
+ print(f"Using Model {model_lower}", flush=True)
926
+ else:
927
+ print("No model defined yet", flush=True)
928
+
929
+ min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
930
+ early_stopping = early_stopping if early_stopping is not None else False
931
+ max_time_defaults = 60 * 3
932
+ max_time = max_time if max_time is not None else max_time_defaults
933
+
934
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
935
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
936
+
937
+ # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
938
+ if show_examples is None:
939
+ if chat:
940
+ show_examples = False
941
+ else:
942
+ show_examples = True
943
+
944
+ summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker?
945
+ Philipp: Sure you can use the new Hugging Face Deep Learning Container.
946
+ Jeff: ok.
947
+ Jeff: and how can I get started?
948
+ Jeff: where can I find documentation?
949
+ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face"""
950
+
951
+ if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower:
952
+ placeholder_instruction = summarize_example1
953
+ placeholder_input = ""
954
+ use_defaults = True
955
+ use_default_examples = False
956
+ examples += [
957
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
958
+ 1.0, 1,
959
+ False]]
960
+ task_info = "Summarization"
961
+ elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower:
962
+ placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?"
963
+ placeholder_input = ""
964
+ use_defaults = True
965
+ use_default_examples = True
966
+ task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)"
967
+ elif 'mbart-' in model_lower:
968
+ placeholder_instruction = "The girl has long hair."
969
+ placeholder_input = ""
970
+ use_defaults = True
971
+ use_default_examples = False
972
+ examples += [
973
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
974
+ 1.0, 1,
975
+ False]]
976
+ elif 'gpt2' in model_lower:
977
+ placeholder_instruction = "The sky is"
978
+ placeholder_input = ""
979
+ prompt_type = prompt_type or 'plain'
980
+ use_default_examples = True # some will be odd "continuations" but can be ok
981
+ examples += [
982
+ [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults,
983
+ 1.0, 1,
984
+ False]]
985
+ task_info = "Auto-complete phrase, code, etc."
986
+ use_defaults = True
987
+ else:
988
+ if chat:
989
+ placeholder_instruction = "Enter a question or imperative."
990
+ else:
991
+ placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
992
+ placeholder_input = ""
993
+ if model_lower:
994
+ prompt_type = prompt_type or 'human_bot'
995
+ else:
996
+ prompt_type = ''
997
+ examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
998
+ stream_output, prompt_type or 'plain', 0.1, 0.75, 40, 4, 256, 0, False, max_time_defaults, 1.0, 1,
999
+ False]]
1000
+ task_info = "No task"
1001
+ if prompt_type == 'instruct':
1002
+ task_info = "Answer question or follow imperative as instruction with optionally input."
1003
+ elif prompt_type == 'plain':
1004
+ task_info = "Auto-complete phrase, code, etc."
1005
+ elif prompt_type == 'human_bot':
1006
+ if chat:
1007
+ task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)"
1008
+ else:
1009
+ task_info = "Ask question/imperative (input concatenated with instruction)"
1010
+
1011
+ # revert to plain if still nothing
1012
+ prompt_type = prompt_type or 'plain'
1013
+ if use_defaults:
1014
+ temperature = 1.0 if temperature is None else temperature
1015
+ top_p = 1.0 if top_p is None else top_p
1016
+ top_k = 40 if top_k is None else top_k
1017
+ num_beams = num_beams or 1
1018
+ max_new_tokens = max_new_tokens or 128
1019
+ repetition_penalty = repetition_penalty or 1.07
1020
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1021
+ do_sample = False if do_sample is None else do_sample
1022
+ else:
1023
+ temperature = 0.2 if temperature is None else temperature
1024
+ top_p = 0.85 if top_p is None else top_p
1025
+ top_k = 70 if top_k is None else top_k
1026
+ if chat:
1027
+ num_beams = num_beams or 1
1028
+ else:
1029
+ num_beams = num_beams or 4
1030
+ max_new_tokens = max_new_tokens or 256
1031
+ repetition_penalty = repetition_penalty or 1.07
1032
+ num_return_sequences = min(num_beams, num_return_sequences or 1)
1033
+ do_sample = True if do_sample is None else do_sample
1034
+ # doesn't include chat, instruction_nochat, iinput_nochat, added later
1035
+ params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
1036
+ early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
1037
+
1038
+ if use_default_examples:
1039
+ examples += [
1040
+ ["Translate English to French", "Good morning"] + params_list,
1041
+ ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list,
1042
+ ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list,
1043
+ [
1044
+ "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.",
1045
+ ''] + params_list,
1046
+ ['Translate to German: My name is Arthur', ''] + params_list,
1047
+ ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list,
1048
+ ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.',
1049
+ ''] + params_list,
1050
+ ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list,
1051
+ ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list,
1052
+ ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list,
1053
+ [
1054
+ "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?",
1055
+ ''] + params_list,
1056
+ ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list,
1057
+ [
1058
+ 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?',
1059
+ ''] + params_list,
1060
+ ["""def area_of_rectangle(a: float, b: float):
1061
+ \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list,
1062
+ ["""# a function in native python:
1063
+ def mean(a):
1064
+ return sum(a)/len(a)
1065
+
1066
+ # the same function using numpy:
1067
+ import numpy as np
1068
+ def mean(a):""", ''] + params_list,
1069
+ ["""X = np.random.randn(100, 100)
1070
+ y = np.random.randint(0, 1, 100)
1071
+
1072
+ # fit random forest classifier with 20 estimators""", ''] + params_list,
1073
+ ]
1074
+
1075
+ src_lang = "English"
1076
+ tgt_lang = "Russian"
1077
+
1078
+ # move to correct position
1079
+ for example in examples:
1080
+ example += [chat, '', '']
1081
+ # adjust examples if non-chat mode
1082
+ if not chat:
1083
+ example[eval_func_param_names.index('instruction_nochat')] = example[
1084
+ eval_func_param_names.index('instruction')]
1085
+ example[eval_func_param_names.index('instruction')] = ''
1086
+
1087
+ example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1088
+ example[eval_func_param_names.index('iinput')] = ''
1089
+
1090
+ return placeholder_instruction, placeholder_input, \
1091
+ stream_output, show_examples, \
1092
+ prompt_type, temperature, top_p, top_k, num_beams, \
1093
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1094
+ repetition_penalty, num_return_sequences, \
1095
+ do_sample, \
1096
+ src_lang, tgt_lang, \
1097
+ examples, \
1098
+ task_info
1099
+
1100
+
1101
+ def languages_covered():
1102
+ # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered
1103
+ covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)"""
1104
+ covered = covered.split(', ')
1105
+ covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered}
1106
+ return covered
1107
+
1108
+
1109
+ def get_context(chat_context, prompt_type):
1110
+ if chat_context and prompt_type == 'human_bot':
1111
+ context0 = """<bot>: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand.
1112
+ <human>: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed."""
1113
+ else:
1114
+ context0 = ''
1115
+ return context0
1116
+
1117
+
1118
+ def test_test_prompt(prompt_type='instruct', data_point=0):
1119
+ example_data_point = example_data_points[data_point]
1120
+ example_data_point.pop('output', None)
1121
+ return generate_prompt(example_data_point, prompt_type, False, False)
1122
+
1123
+
1124
+ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
1125
+ question = question[-cutoff_len:]
1126
+ answer = answer[-cutoff_len:]
1127
+
1128
+ inputs = stokenizer(question, answer,
1129
+ return_tensors="pt",
1130
+ truncation=True,
1131
+ max_length=max_length_tokenize).to(smodel.device)
1132
+ try:
1133
+ score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
1134
+ except torch.cuda.OutOfMemoryError as e:
1135
+ print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
1136
+ del inputs
1137
+ traceback.print_exc()
1138
+ clear_torch_cache()
1139
+ return 'Response Score: GPU OOM'
1140
+ except (Exception, RuntimeError) as e:
1141
+ if 'Expected all tensors to be on the same device' in str(e) or \
1142
+ 'expected scalar type Half but found Float' in str(e) or \
1143
+ 'probability tensor contains either' in str(e) or \
1144
+ 'cublasLt ran into an error!' in str(e):
1145
+ print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
1146
+ flush=True)
1147
+ traceback.print_exc()
1148
+ clear_torch_cache()
1149
+ return 'Response Score: GPU Error'
1150
+ else:
1151
+ raise
1152
+ os.environ['TOKENIZERS_PARALLELISM'] = 'true'
1153
+ return score
1154
+
1155
+
1156
+ if __name__ == "__main__":
1157
+ print("""
1158
+ WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B
1159
+ python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B'
1160
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B'
1161
+
1162
+ # generate without lora weights, no prompt
1163
+ python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain'
1164
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq'
1165
+
1166
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq'
1167
+ # OpenChatKit settings:
1168
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0
1169
+
1170
+ python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False
1171
+ python generate.py --base_model='t5-large' --prompt_type='simple_instruct'
1172
+ python generate.py --base_model='philschmid/bart-large-cnn-samsum'
1173
+ python generate.py --base_model='philschmid/flan-t5-base-samsum'
1174
+ python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt'
1175
+
1176
+ python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28'
1177
+
1178
+ must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False
1179
+ can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned
1180
+ python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot'
1181
+
1182
+ python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1183
+
1184
+ """, flush=True)
1185
+ fire.Fire(main)
gradio_runner.py ADDED
@@ -0,0 +1,910 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import inspect
3
+ import os
4
+ import sys
5
+
6
+ from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
7
+ from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
8
+ ping
9
+ from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
10
+ from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa
11
+
12
+ import gradio as gr
13
+ from apscheduler.schedulers.background import BackgroundScheduler
14
+
15
+
16
+ def go_gradio(**kwargs):
17
+ allow_api = kwargs['allow_api']
18
+ is_public = kwargs['is_public']
19
+ is_hf = kwargs['is_hf']
20
+ is_low_mem = kwargs['is_low_mem']
21
+ n_gpus = kwargs['n_gpus']
22
+ admin_pass = kwargs['admin_pass']
23
+ model_state0 = kwargs['model_state0']
24
+ score_model_state0 = kwargs['score_model_state0']
25
+ queue = True
26
+
27
+ # easy update of kwargs needed for evaluate() etc.
28
+ kwargs.update(locals())
29
+
30
+ if 'mbart-' in kwargs['model_lower']:
31
+ instruction_label_nochat = "Text to translate"
32
+ else:
33
+ instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
34
+ " use Enter for multiple input lines)"
35
+ if kwargs['input_lines'] > 1:
36
+ instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
37
+ else:
38
+ instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
39
+
40
+ title = 'h2oGPT'
41
+ if 'h2ogpt-research' in kwargs['base_model']:
42
+ title += " [Research demonstration]"
43
+ if kwargs['verbose']:
44
+ description = f"""Model {kwargs['base_model']} Instruct dataset.
45
+ For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
46
+ Command: {str(' '.join(sys.argv))}
47
+ Hash: {get_githash()}
48
+ """
49
+ else:
50
+ description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).<br>"
51
+ if is_public:
52
+ description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
53
+ description += """<p><b> DISCLAIMERS: </b><ul><i><li>The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.</i></li>"""
54
+ if kwargs['load_8bit']:
55
+ description += """<i><li> Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.</i></li>"""
56
+ description += """<i><li>Conversations may be used to improve h2oGPT. Do not share sensitive information.</i></li>"""
57
+ if 'h2ogpt-research' in kwargs['base_model']:
58
+ description += """<i><li>Research demonstration only, not used for commercial purposes.</i></li>"""
59
+ description += """<i><li>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md).</i></li></ul></p>"""
60
+
61
+ if kwargs['verbose']:
62
+ task_info_md = f"""
63
+ ### Task: {kwargs['task_info']}"""
64
+ else:
65
+ task_info_md = ''
66
+
67
+ if kwargs['h2ocolors']:
68
+ css_code = """footer {visibility: hidden;}
69
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
70
+ body.dark{background:linear-gradient(#000000,#0d0d0d);}
71
+ """
72
+ else:
73
+ css_code = """footer {visibility: hidden}"""
74
+
75
+ if kwargs['gradio_avoid_processing_markdown']:
76
+ from gradio_client import utils as client_utils
77
+ from gradio.components import Chatbot
78
+
79
+ # gradio has issue with taking too long to process input/output for markdown etc.
80
+ # Avoid for now, allow raw html to render, good enough for chatbot.
81
+ def _postprocess_chat_messages(self, chat_message: str):
82
+ if chat_message is None:
83
+ return None
84
+ elif isinstance(chat_message, (tuple, list)):
85
+ filepath = chat_message[0]
86
+ mime_type = client_utils.get_mimetype(filepath)
87
+ filepath = self.make_temp_copy_if_needed(filepath)
88
+ return {
89
+ "name": filepath,
90
+ "mime_type": mime_type,
91
+ "alt_text": chat_message[1] if len(chat_message) > 1 else None,
92
+ "data": None, # These last two fields are filled in by the frontend
93
+ "is_file": True,
94
+ }
95
+ elif isinstance(chat_message, str):
96
+ return chat_message
97
+ else:
98
+ raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
99
+
100
+ Chatbot._postprocess_chat_messages = _postprocess_chat_messages
101
+
102
+ theme = H2oTheme() if kwargs['h2ocolors'] else SoftTheme()
103
+ demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
104
+ callback = gr.CSVLogger()
105
+
106
+ model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
107
+ if kwargs['base_model'].strip() not in model_options:
108
+ lora_options = [kwargs['base_model'].strip()] + model_options
109
+ lora_options = kwargs['extra_lora_options']
110
+ if kwargs['lora_weights'].strip() not in lora_options:
111
+ lora_options = [kwargs['lora_weights'].strip()] + lora_options
112
+ # always add in no lora case
113
+ # add fake space so doesn't go away in gradio dropdown
114
+ no_lora_str = no_model_str = '[None/Remove]'
115
+ lora_options = [no_lora_str] + kwargs['extra_lora_options'] # FIXME: why double?
116
+ # always add in no model case so can free memory
117
+ # add fake space so doesn't go away in gradio dropdown
118
+ model_options = [no_model_str] + model_options
119
+
120
+ # transcribe, will be detranscribed before use by evaluate()
121
+ if not kwargs['lora_weights'].strip():
122
+ kwargs['lora_weights'] = no_lora_str
123
+
124
+ if not kwargs['base_model'].strip():
125
+ kwargs['base_model'] = no_model_str
126
+
127
+ # transcribe for gradio
128
+ kwargs['gpu_id'] = str(kwargs['gpu_id'])
129
+
130
+ no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]'
131
+ output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get(
132
+ 'base_model') else no_model_msg
133
+ output_label0_model2 = no_model_msg
134
+
135
+ with demo:
136
+ # avoid actual model/tokenizer here or anything that would be bad to deepcopy
137
+ # https://github.com/gradio-app/gradio/issues/3558
138
+ model_state = gr.State(['model', 'tokenizer', kwargs['device'], kwargs['base_model']])
139
+ model_state2 = gr.State([None, None, None, None])
140
+ model_options_state = gr.State([model_options])
141
+ lora_options_state = gr.State([lora_options])
142
+ gr.Markdown(f"""
143
+ {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
144
+
145
+ {description}
146
+ {task_info_md}
147
+ """)
148
+ if is_hf:
149
+ gr.HTML(
150
+ '''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
151
+
152
+ # go button visible if
153
+ base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
154
+ go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary")
155
+ normal_block = gr.Row(visible=not base_wanted)
156
+ with normal_block:
157
+ with gr.Tabs():
158
+ with gr.Row():
159
+ col_nochat = gr.Column(visible=not kwargs['chat'])
160
+ with col_nochat: # FIXME: for model comparison, and check rest
161
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0)
162
+ instruction_nochat = gr.Textbox(
163
+ lines=kwargs['input_lines'],
164
+ label=instruction_label_nochat,
165
+ placeholder=kwargs['placeholder_instruction'],
166
+ )
167
+ iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction",
168
+ placeholder=kwargs['placeholder_input'])
169
+ submit_nochat = gr.Button("Submit")
170
+ flag_btn_nochat = gr.Button("Flag")
171
+ if not kwargs['auto_score']:
172
+ with gr.Column(visible=kwargs['score_model']):
173
+ score_btn_nochat = gr.Button("Score last prompt & response")
174
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
175
+ else:
176
+ with gr.Column(visible=kwargs['score_model']):
177
+ score_text_nochat = gr.Textbox("Response Score: NA", show_label=False)
178
+ col_chat = gr.Column(visible=kwargs['chat'])
179
+ with col_chat:
180
+ with gr.Row():
181
+ text_output = gr.Chatbot(label=output_label0).style(height=kwargs['height'] or 400)
182
+ text_output2 = gr.Chatbot(label=output_label0_model2, visible=False).style(
183
+ height=kwargs['height'] or 400)
184
+ with gr.Row():
185
+ with gr.Column(scale=50):
186
+ instruction = gr.Textbox(
187
+ lines=kwargs['input_lines'],
188
+ label=instruction_label,
189
+ placeholder=kwargs['placeholder_instruction'],
190
+ )
191
+ with gr.Row():
192
+ submit = gr.Button(value='Submit').style(full_width=False, size='sm')
193
+ stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
194
+ with gr.Row():
195
+ clear = gr.Button("New Conversation")
196
+ flag_btn = gr.Button("Flag")
197
+ if not kwargs['auto_score']: # FIXME: For checkbox model2
198
+ with gr.Column(visible=kwargs['score_model']):
199
+ with gr.Row():
200
+ score_btn = gr.Button("Score last prompt & response").style(
201
+ full_width=False, size='sm')
202
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
203
+ score_res2 = gr.Row(visible=False)
204
+ with score_res2:
205
+ score_btn2 = gr.Button("Score last prompt & response 2").style(
206
+ full_width=False, size='sm')
207
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False)
208
+ else:
209
+ with gr.Column(visible=kwargs['score_model']):
210
+ score_text = gr.Textbox("Response Score: NA", show_label=False)
211
+ score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
212
+ retry = gr.Button("Regenerate")
213
+ undo = gr.Button("Undo")
214
+ with gr.TabItem("Input/Output"):
215
+ with gr.Row():
216
+ if 'mbart-' in kwargs['model_lower']:
217
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
218
+ value=kwargs['src_lang'],
219
+ label="Input Language")
220
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
221
+ value=kwargs['tgt_lang'],
222
+ label="Output Language")
223
+ with gr.TabItem("Expert"):
224
+ with gr.Row():
225
+ with gr.Column():
226
+ stream_output = gr.components.Checkbox(label="Stream output",
227
+ value=kwargs['stream_output'])
228
+ prompt_type = gr.Dropdown(prompt_types_strings,
229
+ value=kwargs['prompt_type'], label="Prompt Type",
230
+ visible=not is_public)
231
+ prompt_type2 = gr.Dropdown(prompt_types_strings,
232
+ value=kwargs['prompt_type'], label="Prompt Type Model 2",
233
+ visible=not is_public and False)
234
+ do_sample = gr.Checkbox(label="Sample",
235
+ info="Enable sampler, required for use of temperature, top_p, top_k",
236
+ value=kwargs['do_sample'])
237
+ temperature = gr.Slider(minimum=0.01, maximum=3,
238
+ value=kwargs['temperature'],
239
+ label="Temperature",
240
+ info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
241
+ top_p = gr.Slider(minimum=0, maximum=1,
242
+ value=kwargs['top_p'], label="Top p",
243
+ info="Cumulative probability of tokens to sample from")
244
+ top_k = gr.Slider(
245
+ minimum=0, maximum=100, step=1,
246
+ value=kwargs['top_k'], label="Top k",
247
+ info='Num. tokens to sample from'
248
+ )
249
+ max_beams = 8 if not is_low_mem else 2
250
+ num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
251
+ value=min(max_beams, kwargs['num_beams']), label="Beams",
252
+ info="Number of searches for optimal overall probability. "
253
+ "Uses more GPU memory/compute")
254
+ max_max_new_tokens = 2048 if not is_low_mem else kwargs['max_new_tokens']
255
+ max_new_tokens = gr.Slider(
256
+ minimum=1, maximum=max_max_new_tokens, step=1,
257
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
258
+ )
259
+ min_new_tokens = gr.Slider(
260
+ minimum=0, maximum=max_max_new_tokens, step=1,
261
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
262
+ )
263
+ early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
264
+ value=kwargs['early_stopping'])
265
+ max_max_time = 60 * 5 if not is_low_mem else 60
266
+ max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
267
+ value=min(max_max_time, kwargs['max_time']), label="Max. time",
268
+ info="Max. time to search optimal output.")
269
+ repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
270
+ value=kwargs['repetition_penalty'],
271
+ label="Repetition Penalty")
272
+ num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
273
+ value=kwargs['num_return_sequences'],
274
+ label="Number Returns", info="Must be <= num_beams",
275
+ visible=not is_public)
276
+ iinput = gr.Textbox(lines=4, label="Input",
277
+ placeholder=kwargs['placeholder_input'],
278
+ visible=not is_public)
279
+ context = gr.Textbox(lines=3, label="System Pre-Context",
280
+ info="Directly pre-appended without prompt processing",
281
+ visible=not is_public)
282
+ chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
283
+ visible=not is_public)
284
+
285
+ with gr.TabItem("Models"):
286
+ load_msg = "Load-Unload Model/LORA" if not is_public \
287
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
288
+ load_msg2 = "Load-Unload Model/LORA 2" if not is_public \
289
+ else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
290
+ compare_checkbox = gr.components.Checkbox(label="Compare Mode",
291
+ value=False, visible=not is_public)
292
+ with gr.Row():
293
+ n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
294
+ with gr.Column():
295
+ with gr.Row():
296
+ with gr.Column(scale=50):
297
+ model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
298
+ value=kwargs['base_model'])
299
+ lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
300
+ value=kwargs['lora_weights'], visible=kwargs['show_lora'])
301
+ with gr.Column(scale=1):
302
+ load_model_button = gr.Button(load_msg)
303
+ model_load8bit_checkbox = gr.components.Checkbox(
304
+ label="Load 8-bit [requires support]",
305
+ value=kwargs['load_8bit'])
306
+ model_infer_devices_checkbox = gr.components.Checkbox(
307
+ label="Choose Devices [If not Checked, use all GPUs]",
308
+ value=kwargs['infer_devices'])
309
+ model_gpu = gr.Dropdown(n_gpus_list,
310
+ label="GPU ID 2 [-1 = all GPUs, if Choose is enabled]",
311
+ value=kwargs['gpu_id'])
312
+ model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'])
313
+ lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
314
+ visible=kwargs['show_lora'])
315
+ with gr.Row():
316
+ with gr.Column(scale=50):
317
+ new_model = gr.Textbox(label="New Model HF name/path")
318
+ new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
319
+ with gr.Column(scale=1):
320
+ add_model_button = gr.Button("Add new model name")
321
+ add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora'])
322
+ col_model2 = gr.Column(visible=False)
323
+ with col_model2:
324
+ with gr.Row():
325
+ with gr.Column(scale=50):
326
+ model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
327
+ value=no_model_str)
328
+ lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
329
+ value=no_lora_str,
330
+ visible=kwargs['show_lora'])
331
+ with gr.Column(scale=1):
332
+ load_model_button2 = gr.Button(load_msg2)
333
+ model_load8bit_checkbox2 = gr.components.Checkbox(
334
+ label="Load 8-bit 2 [requires support]",
335
+ value=kwargs['load_8bit'])
336
+ model_infer_devices_checkbox2 = gr.components.Checkbox(
337
+ label="Choose Devices 2 [If not Checked, use all GPUs]",
338
+ value=kwargs[
339
+ 'infer_devices'])
340
+ model_gpu2 = gr.Dropdown(n_gpus_list,
341
+ label="GPU ID [-1 = all GPUs, if choose is enabled]",
342
+ value=kwargs['gpu_id'])
343
+ # no model/lora loaded ever in model2 by default
344
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
345
+ lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
346
+ visible=kwargs['show_lora'])
347
+ with gr.TabItem("System"):
348
+ admin_row = gr.Row()
349
+ with admin_row:
350
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
351
+ admin_btn = gr.Button(value="Admin Access", visible=is_public)
352
+ system_row = gr.Row(visible=not is_public)
353
+ with system_row:
354
+ with gr.Column():
355
+ with gr.Row():
356
+ system_btn = gr.Button(value='Get System Info')
357
+ system_text = gr.Textbox(label='System Info')
358
+
359
+ with gr.Row():
360
+ zip_btn = gr.Button("Zip")
361
+ zip_text = gr.Textbox(label="Zip file name")
362
+ file_output = gr.File()
363
+ with gr.Row():
364
+ s3up_btn = gr.Button("S3UP")
365
+ s3up_text = gr.Textbox(label='S3UP result')
366
+
367
+ # Get flagged data
368
+ zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
369
+ zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text])
370
+ s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text)
371
+
372
+ def check_admin_pass(x):
373
+ return gr.update(visible=x == admin_pass)
374
+
375
+ def close_admin(x):
376
+ return gr.update(visible=not (x == admin_pass))
377
+
378
+ admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row) \
379
+ .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row)
380
+
381
+ # Get inputs to evaluate()
382
+ all_kwargs = kwargs.copy()
383
+ all_kwargs.update(locals())
384
+ inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
385
+ from functools import partial
386
+ kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list}
387
+ # ensure present
388
+ for k in inputs_kwargs_list:
389
+ assert k in kwargs_evaluate, "Missing %s" % k
390
+ fun = partial(evaluate,
391
+ **kwargs_evaluate)
392
+ fun2 = partial(evaluate,
393
+ **kwargs_evaluate)
394
+
395
+ dark_mode_btn = gr.Button("Dark Mode", variant="primary").style(
396
+ size="sm",
397
+ )
398
+ dark_mode_btn.click(
399
+ None,
400
+ None,
401
+ None,
402
+ _js=get_dark_js(),
403
+ api_name="dark" if allow_api else None,
404
+ )
405
+
406
+ # Control chat and non-chat blocks, which can be independently used by chat checkbox swap
407
+ def col_nochat_fun(x):
408
+ return gr.Column.update(visible=not x)
409
+
410
+ def col_chat_fun(x):
411
+ return gr.Column.update(visible=x)
412
+
413
+ def context_fun(x):
414
+ return gr.Textbox.update(visible=not x)
415
+
416
+ chat.select(col_nochat_fun, chat, col_nochat, api_name="chat_checkbox" if allow_api else None) \
417
+ .then(col_chat_fun, chat, col_chat) \
418
+ .then(context_fun, chat, context)
419
+
420
+ # examples after submit or any other buttons for chat or no chat
421
+ if kwargs['examples'] is not None and kwargs['show_examples']:
422
+ gr.Examples(examples=kwargs['examples'], inputs=inputs_list)
423
+
424
+ # Score
425
+ def score_last_response(*args, nochat=False, model2=False):
426
+ """ Similar to user() """
427
+ args_list = list(args)
428
+
429
+ max_length_tokenize = 512 if is_low_mem else 2048
430
+ cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM
431
+ smodel = score_model_state0[0]
432
+ stokenizer = score_model_state0[1]
433
+ sdevice = score_model_state0[2]
434
+ if not nochat:
435
+ history = args_list[-1]
436
+ if history is None:
437
+ if not model2:
438
+ # maybe only doing first model, no need to complain
439
+ print("Bad history in scoring last response, fix for now", flush=True)
440
+ history = []
441
+ if smodel is not None and \
442
+ stokenizer is not None and \
443
+ sdevice is not None and \
444
+ history is not None and len(history) > 0 and \
445
+ history[-1] is not None and \
446
+ len(history[-1]) >= 2:
447
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
448
+
449
+ question = history[-1][0]
450
+
451
+ answer = history[-1][1]
452
+ else:
453
+ return 'Response Score: NA'
454
+ else:
455
+ answer = args_list[-1]
456
+ instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat')
457
+ question = args_list[instruction_nochat_arg_id]
458
+
459
+ if question is None:
460
+ return 'Response Score: Bad Question'
461
+ if answer is None:
462
+ return 'Response Score: Bad Answer'
463
+ score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len)
464
+ if isinstance(score, str):
465
+ return 'Response Score: NA'
466
+ return 'Response Score: {:.1%}'.format(score)
467
+
468
+ def noop_score_last_response(*args, **kwargs):
469
+ return "Response Score: Disabled"
470
+
471
+ if kwargs['score_model']:
472
+ score_fun = score_last_response
473
+ else:
474
+ score_fun = noop_score_last_response
475
+
476
+ score_args = dict(fn=score_fun,
477
+ inputs=inputs_list + [text_output],
478
+ outputs=[score_text],
479
+ )
480
+ score_args2 = dict(fn=partial(score_fun, model2=True),
481
+ inputs=inputs_list + [text_output2],
482
+ outputs=[score_text2],
483
+ )
484
+
485
+ score_args_nochat = dict(fn=partial(score_fun, nochat=True),
486
+ inputs=inputs_list + [text_output_nochat],
487
+ outputs=[score_text_nochat],
488
+ )
489
+ if not kwargs['auto_score']:
490
+ score_event = score_btn.click(**score_args, queue=queue, api_name='score' if allow_api else None) \
491
+ .then(**score_args2, queue=queue, api_name='score2' if allow_api else None)
492
+ score_event_nochat = score_btn_nochat.click(**score_args_nochat, queue=queue,
493
+ api_name='score_nochat' if allow_api else None)
494
+
495
+ def user(*args, undo=False, sanitize_user_prompt=True, model2=False):
496
+ """
497
+ User that fills history for bot
498
+ :param args:
499
+ :param undo:
500
+ :param sanitize_user_prompt:
501
+ :param model2:
502
+ :return:
503
+ """
504
+ args_list = list(args)
505
+ user_message = args_list[0]
506
+ input1 = args_list[1]
507
+ context1 = args_list[2]
508
+ if input1 and not user_message.endswith(':'):
509
+ user_message1 = user_message + ":" + input1
510
+ elif input1:
511
+ user_message1 = user_message + input1
512
+ else:
513
+ user_message1 = user_message
514
+ if sanitize_user_prompt:
515
+ from better_profanity import profanity
516
+ user_message1 = profanity.censor(user_message1)
517
+
518
+ history = args_list[-1]
519
+ if undo and history:
520
+ history.pop()
521
+ args_list = args_list[:-1] # FYI, even if unused currently
522
+ if history is None:
523
+ if not model2:
524
+ # no need to complain so often unless model1
525
+ print("Bad history, fix for now", flush=True)
526
+ history = []
527
+ # ensure elements not mixed across models as output,
528
+ # even if input is currently same source
529
+ history = history.copy()
530
+ if undo:
531
+ return history
532
+ else:
533
+ # FIXME: compare, same history for now
534
+ return history + [[user_message1, None]]
535
+
536
+ def bot(*args, retry=False):
537
+ """
538
+ bot that consumes history for user input
539
+ instruction (from input_list) itself is not consumed by bot
540
+ :param args:
541
+ :param retry:
542
+ :return:
543
+ """
544
+ args_list = list(args).copy()
545
+ history = args_list[-1] # model_state is -2
546
+ if retry and history:
547
+ history.pop()
548
+ if not history:
549
+ print("No history", flush=True)
550
+ return
551
+ # ensure output will be unique to models
552
+ history = history.copy()
553
+ instruction1 = history[-1][0]
554
+ context1 = ''
555
+ if kwargs['chat_history'] > 0:
556
+ prompt_type_arg_id = eval_func_param_names.index('prompt_type')
557
+ prompt_type1 = args_list[prompt_type_arg_id]
558
+ chat_arg_id = eval_func_param_names.index('chat')
559
+ chat1 = args_list[chat_arg_id]
560
+ context1 = ''
561
+ for histi in range(len(history) - 1):
562
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
563
+ context1 += generate_prompt(data_point, prompt_type1, chat1, reduced=True)[0].replace(
564
+ '<br>', '\n')
565
+ if not context1.endswith('\n'):
566
+ context1 += '\n'
567
+ if context1 and not context1.endswith('\n'):
568
+ context1 += '\n' # ensure if terminates abruptly, then human continues on next line
569
+ args_list[0] = instruction1 # override original instruction with history from user
570
+ # only include desired chat history
571
+ args_list[2] = context1[-kwargs['chat_history']:]
572
+ model_state1 = args_list[-2]
573
+ if model_state1[0] is None or model_state1[0] == no_model_str:
574
+ return
575
+ args_list = args_list[:-2]
576
+ fun1 = partial(evaluate,
577
+ model_state1,
578
+ **kwargs_evaluate)
579
+ try:
580
+ for output in fun1(*tuple(args_list)):
581
+ bot_message = output
582
+ history[-1][1] = bot_message
583
+ yield history
584
+ except StopIteration:
585
+ yield history
586
+ except RuntimeError as e:
587
+ if "generator raised StopIteration" in str(e):
588
+ # assume last entry was bad, undo
589
+ history.pop()
590
+ yield history
591
+ raise
592
+ except Exception as e:
593
+ # put error into user input
594
+ history[-1][0] = "Exception: %s" % str(e)
595
+ yield history
596
+ raise
597
+ return
598
+
599
+ # NORMAL MODEL
600
+ user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
601
+ inputs=inputs_list + [text_output],
602
+ outputs=text_output,
603
+ )
604
+ bot_args = dict(fn=bot,
605
+ inputs=inputs_list + [model_state] + [text_output],
606
+ outputs=text_output,
607
+ )
608
+ retry_bot_args = dict(fn=functools.partial(bot, retry=True),
609
+ inputs=inputs_list + [model_state] + [text_output],
610
+ outputs=text_output,
611
+ )
612
+ undo_user_args = dict(fn=functools.partial(user, undo=True),
613
+ inputs=inputs_list + [text_output],
614
+ outputs=text_output,
615
+ )
616
+
617
+ # MODEL2
618
+ user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
619
+ inputs=inputs_list + [text_output2],
620
+ outputs=text_output2,
621
+ )
622
+ bot_args2 = dict(fn=bot,
623
+ inputs=inputs_list + [model_state2] + [text_output2],
624
+ outputs=text_output2,
625
+ )
626
+ retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
627
+ inputs=inputs_list + [model_state2] + [text_output2],
628
+ outputs=text_output2,
629
+ )
630
+ undo_user_args2 = dict(fn=functools.partial(user, undo=True),
631
+ inputs=inputs_list + [text_output2],
632
+ outputs=text_output2,
633
+ )
634
+
635
+ def clear_instruct():
636
+ return gr.Textbox.update(value='')
637
+
638
+ if kwargs['auto_score']:
639
+ # in case 2nd model, consume instruction first, so can clear quickly
640
+ # bot doesn't consume instruction itself, just history from user, so why works
641
+ submit_event = instruction.submit(**user_args, queue=queue,
642
+ api_name='instruction' if allow_api else None) \
643
+ .then(**user_args2, api_name='instruction2' if allow_api else None) \
644
+ .then(clear_instruct, None, instruction) \
645
+ .then(clear_instruct, None, iinput) \
646
+ .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
647
+ .then(**score_args, api_name='instruction_bot_score' if allow_api else None, queue=queue) \
648
+ .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
649
+ .then(**score_args2, api_name='instruction_bot_score2' if allow_api else None, queue=queue) \
650
+ .then(clear_torch_cache)
651
+ submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
652
+ .then(**user_args2, api_name='submit2' if allow_api else None) \
653
+ .then(clear_instruct, None, instruction) \
654
+ .then(clear_instruct, None, iinput) \
655
+ .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
656
+ .then(**score_args, api_name='submit_bot_score' if allow_api else None, queue=queue) \
657
+ .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
658
+ .then(**score_args2, api_name='submit_bot_score2' if allow_api else None, queue=queue) \
659
+ .then(clear_torch_cache)
660
+ submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
661
+ .then(**user_args2, api_name='retry2' if allow_api else None) \
662
+ .then(clear_instruct, None, instruction) \
663
+ .then(clear_instruct, None, iinput) \
664
+ .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
665
+ .then(**score_args, api_name='retry_bot_score' if allow_api else None, queue=queue) \
666
+ .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
667
+ .then(**score_args2, api_name='retry_bot_score2' if allow_api else None, queue=queue) \
668
+ .then(clear_torch_cache)
669
+ submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
670
+ .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
671
+ .then(clear_instruct, None, instruction) \
672
+ .then(clear_instruct, None, iinput) \
673
+ .then(**score_args, api_name='undo_score' if allow_api else None) \
674
+ .then(**score_args2, api_name='undo_score2' if allow_api else None)
675
+ else:
676
+ submit_event = instruction.submit(**user_args,
677
+ api_name='instruction' if allow_api else None) \
678
+ .then(**user_args2, api_name='instruction2' if allow_api else None) \
679
+ .then(clear_instruct, None, instruction) \
680
+ .then(clear_instruct, None, iinput) \
681
+ .then(**bot_args, api_name='instruction_bot' if allow_api else None, queue=queue) \
682
+ .then(**bot_args2, api_name='instruction_bot2' if allow_api else None, queue=queue) \
683
+ .then(clear_torch_cache)
684
+ submit_event2 = submit.click(**user_args, api_name='submit' if allow_api else None) \
685
+ .then(**user_args2, api_name='submit2' if allow_api else None) \
686
+ .then(clear_instruct, None, instruction) \
687
+ .then(clear_instruct, None, iinput) \
688
+ .then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) \
689
+ .then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue) \
690
+ .then(clear_torch_cache)
691
+ submit_event3 = retry.click(**user_args, api_name='retry' if allow_api else None) \
692
+ .then(**user_args2, api_name='retry2' if allow_api else None) \
693
+ .then(clear_instruct, None, instruction) \
694
+ .then(clear_instruct, None, iinput) \
695
+ .then(**retry_bot_args, api_name='retry_bot' if allow_api else None, queue=queue) \
696
+ .then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, queue=queue) \
697
+ .then(clear_torch_cache)
698
+ submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
699
+ .then(**undo_user_args2, api_name='undo2' if allow_api else None)
700
+
701
+ # does both models
702
+ clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
703
+ .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
704
+ # NOTE: clear of instruction/iinput for nochat has to come after score,
705
+ # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
706
+ submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
707
+ outputs=text_output_nochat,
708
+ queue=queue,
709
+ api_name='submit_nochat' if allow_api else None) \
710
+ .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \
711
+ .then(clear_instruct, None, instruction_nochat) \
712
+ .then(clear_instruct, None, iinput_nochat) \
713
+ .then(clear_torch_cache)
714
+
715
+ def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id):
716
+ # ensure old model removed from GPU memory
717
+ if kwargs['debug']:
718
+ print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True)
719
+
720
+ model0 = model_state0[0]
721
+ if isinstance(model_state_old[0], str) and model0 is not None:
722
+ # best can do, move model loaded at first to CPU
723
+ model0.cpu()
724
+
725
+ if model_state_old[0] is not None and not isinstance(model_state_old[0], str):
726
+ try:
727
+ model_state_old[0].cpu()
728
+ except Exception as e:
729
+ # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data!
730
+ print("Unable to put model on CPU: %s" % str(e), flush=True)
731
+ del model_state_old[0]
732
+ model_state_old[0] = None
733
+
734
+ if model_state_old[1] is not None and not isinstance(model_state_old[1], str):
735
+ del model_state_old[1]
736
+ model_state_old[1] = None
737
+
738
+ clear_torch_cache()
739
+ if kwargs['debug']:
740
+ print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True)
741
+
742
+ if model_name is None or model_name == no_model_str:
743
+ # no-op if no model, just free memory
744
+ # no detranscribe needed for model, never go into evaluate
745
+ lora_weights = no_lora_str
746
+ return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
747
+
748
+ all_kwargs1 = all_kwargs.copy()
749
+ all_kwargs1['base_model'] = model_name.strip()
750
+ all_kwargs1['load_8bit'] = load_8bit
751
+ all_kwargs1['infer_devices'] = infer_devices
752
+ all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe
753
+ model_lower = model_name.strip().lower()
754
+ if model_lower in inv_prompt_type_to_model_lower:
755
+ prompt_type1 = inv_prompt_type_to_model_lower[model_lower]
756
+ else:
757
+ prompt_type1 = prompt_type_old
758
+
759
+ # detranscribe
760
+ if lora_weights == no_lora_str:
761
+ lora_weights = ''
762
+
763
+ all_kwargs1['lora_weights'] = lora_weights.strip()
764
+ model1, tokenizer1, device1 = get_model(**all_kwargs1)
765
+ clear_torch_cache()
766
+
767
+ if kwargs['debug']:
768
+ print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True)
769
+ return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1
770
+
771
+ def dropdown_prompt_type_list(x):
772
+ return gr.Dropdown.update(value=x)
773
+
774
+ def chatbot_list(x, model_used_in):
775
+ return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]')
776
+
777
+ load_model_args = dict(fn=load_model,
778
+ inputs=[model_choice, lora_choice, model_state, prompt_type,
779
+ model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu],
780
+ outputs=[model_state, model_used, lora_used, prompt_type])
781
+ prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type)
782
+ chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output)
783
+ nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat)
784
+ if not is_public:
785
+ load_model_event = load_model_button.click(**load_model_args) \
786
+ .then(**prompt_update_args) \
787
+ .then(**chatbot_update_args) \
788
+ .then(**nochat_update_args) \
789
+ .then(clear_torch_cache)
790
+
791
+ load_model_args2 = dict(fn=load_model,
792
+ inputs=[model_choice2, lora_choice2, model_state2, prompt_type2,
793
+ model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2],
794
+ outputs=[model_state2, model_used2, lora_used2, prompt_type2])
795
+ prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2)
796
+ chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2)
797
+ if not is_public:
798
+ load_model_event2 = load_model_button2.click(**load_model_args2) \
799
+ .then(**prompt_update_args2) \
800
+ .then(**chatbot_update_args2) \
801
+ .then(clear_torch_cache)
802
+
803
+ def dropdown_model_list(list0, x):
804
+ new_state = [list0[0] + [x]]
805
+ new_options = [*new_state[0]]
806
+ return gr.Dropdown.update(value=x, choices=new_options), \
807
+ gr.Dropdown.update(value=x, choices=new_options), \
808
+ '', new_state
809
+
810
+ add_model_event = add_model_button.click(fn=dropdown_model_list,
811
+ inputs=[model_options_state, new_model],
812
+ outputs=[model_choice, model_choice2, new_model, model_options_state])
813
+
814
+ def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2):
815
+ new_state = [list0[0] + [x]]
816
+ new_options = [*new_state[0]]
817
+ # don't switch drop-down to added lora if already have model loaded
818
+ x1 = x if model_used1 == no_model_str else lora_used1
819
+ x2 = x if model_used2 == no_model_str else lora_used2
820
+ return gr.Dropdown.update(value=x1, choices=new_options), \
821
+ gr.Dropdown.update(value=x2, choices=new_options), \
822
+ '', new_state
823
+
824
+ add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
825
+ inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
826
+ lora_used2],
827
+ outputs=[lora_choice, lora_choice2, new_lora, lora_options_state])
828
+
829
+ go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None) \
830
+ .then(lambda: gr.update(visible=True), None, normal_block) \
831
+ .then(**load_model_args).then(**prompt_update_args)
832
+
833
+ def compare_textbox_fun(x):
834
+ return gr.Textbox.update(visible=x)
835
+
836
+ def compare_column_fun(x):
837
+ return gr.Column.update(visible=x)
838
+
839
+ def compare_prompt_fun(x):
840
+ return gr.Dropdown.update(visible=x)
841
+
842
+ compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2,
843
+ api_name="compare_checkbox" if allow_api else None) \
844
+ .then(compare_column_fun, compare_checkbox, col_model2) \
845
+ .then(compare_prompt_fun, compare_checkbox, prompt_type2) \
846
+ .then(compare_textbox_fun, compare_checkbox, score_text2)
847
+ # FIXME: add score_res2 in condition, but do better
848
+
849
+ # callback for logging flagged input/output
850
+ callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points")
851
+ flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None,
852
+ preprocess=False,
853
+ api_name='flag' if allow_api else None)
854
+ flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None,
855
+ preprocess=False,
856
+ api_name='flag_nochat' if allow_api else None)
857
+
858
+ def get_system_info():
859
+ return gr.Textbox.update(value=system_info_print())
860
+
861
+ system_event = system_btn.click(get_system_info, outputs=system_text,
862
+ api_name='system_info' if allow_api else None)
863
+
864
+ # don't pass text_output, don't want to clear output, just stop it
865
+ # FIXME: have to click once to stop output and second time to stop GPUs going
866
+ stop_btn.click(lambda: None, None, None,
867
+ cancels=[submit_event_nochat, submit_event, submit_event2, submit_event3],
868
+ queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache)
869
+ demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None)
870
+
871
+ demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
872
+ favicon_path = "h2o-logo.svg"
873
+
874
+ scheduler = BackgroundScheduler()
875
+ scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
876
+ if is_public:
877
+ scheduler.add_job(func=ping, trigger="interval", seconds=60)
878
+ scheduler.start()
879
+
880
+ demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
881
+ favicon_path=favicon_path, prevent_thread_lock=True) # , enable_queue=True)
882
+ print("Started GUI", flush=True)
883
+ if kwargs['block_gradio_exit']:
884
+ demo.block_thread()
885
+
886
+
887
+ input_args_list = ['model_state']
888
+ inputs_kwargs_list = ['debug', 'save_dir', 'hard_stop_list', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
889
+ 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count']
890
+
891
+
892
+ def get_inputs_list(inputs_dict, model_lower):
893
+ """
894
+ map gradio objects in locals() to inputs for evaluate().
895
+ :param inputs_dict:
896
+ :param model_lower:
897
+ :return:
898
+ """
899
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
900
+ inputs_list = []
901
+ for k in inputs_list_names:
902
+ if k == 'kwargs':
903
+ continue
904
+ if k in input_args_list + inputs_kwargs_list:
905
+ # these are added via partial, not taken as input
906
+ continue
907
+ if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
908
+ continue
909
+ inputs_list.append(inputs_dict[k])
910
+ return inputs_list
gradio_themes.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from gradio.themes.soft import Soft
3
+ from gradio.themes.utils import Color, colors, sizes
4
+
5
+ h2o_yellow = Color(
6
+ name="yellow",
7
+ c50="#fffef2",
8
+ c100="#fff9e6",
9
+ c200="#ffecb3",
10
+ c300="#ffe28c",
11
+ c400="#ffd659",
12
+ c500="#fec925",
13
+ c600="#e6ac00",
14
+ c700="#bf8f00",
15
+ c800="#a67c00",
16
+ c900="#664d00",
17
+ c950="#403000",
18
+ )
19
+ h2o_gray = Color(
20
+ name="gray",
21
+ c50="#f8f8f8",
22
+ c100="#e5e5e5",
23
+ c200="#cccccc",
24
+ c300="#b2b2b2",
25
+ c400="#999999",
26
+ c500="#7f7f7f",
27
+ c600="#666666",
28
+ c700="#4c4c4c",
29
+ c800="#333333",
30
+ c900="#191919",
31
+ c950="#0d0d0d",
32
+ )
33
+
34
+
35
+ class H2oTheme(Soft):
36
+ def __init__(
37
+ self,
38
+ *,
39
+ primary_hue: colors.Color | str = h2o_yellow,
40
+ secondary_hue: colors.Color | str = h2o_yellow,
41
+ neutral_hue: colors.Color | str = h2o_gray,
42
+ spacing_size: sizes.Size | str = sizes.spacing_md,
43
+ radius_size: sizes.Size | str = sizes.radius_md,
44
+ text_size: sizes.Size | str = sizes.text_lg,
45
+ ):
46
+ super().__init__(
47
+ primary_hue=primary_hue,
48
+ secondary_hue=secondary_hue,
49
+ neutral_hue=neutral_hue,
50
+ spacing_size=spacing_size,
51
+ radius_size=radius_size,
52
+ text_size=text_size,
53
+ )
54
+ super().set(
55
+ link_text_color="#3344DD",
56
+ link_text_color_hover="#3344DD",
57
+ link_text_color_visited="#3344DD",
58
+ link_text_color_dark="#74abff",
59
+ link_text_color_hover_dark="#a3c8ff",
60
+ link_text_color_active_dark="#a3c8ff",
61
+ link_text_color_visited_dark="#74abff",
62
+ button_primary_text_color="*neutral_950",
63
+ button_primary_text_color_dark="*neutral_950",
64
+ button_primary_background_fill="*primary_500",
65
+ button_primary_background_fill_dark="*primary_500",
66
+ block_label_background_fill="*primary_500",
67
+ block_label_background_fill_dark="*primary_500",
68
+ block_label_text_color="*neutral_950",
69
+ block_label_text_color_dark="*neutral_950",
70
+ block_title_text_color="*neutral_950",
71
+ block_title_text_color_dark="*neutral_950",
72
+ block_background_fill_dark="*neutral_950",
73
+ body_background_fill="*neutral_50",
74
+ body_background_fill_dark="*neutral_900",
75
+ background_fill_primary_dark="*block_background_fill",
76
+ block_radius="0 0 8px 8px",
77
+ )
78
+
79
+
80
+ class SoftTheme(Soft):
81
+ def __init__(
82
+ self,
83
+ *,
84
+ primary_hue: colors.Color | str = colors.indigo,
85
+ secondary_hue: colors.Color | str = colors.indigo,
86
+ neutral_hue: colors.Color | str = colors.gray,
87
+ spacing_size: sizes.Size | str = sizes.spacing_md,
88
+ radius_size: sizes.Size | str = sizes.radius_md,
89
+ text_size: sizes.Size | str = sizes.text_md,
90
+ ):
91
+ super().__init__(
92
+ primary_hue=primary_hue,
93
+ secondary_hue=secondary_hue,
94
+ neutral_hue=neutral_hue,
95
+ spacing_size=spacing_size,
96
+ radius_size=radius_size,
97
+ text_size=text_size,
98
+ )
99
+
100
+
101
+ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"' \
102
+ ' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:' \
103
+ '#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" ' \
104
+ 'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.' \
105
+ '47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.' \
106
+ '82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10' \
107
+ '.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"' \
108
+ '/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.' \
109
+ '76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69' \
110
+ ',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.' \
111
+ '85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.' \
112
+ '69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30.' \
113
+ '62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8.' \
114
+ '62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-' \
115
+ '12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"' \
116
+ ' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,' \
117
+ '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
118
+
119
+
120
+ def get_h2o_title(title):
121
+ return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
122
+ <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
123
+ <h1 style="line-height:60px">{title}</h1>
124
+ </div>
125
+ <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
126
+ <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/h2o-qr.png></img>
127
+ </div>
128
+ """
129
+
130
+
131
+ def get_simple_title(title):
132
+ return f"""<h1 align="center"> {title}</h1>"""
133
+
134
+
135
+ def get_dark_js():
136
+ return """() => {
137
+ if (document.querySelectorAll('.dark').length) {
138
+ document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
139
+ } else {
140
+ document.querySelector('body').classList.add('dark');
141
+ }
142
+ }"""
prompter.py CHANGED
@@ -71,7 +71,8 @@ class Prompter(object):
71
  output = output.split(self.pre_response)[1]
72
  allow_terminate = True
73
  else:
74
- print("Failure of parsing: %s" % output, flush=True)
 
75
  allow_terminate = False
76
  else:
77
  allow_terminate = True
 
71
  output = output.split(self.pre_response)[1]
72
  allow_terminate = True
73
  else:
74
+ if output:
75
+ print("Failure of parsing or not enough output yet: %s" % output, flush=True)
76
  allow_terminate = False
77
  else:
78
  allow_terminate = True
requirements.txt CHANGED
@@ -19,9 +19,10 @@ pandas==2.0.0
19
  matplotlib==3.7.1
20
  loralib==0.1.1
21
  bitsandbytes==0.38.1
22
- git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
23
  transformers==4.28.1
24
  tokenizers==0.13.3
 
25
 
26
  # optional for generate
27
  pynvml==11.5.0
 
19
  matplotlib==3.7.1
20
  loralib==0.1.1
21
  bitsandbytes==0.38.1
22
+ git+https://github.com/huggingface/peft.git@e8f66b8a425eced6c592089d40b8d33d82c2b2f0
23
  transformers==4.28.1
24
  tokenizers==0.13.3
25
+ APScheduler==3.10.1
26
 
27
  # optional for generate
28
  pynvml==11.5.0
stopping.py CHANGED
@@ -9,11 +9,11 @@ from transformers import StoppingCriteria
9
 
10
  class StoppingCriteriaSub(StoppingCriteria):
11
 
12
- def __init__(self, stops=[], encounters=[]):
13
  super().__init__()
14
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
  self.encounters = encounters
16
- self.stops = [stop.to("cuda") for stop in stops]
17
  self.num_stops = [0] * len(stops)
18
 
19
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
@@ -25,115 +25,3 @@ class StoppingCriteriaSub(StoppingCriteria):
25
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
  return False
28
-
29
-
30
- class Stream(StoppingCriteria):
31
- """
32
- This class can be used to callback during generation. Keep
33
- in mind for decoder-only type of transformers, this will include the initial prompted tokens.
34
-
35
- Args:
36
- func (`callable`):
37
- A callable function to apply on first input in list every iteration of generation
38
- """
39
-
40
- def __init__(self, func=None):
41
- self.func = func
42
-
43
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
- if self.func is not None:
45
- # only consume first of multiple responses
46
- self.func(input_ids[0])
47
- return False
48
-
49
-
50
- class CallbackToGenerator(collections.abc.Generator):
51
- """
52
- A generator wrapper for a function that invokes a callback multiple times.
53
-
54
- Calling `send` on the generator emits a value from one callback, and returns
55
- the next.
56
-
57
- Note this starts a background thread
58
- """
59
-
60
- def __init__(self, func, *args, callback=None, **kwargs):
61
- self.func = func
62
- self.args = args
63
- self.kwargs = kwargs
64
- self.callback = callback
65
-
66
- self._ready_queue = Queue(1)
67
- self._done_queue = Queue(1)
68
- self._done_holder = [False]
69
-
70
- # local to avoid reference cycles
71
- ready_queue = self._ready_queue
72
- done_queue = self._done_queue
73
- done_holder = self._done_holder
74
-
75
- def val_callback(value):
76
- done_queue.put((False, value))
77
- cmd, val = ready_queue.get()
78
- if cmd == 'send':
79
- return val
80
- elif cmd == 'throw':
81
- raise val
82
- else:
83
- assert False # pragma: no cover
84
-
85
- def thread_func():
86
- while True:
87
- cmd, val = ready_queue.get()
88
- if cmd == 'send' and val is not None:
89
- done_queue.put((True, TypeError("can't send non-None value to a just-started generator")))
90
- continue
91
- break
92
- try:
93
- if cmd == 'throw':
94
- raise val
95
- ret = func(callback=val_callback, **self.kwargs)
96
- raise StopIteration(ret) if ret is not None else StopIteration
97
- except BaseException as e:
98
- done_holder[0] = True
99
- done_queue.put((True, e))
100
-
101
- self._thread = Thread(target=thread_func)
102
- self._thread.start()
103
-
104
- def _put(self, *args):
105
- if self._done_holder[0]:
106
- raise StopIteration
107
- self._ready_queue.put(args)
108
- is_exception, val = self._done_queue.get()
109
- if is_exception:
110
- try:
111
- raise val
112
- finally:
113
- # prevent val's traceback containing a reference cycle
114
- del val
115
- else:
116
- return val
117
-
118
- def send(self, value):
119
- return self._put('send', value)
120
-
121
- def throw(self, exc):
122
- return self._put('throw', exc)
123
-
124
- def close(self):
125
- try:
126
- self.throw(GeneratorExit)
127
- except StopIteration:
128
- self._thread.join()
129
- except GeneratorExit:
130
- self._thread.join()
131
- except BaseException:
132
- self._thread.join()
133
- raise
134
- else:
135
- # yielded again, can't clean up the thread
136
- raise RuntimeError('Task with callback ignored GeneratorExit')
137
-
138
- def __del__(self):
139
- self.close()
 
9
 
10
  class StoppingCriteriaSub(StoppingCriteria):
11
 
12
+ def __init__(self, stops=[], encounters=[], device="cuda"):
13
  super().__init__()
14
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
15
  self.encounters = encounters
16
+ self.stops = [stop.to(device) for stop in stops]
17
  self.num_stops = [0] * len(stops)
18
 
19
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
 
25
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
26
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
27
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -1,6 +1,12 @@
 
1
  import os
2
  import gc
 
3
  import random
 
 
 
 
4
  import time
5
  import traceback
6
  import zipfile
@@ -8,7 +14,6 @@ from datetime import datetime
8
  import filelock
9
  import numpy as np
10
  import pandas as pd
11
- import torch
12
 
13
 
14
  def set_seed(seed: int):
@@ -16,6 +21,7 @@ def set_seed(seed: int):
16
  Sets the seed of the entire notebook so results are the same every time we run.
17
  This is for REPRODUCIBILITY.
18
  """
 
19
  np.random.seed(seed)
20
  random_state = np.random.RandomState(seed)
21
  random.seed(seed)
@@ -39,12 +45,22 @@ def flatten_list(lis):
39
 
40
 
41
  def clear_torch_cache():
42
- if torch.cuda.is_available:
 
43
  torch.cuda.empty_cache()
44
  torch.cuda.ipc_collect()
45
  gc.collect()
46
 
47
 
 
 
 
 
 
 
 
 
 
48
  def system_info():
49
  import psutil
50
 
@@ -184,3 +200,62 @@ def _s3up(filename):
184
  )
185
  if ret in [None, '']:
186
  return "Successfully uploaded %s" % filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
  import os
3
  import gc
4
+ import pathlib
5
  import random
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ import threading
10
  import time
11
  import traceback
12
  import zipfile
 
14
  import filelock
15
  import numpy as np
16
  import pandas as pd
 
17
 
18
 
19
  def set_seed(seed: int):
 
21
  Sets the seed of the entire notebook so results are the same every time we run.
22
  This is for REPRODUCIBILITY.
23
  """
24
+ import torch
25
  np.random.seed(seed)
26
  random_state = np.random.RandomState(seed)
27
  random.seed(seed)
 
45
 
46
 
47
  def clear_torch_cache():
48
+ import torch
49
+ if torch.cuda.is_available():
50
  torch.cuda.empty_cache()
51
  torch.cuda.ipc_collect()
52
  gc.collect()
53
 
54
 
55
+ def ping():
56
+ print('Ping: %s' % str(datetime.now()), flush=True)
57
+
58
+
59
+ def get_torch_allocated():
60
+ import torch
61
+ return torch.cuda.memory_allocated()
62
+
63
+
64
  def system_info():
65
  import psutil
66
 
 
200
  )
201
  if ret in [None, '']:
202
  return "Successfully uploaded %s" % filename
203
+
204
+
205
+ def get_githash():
206
+ try:
207
+ githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1]
208
+ except:
209
+ githash = ''
210
+ return githash
211
+
212
+
213
+ def copy_code(run_id):
214
+ """
215
+ copy code to track changes
216
+ :param run_id:
217
+ :return:
218
+ """
219
+ rnd_num = str(random.randint(0, 2 ** 31))
220
+ run_id = 'run_' + str(run_id)
221
+ os.makedirs(run_id, exist_ok=True)
222
+ me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__)
223
+ me_file = os.path.basename(__file__)
224
+ new_me = os.path.join(run_id, me_file + '_' + get_githash())
225
+ if os.path.isfile(new_me):
226
+ new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num)
227
+ shutil.copy(me_full, new_me)
228
+ else:
229
+ shutil.copy(me_full, new_me)
230
+
231
+
232
+ class NullContext(threading.local):
233
+ """No-op context manager, executes block without doing any additional processing.
234
+
235
+ Used as a stand-in if a particular block of code is only sometimes
236
+ used with a normal context manager:
237
+ """
238
+ def __init__(self, *args, **kwargs):
239
+ pass
240
+
241
+ def __enter__(self):
242
+ return self
243
+
244
+ def __exit__(self, exc_type, exc_value, exc_traceback):
245
+ self.finally_act()
246
+
247
+ def finally_act(self):
248
+ pass
249
+
250
+
251
+ def wrapped_partial(func, *args, **kwargs):
252
+ """
253
+ Give partial properties of normal function, like __name__ attribute etc.
254
+ :param func:
255
+ :param args:
256
+ :param kwargs:
257
+ :return:
258
+ """
259
+ partial_func = functools.partial(func, *args, **kwargs)
260
+ functools.update_wrapper(partial_func, func)
261
+ return partial_func