vilarin commited on
Commit
030c23d
·
verified ·
1 Parent(s): 6e9bcd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
3
  import gradio as gr
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
6
  import os
7
  from threading import Thread
8
 
@@ -37,7 +38,7 @@ tokenizer = AutoTokenizer.from_pretrained(MODELS,trust_remote_code=True)
37
 
38
 
39
  @spaces.GPU
40
- def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
41
  print(f'message is - {message}')
42
  print(f'history is - {history}')
43
  conversation = []
@@ -48,22 +49,31 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
48
  print(f"Conversation is -\n{conversation}")
49
 
50
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
51
-
52
 
53
  generate_kwargs = dict(
54
- max_length=2500,
55
- max_new_tokens=max_new_tokens,
56
  do_sample=True,
57
  top_k=1,
58
  temperature=temperature,
59
  repetition_penalty=1.2,
60
  )
 
61
 
62
  with torch.no_grad():
63
- outputs = model.generate(**inputs, **generate_kwargs)
64
- outputs = outputs[:, inputs['input_ids'].shape[1]:]
65
- results = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
- return results
 
 
 
 
 
 
 
 
67
 
68
 
69
 
@@ -90,10 +100,10 @@ with gr.Blocks(css=CSS) as demo:
90
  ),
91
  gr.Slider(
92
  minimum=128,
93
- maximum=4096,
94
  step=1,
95
  value=1024,
96
- label="Max new tokens",
97
  render=False,
98
  ),
99
  ],
 
3
  import gradio as gr
4
  import spaces
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from huggingface_hub.inference._generated.types import TextGenerationStreamOutput, TextGenerationStreamOutputToken
7
  import os
8
  from threading import Thread
9
 
 
38
 
39
 
40
  @spaces.GPU
41
+ def stream_chat(message: str, history: list, temperature: float, max_length: int):
42
  print(f'message is - {message}')
43
  print(f'history is - {history}')
44
  conversation = []
 
49
  print(f"Conversation is -\n{conversation}")
50
 
51
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
52
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
53
 
54
  generate_kwargs = dict(
55
+ max_length=max_length,
56
+ streamer=streamer,
57
  do_sample=True,
58
  top_k=1,
59
  temperature=temperature,
60
  repetition_penalty=1.2,
61
  )
62
+ gen_kwargs = {**input_ids, **generate_kwargs}
63
 
64
  with torch.no_grad():
65
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
66
+ thread.start()
67
+ for next_text in streamer:
68
+ yield TextGenerationStreamOutput(
69
+ index=0,
70
+ token=TextGenerationStreamOutputToken(
71
+ id=0,
72
+ logprob=0,
73
+ text=next_text,
74
+ special=False,
75
+ )
76
+ )
77
 
78
 
79
 
 
100
  ),
101
  gr.Slider(
102
  minimum=128,
103
+ maximum=8192,
104
  step=1,
105
  value=1024,
106
+ label="Max Length",
107
  render=False,
108
  ),
109
  ],