vilarin commited on
Commit
9eefdf9
·
verified ·
1 Parent(s): 6f1ee3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  import os
7
  from threading import Thread
8
 
@@ -34,6 +34,15 @@ model = AutoModelForCausalLM.from_pretrained(
34
 
35
  tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
36
 
 
 
 
 
 
 
 
 
 
37
 
38
  @spaces.GPU()
39
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
@@ -49,24 +58,29 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
49
  input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
50
  #input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
 
 
52
 
53
  generate_kwargs = dict(
 
54
  max_new_tokens=max_new_tokens,
55
  streamer=streamer,
56
  do_sample=True,
57
  top_k=1,
58
  temperature=temperature,
59
  repetition_penalty=1,
 
 
60
  )
61
- gen_kwargs = {**input_ids, **generate_kwargs}
62
 
63
- thread = Thread(target=model.generate, kwargs=gen_kwargs)
64
  thread.start()
65
  buffer = ""
66
  for new_text in streamer:
67
  buffer += new_text
68
  yield buffer
69
-
70
  chatbot = gr.Chatbot(height=600, placeholder = PLACEHOLDER)
71
 
72
  with gr.Blocks(css=CSS) as demo:
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
6
  import os
7
  from threading import Thread
8
 
 
34
 
35
  tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
36
 
37
+ class StopOnTokens(StoppingCriteria):
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ # stop_ids = model.config.eos_token_id
40
+ stop_ids = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
41
+ tokenizer.get_command("<|observation|>")]
42
+ for stop_id in stop_ids:
43
+ if input_ids[0][-1] == stop_id:
44
+ return True
45
+ return False
46
 
47
  @spaces.GPU()
48
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
 
58
  input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
59
  #input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
60
  streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
61
+ eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
62
+ tokenizer.get_command("<|observation|>")]
63
 
64
  generate_kwargs = dict(
65
+ input_ids=input_ids,
66
  max_new_tokens=max_new_tokens,
67
  streamer=streamer,
68
  do_sample=True,
69
  top_k=1,
70
  temperature=temperature,
71
  repetition_penalty=1,
72
+ stopping_criteria=StoppingCriteriaList([stop]),
73
+ eos_token_id=eos_token_id,
74
  )
75
+ #gen_kwargs = {**input_ids, **generate_kwargs}
76
 
77
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
78
  thread.start()
79
  buffer = ""
80
  for new_text in streamer:
81
  buffer += new_text
82
  yield buffer
83
+
84
  chatbot = gr.Chatbot(height=600, placeholder = PLACEHOLDER)
85
 
86
  with gr.Blocks(css=CSS) as demo: