phi commited on
Commit
a572fd2
·
1 Parent(s): 5100e68

change files

Browse files
Files changed (2) hide show
  1. app.py +254 -133
  2. requirements.txt +1 -0
app.py CHANGED
@@ -28,10 +28,25 @@ from typing import List, Optional, Union, Dict, Tuple
28
  from tqdm.auto import tqdm
29
  from huggingface_hub import snapshot_download
30
 
31
- DEBUG = True
32
 
33
- if not DEBUG:
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
 
 
 
35
  # vllm import
36
  from vllm import LLM, SamplingParams
37
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
@@ -51,6 +66,22 @@ if not DEBUG:
51
  _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
52
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def hf_model_weights_iterator(
55
  model_name_or_path: str,
56
  cache_dir: Optional[str] = None,
@@ -208,26 +239,26 @@ def llama_load_weights(
208
  if "rotary_emb.inv_freq" in name:
209
  continue
210
 
211
- # if "embed_tokens" in name or "lm_head" in name:
212
- # param = state_dict[name]
213
- # # Consider padding in the vocab size.
214
- # padded_vocab_size = (param.shape[0] * tp_size)
215
- # # num_extra_rows = padded_vocab_size - self.config.vocab_size
216
- # num_extra_rows = padded_vocab_size - loaded_weight.size(0)
217
- # load_size = loaded_weight.size()
218
- # extra_rows = torch.empty(num_extra_rows,
219
- # loaded_weight.shape[1])
220
- # extra_rows = extra_rows.to(loaded_weight)
221
- # loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
222
- # if num_extra_rows > 0:
223
- # print(f'Add empty to {num_extra_rows} extra row for {name}')
224
- # print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
225
-
226
  if "embed_tokens" in name or "lm_head" in name:
227
  param = state_dict[name]
228
- load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank)
229
- loaded += 1
230
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  is_attention_weight = False
233
  for weight_name, shard_size, offset in attention_weight_specs:
@@ -428,29 +459,84 @@ class ChatBot(gr.Chatbot):
428
  ):
429
  x = super()._postprocess_chat_messages(chat_message)
430
  if isinstance(x, str):
431
- x = x.replace("\n", "<br>")
432
  return x
433
 
434
 
435
- def load_ckpt(ckpt_file: str) -> str:
436
- global llm
437
- status = "Failed"
438
- if not os.path.exists(ckpt_file):
439
- status = f"Failed - file not found: {ckpt_file}"
440
- elif not ckpt_file.endswith(".bin"):
441
- status = f"Failed - file not .bin: {ckpt_file}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  else:
443
- try:
444
- state_dict = torch.load(ckpt_file, map_location='cpu')
445
- print(f'loaded state_dict: {ckpt_file}')
446
- llm.llm_engine.workers[0].model.load_state_dict(state_dict)
447
- status = f'Success. Loaded {ckpt_file}'
448
- except Exception as e:
449
- status = f'Failed - {str(e)}'
450
- return status
 
 
 
 
 
 
 
 
451
 
452
 
453
 
 
 
454
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
455
  global llm
456
  assert llm is not None
@@ -466,7 +552,6 @@ def chat_response(message, history, temperature: float, max_tokens: int, system_
466
  sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
467
  gen = llm.generate(message, sampling_params)
468
  out = gen[0].outputs[0].text
469
- # print(f'{message}<<<{out}>>>')
470
  return f'{out}'
471
 
472
 
@@ -493,10 +578,6 @@ def _vllm_run_engine(self: Any, use_tqdm: bool = False) -> Dict[str, Any]:
493
  while self.llm_engine.has_unfinished_requests():
494
  step_outputs = self.llm_engine.step()
495
  for output in step_outputs:
496
- # if output.finished:
497
- # outputs.append(output)
498
- # if use_tqdm:
499
- # pbar.update(1)
500
  outputs[output.request_id] = output
501
  # outputs = sorted(outputs, key=lambda x: int(x.request_id))
502
  if len(outputs) > 0:
@@ -565,53 +646,71 @@ def vllm_generate_stream(
565
  yield from _vllm_run_engine(self, use_tqdm)
566
 
567
 
568
- def chat_response_stream(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
569
  message: str,
570
- history: List[Tuple[str, str]],
571
- temperature: float,
572
- max_tokens: int,
573
- frequency_penalty: float,
574
- system_prompt: str
575
  ) -> str:
576
- global llm, RES_PRINTED
577
- assert llm is not None
578
- # force removing all
579
- vllm_abort(llm)
580
-
581
- temperature = float(temperature)
582
- frequency_penalty = float(frequency_penalty)
583
- max_tokens = int(max_tokens)
584
- if system_prompt.strip() != '':
585
- # chat version, add system prompt
586
- message = llama_chat_sys_input_seq_constructor(
587
- message.strip(),
588
- sys_prompt=system_prompt
589
- )
590
- sampling_params = SamplingParams(
591
- temperature=temperature, max_tokens=max_tokens,
592
- frequency_penalty=frequency_penalty,
593
- )
594
- cur_out = None
595
- for gen in vllm_generate_stream(llm, message, sampling_params):
596
- if cur_out is not None:
597
- yield cur_out
598
- assert len(gen) == 1, f'{gen}'
599
- item = next(iter(gen.values()))
600
- cur_out = item.outputs[0].text
601
- if not RES_PRINTED:
602
- print(f'{message}<<<{cur_out}>>>')
603
- RES_PRINTED = True
604
- if cur_out is not None:
605
- yield cur_out
606
-
607
 
 
608
  def chat_response_stream_multiturn(
609
  message: str,
610
  history: List[Tuple[str, str]],
611
  temperature: float,
612
  max_tokens: int,
613
  frequency_penalty: float,
614
- system_prompt: str
615
  ) -> str:
616
  """Build multi turn
617
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
@@ -631,27 +730,46 @@ def chat_response_stream_multiturn(
631
  frequency_penalty = float(frequency_penalty)
632
  max_tokens = int(max_tokens)
633
 
 
 
 
 
 
 
 
 
 
 
 
634
  # history.append([message, None])
635
  # history will be appended with message later on
636
  full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
637
  message, history, sys_prompt=system_prompt
638
  )
 
639
  sampling_params = SamplingParams(
640
  temperature=temperature, max_tokens=max_tokens,
641
  frequency_penalty=frequency_penalty,
642
  )
643
  cur_out = None
644
- for gen in vllm_generate_stream(llm, full_prompt, sampling_params):
645
- if cur_out is not None:
 
646
  yield cur_out
647
  assert len(gen) == 1, f'{gen}'
648
  item = next(iter(gen.values()))
649
  cur_out = item.outputs[0].text
650
- if not RES_PRINTED:
651
- print(f'{full_prompt}<<<{cur_out}>>>')
652
- RES_PRINTED = True
 
653
  if cur_out is not None:
654
  yield cur_out
 
 
 
 
 
655
 
656
 
657
  def debug_chat_response_echo(
@@ -662,16 +780,26 @@ def debug_chat_response_echo(
662
  frequency_penalty: float = 0.4,
663
  system_prompt: str = SYSTEM_PROMPT_1,
664
  ) -> str:
 
 
665
  yield f"repeat: {message}"
666
 
667
 
668
  # ============ CONSTANT ============
669
- MODEL_NAME = "DAMO-SeaL-13B"
670
- MODEL_TITLE = "DAMO-SeaL-13B - An Assistant for South East Asian Languages"
 
 
671
  MODEL_DESC = """
672
- This is a 13B DAMO-SeaL-Chat assistant model built by DAMO Academy, Alibaba Group. It can produce helpful responses in English, Vietnamese, Indonesian and Thai.
673
- <br>
674
- #### Citation
 
 
 
 
 
 
675
  If you find our project useful, hope you can star our repo and cite our paper as follows:
676
  ```
677
  @article{damonlpsg2023seallm,
@@ -680,22 +808,21 @@ If you find our project useful, hope you can star our repo and cite our paper as
680
  year = 2023,
681
  }
682
  ```
683
- """.strip()
684
-
685
-
686
- cite_markdown = """
687
  """
688
- # journal = {arXiv preprint arXiv:2306.02858}
689
- # url = {https://arxiv.org/abs/2306.02858}
690
 
 
 
 
691
 
692
- TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
693
- DTYPE = 'bfloat16'
694
- DTYPE = 'float16'
695
-
696
- MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`")
697
 
698
 
 
 
 
 
699
 
700
 
701
  def launch():
@@ -707,26 +834,29 @@ def launch():
707
  assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
708
  dtype = DTYPE
709
  sys_prompt = SYSTEM_PROMPT_1
710
- max_tokens = 4096
 
711
 
712
  if DEBUG:
713
  model_desc += "\n<br>!!!!! This is in debug mode, responses will be copy original"
714
  response_fn = debug_chat_response_echo
715
  else:
716
  # ! load the model
 
717
  assert os.path.exists(model_path), f'{model_path} not found'
 
 
718
  llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
719
 
720
  print(f'Use system prompt:\n{sys_prompt}')
721
 
722
- # response_fn = chat_response_stream_multiturn if args.multiturn else chat_response_stream
723
  response_fn = chat_response_stream_multiturn
724
  print(F'respond: {response_fn}')
725
 
726
  demo = gr.ChatInterface(
727
  response_fn,
728
  chatbot=ChatBot(
729
- # value=MODEL_NAME,
730
  bubble_full_width=False,
731
  latex_delimiters=[
732
  { "left": "$", "right": "$", "display": False},
@@ -735,7 +865,8 @@ def launch():
735
  ),
736
  textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
737
  submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
738
- # stop_btn=None,
 
739
  title=f"{model_title}",
740
  description=f"{model_desc}",
741
  # ! decide if can change the system prompt.
@@ -743,38 +874,16 @@ def launch():
743
  gr.Number(value=0, label='Temperature (higher -> more random)'),
744
  gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
745
  gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'),
746
- gr.Textbox(value=sys_prompt, label='System prompt', lines=8)],
 
747
  )
748
-
749
- # with gr.Blocks() as demo:
750
- # gr.ChatInterface(
751
- # response_fn,
752
- # chatbot=ChatBot(
753
- # bubble_full_width=False,
754
- # latex_delimiters=[
755
- # { "left": "$", "right": "$", "display": False},
756
- # { "left": "$$", "right": "$$", "display": True},
757
- # ]
758
- # ),
759
- # textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
760
- # submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
761
- # # stop_btn=None,
762
- # title=f"{model_title}",
763
- # description=f"{model_desc}",
764
- # # ! decide if can change the system prompt.
765
- # additional_inputs=[
766
- # gr.Number(value=0, label='Temperature (higher -> more random)'),
767
- # gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
768
- # gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'),
769
- # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
770
- # ],
771
- # )
772
-
773
- # gr.Markdown(cite_markdown)
774
 
775
  demo.queue()
776
- # demo.launch(server_port=args.port)
777
- demo.launch()
778
 
779
 
780
  def main():
@@ -793,7 +902,19 @@ export CUDA_VISIBLE_DEVICES=0
793
  export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
794
  export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
795
  export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
 
 
 
 
 
 
 
 
 
796
  python app.py
797
 
798
 
 
 
 
799
  """
 
28
  from tqdm.auto import tqdm
29
  from huggingface_hub import snapshot_download
30
 
 
31
 
32
+ # @@ constants ================
33
+
34
+ DEBUG = bool(int(os.environ.get("DEBUG", "1")))
35
+ BLOCK_ZH = bool(int(os.environ.get("BLOCK_ZH", "0")))
36
+
37
+ TENSOR_PARALLEL = int(os.environ.get("TENSOR_PARALLEL", "1"))
38
+ DTYPE = os.environ.get("DTYPE", "bfloat16")
39
+ # DTYPE = 'float16'
40
+
41
+ # MODEL_PATH = os.environ.get("MODEL_PATH", "notfound, please set `export MODEL_PATH=`")
42
+ MODEL_PATH = os.environ.get("MODEL_PATH", "seal_13b_a")
43
+ PORT = int(os.environ.get("PORT", "7860"))
44
+ STREAM_YIELD_MULTIPLE = int(os.environ.get("STREAM_YIELD_MULTIPLE", "1"))
45
+ MAX_TOKENS = 2048
46
 
47
+ # @@ constants ================
48
+ if not DEBUG:
49
+
50
  # vllm import
51
  from vllm import LLM, SamplingParams
52
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
 
66
  _MODEL_REGISTRY['FasterLlamaForCausalLM'] = LlamaForCausalLM
67
 
68
 
69
+ def _detect_lang(text):
70
+ from langdetect import detect as detect_lang
71
+ from langdetect.detector import LangDetectException
72
+ dlang = None
73
+ try:
74
+ dlang = detect_lang(text)
75
+ except Exception as e:
76
+ # No features in text.
77
+ print(f'Error: {e}')
78
+ if "No features in text." in str(e):
79
+ return "en"
80
+ else:
81
+ return "zh"
82
+ return dlang
83
+
84
+
85
  def hf_model_weights_iterator(
86
  model_name_or_path: str,
87
  cache_dir: Optional[str] = None,
 
239
  if "rotary_emb.inv_freq" in name:
240
  continue
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  if "embed_tokens" in name or "lm_head" in name:
243
  param = state_dict[name]
244
+ # Consider padding in the vocab size.
245
+ padded_vocab_size = (param.shape[0] * tp_size)
246
+ # num_extra_rows = padded_vocab_size - self.config.vocab_size
247
+ num_extra_rows = padded_vocab_size - loaded_weight.size(0)
248
+ load_size = loaded_weight.size()
249
+ extra_rows = torch.empty(num_extra_rows,
250
+ loaded_weight.shape[1])
251
+ extra_rows = extra_rows.to(loaded_weight)
252
+ loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)
253
+ if num_extra_rows > 0:
254
+ print(f'Add empty to {num_extra_rows} extra row for {name}')
255
+ print(f'Load: {name} | {padded_vocab_size=} | {self.config.vocab_size=} | {num_extra_rows=} | {param.size()=} | {loaded_weight.size()=} | {load_size=}')
256
+
257
+ # if "embed_tokens" in name or "lm_head" in name:
258
+ # param = state_dict[name]
259
+ # load_padded_tensor_parallel_vocab(param, loaded_weight, tensor_model_parallel_rank)
260
+ # loaded += 1
261
+ # continue
262
 
263
  is_attention_weight = False
264
  for weight_name, shard_size, offset in attention_weight_specs:
 
459
  ):
460
  x = super()._postprocess_chat_messages(chat_message)
461
  if isinstance(x, str):
462
+ x = x.strip().replace("\n", "<br>")
463
  return x
464
 
465
 
466
+ # gr.ChatInterface
467
+ from gradio.components import Button
468
+ from gradio.events import Dependency, EventListenerMethod
469
+
470
+
471
+ def _setup_stop_events(
472
+ self, event_triggers: list[EventListenerMethod], event_to_cancel: Dependency
473
+ ) -> None:
474
+ event_triggers = event_triggers if isinstance(event_triggers, (list, tuple)) else [event_triggers]
475
+ if self.stop_btn and self.is_generator:
476
+ if self.submit_btn:
477
+ for event_trigger in event_triggers:
478
+ event_trigger(
479
+ lambda: (
480
+ Button.update(visible=False),
481
+ Button.update(visible=True),
482
+ ),
483
+ None,
484
+ [self.submit_btn, self.stop_btn],
485
+ api_name=False,
486
+ queue=False,
487
+ )
488
+ event_to_cancel.then(
489
+ lambda: (Button.update(visible=True), Button.update(visible=False)),
490
+ None,
491
+ [self.submit_btn, self.stop_btn],
492
+ api_name=False,
493
+ queue=False,
494
+ )
495
+ else:
496
+ for event_trigger in event_triggers:
497
+ event_trigger(
498
+ lambda: Button.update(visible=True),
499
+ None,
500
+ [self.stop_btn],
501
+ api_name=False,
502
+ queue=False,
503
+ )
504
+ event_to_cancel.then(
505
+ lambda: Button.update(visible=False),
506
+ None,
507
+ [self.stop_btn],
508
+ api_name=False,
509
+ queue=False,
510
+ )
511
+ self.stop_btn.click(
512
+ None,
513
+ None,
514
+ None,
515
+ cancels=event_to_cancel,
516
+ api_name=False,
517
+ )
518
  else:
519
+ if self.submit_btn:
520
+ for event_trigger in event_triggers:
521
+ event_trigger(
522
+ lambda: Button.update(interactive=False),
523
+ None,
524
+ [self.submit_btn],
525
+ api_name=False,
526
+ queue=False,
527
+ )
528
+ event_to_cancel.then(
529
+ lambda: Button.update(interactive=True),
530
+ None,
531
+ [self.submit_btn],
532
+ api_name=False,
533
+ queue=False,
534
+ )
535
 
536
 
537
 
538
+ gr.ChatInterface._setup_stop_events = _setup_stop_events
539
+
540
  def chat_response(message, history, temperature: float, max_tokens: int, system_prompt: str = '') -> str:
541
  global llm
542
  assert llm is not None
 
552
  sampling_params = SamplingParams(temperature=temperature, max_tokens=max_tokens)
553
  gen = llm.generate(message, sampling_params)
554
  out = gen[0].outputs[0].text
 
555
  return f'{out}'
556
 
557
 
 
578
  while self.llm_engine.has_unfinished_requests():
579
  step_outputs = self.llm_engine.step()
580
  for output in step_outputs:
 
 
 
 
581
  outputs[output.request_id] = output
582
  # outputs = sorted(outputs, key=lambda x: int(x.request_id))
583
  if len(outputs) > 0:
 
646
  yield from _vllm_run_engine(self, use_tqdm)
647
 
648
 
649
+ # def chat_response_stream(
650
+ # message: str,
651
+ # history: List[Tuple[str, str]],
652
+ # temperature: float,
653
+ # max_tokens: int,
654
+ # frequency_penalty: float,
655
+ # system_prompt: str
656
+ # ) -> str:
657
+ # global llm, RES_PRINTED
658
+ # assert llm is not None
659
+ # # force removing all
660
+ # vllm_abort(llm)
661
+
662
+ # temperature = float(temperature)
663
+ # frequency_penalty = float(frequency_penalty)
664
+ # max_tokens = int(max_tokens)
665
+ # if system_prompt.strip() != '':
666
+ # # chat version, add system prompt
667
+ # message = llama_chat_sys_input_seq_constructor(
668
+ # message.strip(),
669
+ # sys_prompt=system_prompt
670
+ # )
671
+ # sampling_params = SamplingParams(
672
+ # temperature=temperature, max_tokens=max_tokens,
673
+ # frequency_penalty=frequency_penalty,
674
+ # )
675
+ # cur_out = None
676
+ # for j, gen in enumerate(vllm_generate_stream(llm, message, sampling_params)):
677
+ # if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
678
+ # yield cur_out
679
+ # assert len(gen) == 1, f'{gen}'
680
+ # item = next(iter(gen.values()))
681
+ # cur_out = item.outputs[0].text
682
+ # if not RES_PRINTED:
683
+ # print(f'{message}<<<{cur_out}>>>')
684
+ # RES_PRINTED = True
685
+ # if cur_out is not None:
686
+ # yield cur_out
687
+
688
+
689
+ BLOCK_MESSAGE = """Sorry, Chinese is not currently supported. Please clear the chat box for a new conversation.
690
+ 抱歉,目前不支持中文。 请清除聊天框以进行新对话。"""
691
+
692
+ def block_zh(
693
  message: str,
694
+ history: List[Tuple[str, str]]
 
 
 
 
695
  ) -> str:
696
+ # if any((BLOCK_MESSAGE in x[0].strip() or BLOCK_MESSAGE in x[1].strip()) for x in history):
697
+ if any((BLOCK_MESSAGE in x[1].strip()) for x in history):
698
+ return True
699
+ elif 'zh' in _detect_lang(message):
700
+ print(f'Detect zh: {message}')
701
+ return True
702
+ # ! optionally detect every responses message
703
+ else:
704
+ return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
705
 
706
+ # 抱歉,目前不支持中文。
707
  def chat_response_stream_multiturn(
708
  message: str,
709
  history: List[Tuple[str, str]],
710
  temperature: float,
711
  max_tokens: int,
712
  frequency_penalty: float,
713
+ system_prompt: Optional[str] = SYSTEM_PROMPT_1
714
  ) -> str:
715
  """Build multi turn
716
  <bos>[INST] B_SYS SytemPrompt E_SYS Prompt [/INST] Answer <eos>
 
730
  frequency_penalty = float(frequency_penalty)
731
  max_tokens = int(max_tokens)
732
 
733
+ message = message.strip()
734
+
735
+ # detect_ = _detect_lang(message)
736
+ # print(f'Message language: {detect_}')
737
+
738
+ # ! lang detect
739
+ if BLOCK_ZH:
740
+ if block_zh(message, history):
741
+ yield BLOCK_MESSAGE
742
+ return
743
+
744
  # history.append([message, None])
745
  # history will be appended with message later on
746
  full_prompt = llama_chat_multiturn_sys_input_seq_constructor(
747
  message, history, sys_prompt=system_prompt
748
  )
749
+ # print(full_prompt)
750
  sampling_params = SamplingParams(
751
  temperature=temperature, max_tokens=max_tokens,
752
  frequency_penalty=frequency_penalty,
753
  )
754
  cur_out = None
755
+ # for gen in vllm_generate_stream(llm, full_prompt, sampling_params):
756
+ for j, gen in enumerate(vllm_generate_stream(llm, full_prompt, sampling_params)):
757
+ if cur_out is not None and (STREAM_YIELD_MULTIPLE < 1 or j % STREAM_YIELD_MULTIPLE == 0) and j > 0:
758
  yield cur_out
759
  assert len(gen) == 1, f'{gen}'
760
  item = next(iter(gen.values()))
761
  cur_out = item.outputs[0].text
762
+
763
+ # if not RES_PRINTED:
764
+ print(f'{full_prompt}<<<{cur_out}>>>\n')
765
+ # RES_PRINTED = True
766
  if cur_out is not None:
767
  yield cur_out
768
+
769
+ # print(f'Output: {_detect_lang(cur_out)}')
770
+ if BLOCK_ZH:
771
+ if "zh" in _detect_lang(cur_out):
772
+ yield BLOCK_MESSAGE
773
 
774
 
775
  def debug_chat_response_echo(
 
780
  frequency_penalty: float = 0.4,
781
  system_prompt: str = SYSTEM_PROMPT_1,
782
  ) -> str:
783
+ import time
784
+ time.sleep(0.5)
785
  yield f"repeat: {message}"
786
 
787
 
788
  # ============ CONSTANT ============
789
+ # https://github.com/gradio-app/gradio/issues/884
790
+ MODEL_NAME = "SeaL-13B"
791
+ MODEL_TITLE = "SeaL-13B - An Assistant for South East Asian Languages"
792
+ # ! add icon: "<img src='file/lion.jpg' alt='image One'>"
793
  MODEL_DESC = """
794
+ <span style="font-size: larger">
795
+ This is a DAMO SeaL-13B chatbot assistant built by DAMO Academy, Alibaba Group. It can produce helpful responses in English 🇬🇧, Vietnamese 🇻🇳, Indonesian 🇮🇩 and Thai 🇹🇭.
796
+ </span>
797
+ """.strip()
798
+ # <br>
799
+
800
+
801
+ cite_markdown = """
802
+ ### Citation
803
  If you find our project useful, hope you can star our repo and cite our paper as follows:
804
  ```
805
  @article{damonlpsg2023seallm,
 
808
  year = 2023,
809
  }
810
  ```
 
 
 
 
811
  """
 
 
812
 
813
+ warning_markdown = """
814
+ ### Warning:
815
+ <span style="color: red">The chatbot may produce inaccurate and harmful information about people, places, or facts.</span>
816
 
817
+ <span style="color: red">We strongly advise against misuse of the chatbot to knowingly generate harmful or unethical content, \
818
+ or content that violates locally applicable and international laws or regulations, including hate speech, violence, pornography, deception, etc!</span>
819
+ """
 
 
820
 
821
 
822
+ path_markdown = """
823
+ #### Model path:
824
+ {model_path}
825
+ """
826
 
827
 
828
  def launch():
 
834
  assert tensor_parallel > 0 , f'{tensor_parallel} invalid'
835
  dtype = DTYPE
836
  sys_prompt = SYSTEM_PROMPT_1
837
+ max_tokens = MAX_TOKENS
838
+ print(f'Launch config: {model_path=} / {model_title=} / {tensor_parallel=} / {dtype=} / {max_tokens}\n{SYSTEM_PROMPT_1} | {BLOCK_ZH=}')
839
 
840
  if DEBUG:
841
  model_desc += "\n<br>!!!!! This is in debug mode, responses will be copy original"
842
  response_fn = debug_chat_response_echo
843
  else:
844
  # ! load the model
845
+ import vllm
846
  assert os.path.exists(model_path), f'{model_path} not found'
847
+ print(F'VLLM: {vllm.__version__}')
848
+ print(f'Load path: {model_path}')
849
  llm = LLM(model=model_path, dtype=dtype, tensor_parallel_size=tensor_parallel)
850
 
851
  print(f'Use system prompt:\n{sys_prompt}')
852
 
 
853
  response_fn = chat_response_stream_multiturn
854
  print(F'respond: {response_fn}')
855
 
856
  demo = gr.ChatInterface(
857
  response_fn,
858
  chatbot=ChatBot(
859
+ label=MODEL_NAME,
860
  bubble_full_width=False,
861
  latex_delimiters=[
862
  { "left": "$", "right": "$", "display": False},
 
865
  ),
866
  textbox=gr.Textbox(placeholder='Type message', lines=8, max_lines=128, min_width=200),
867
  submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
868
+ # ! consider preventing the stop button
869
+ stop_btn=None,
870
  title=f"{model_title}",
871
  description=f"{model_desc}",
872
  # ! decide if can change the system prompt.
 
874
  gr.Number(value=0, label='Temperature (higher -> more random)'),
875
  gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
876
  gr.Number(value=0.4, label='Frequency penalty (> 0 encourage new tokens)'),
877
+ # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
878
+ ],
879
  )
880
+ with demo:
881
+ gr.Markdown(warning_markdown)
882
+ gr.Markdown(cite_markdown)
883
+ gr.Markdown(path_markdown.format(model_path=model_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
884
 
885
  demo.queue()
886
+ demo.launch(server_port=PORT)
 
887
 
888
 
889
  def main():
 
902
  export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW8k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.FSePlCq13M.FSePlCq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_4000
903
  export MODEL_PATH=${dataroot}/llama-2-7b-lxxp-faster
904
  export MODEL_PATH=${dataroot}/llama-2-7b-chat-xp
905
+
906
+ export DEBUG=0
907
+ export CUDA_VISIBLE_DEVICES=0
908
+ export MODEL_PATH=seal_13b_a
909
+ export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/merlion13s108Hi8kPretFlCW12k.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.SeaV2Cq13M.SeaV2Cq13M.m4k.b8.lr1e5.linear.wa0k.ms858k.grac1.se1.8g.v4c.zfsdp/step_6000
910
+
911
+ export MODEL_PATH=${dataroot}/hf_train/pretrain_lm/swpn/mer13s108Hi16kPretFlCWNLP12k_SFT2.LMFromHf.a.gc.t5k0.vizhthid.mean_std.TrainTask.NLNL.Multi.Vi.Sft2Censor.Sft2Censor.m4k.b8.lr1e5.linear.wa0k.ms1144k.grac1.se1.6g.v4c.zfsdp/step_2000
912
+ export PORT=8799
913
+ export BLOCK_ZH=1
914
  python app.py
915
 
916
 
917
+ DEBUG=1 python app.py
918
+
919
+
920
  """
requirements.txt CHANGED
@@ -22,5 +22,6 @@ tensorboard
22
  geomloss
23
  einops
24
  gdown
 
25
  vllm==0.1.4
26
  transformers
 
22
  geomloss
23
  einops
24
  gdown
25
+ langdetect
26
  vllm==0.1.4
27
  transformers