Spestly commited on
Commit
6c71f71
·
verified ·
1 Parent(s): 32260ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -15
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
4
 
5
  # Load model and tokenizer
6
  model_name = "Spestly/AwA-1.5B"
@@ -18,23 +19,30 @@ def generate_response_stream(message, history):
18
  f"### Instruction:\n{message}\n\n### Response:"
19
  )
20
 
 
21
  inputs = tokenizer(instruction, return_tensors="pt")
22
 
23
- with torch.no_grad():
24
- # Generate tokens one at a time
25
- generated_ids = model.generate(
26
- **inputs,
27
- max_new_tokens=1000,
28
- num_return_sequences=1,
29
- temperature=0.7,
30
- top_p=0.9,
31
- do_sample=True,
32
- streamer=None, # Replace this if the Transformers version supports streaming
33
- )
34
 
35
- # Decode and yield response tokens incrementally
36
- for token_id in generated_ids[0]:
37
- yield tokenizer.decode(token_id, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  iface = gr.ChatInterface(
40
  fn=generate_response_stream,
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
3
  import torch
4
+ import threading
5
 
6
  # Load model and tokenizer
7
  model_name = "Spestly/AwA-1.5B"
 
19
  f"### Instruction:\n{message}\n\n### Response:"
20
  )
21
 
22
+ # Tokenize the input instruction
23
  inputs = tokenizer(instruction, return_tensors="pt")
24
 
25
+ # Create a TextIteratorStreamer for real-time output
26
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
27
 
28
+ # Generate tokens in a separate thread
29
+ generation_thread = threading.Thread(
30
+ target=model.generate,
31
+ kwargs={
32
+ "input_ids": inputs["input_ids"],
33
+ "attention_mask": inputs["attention_mask"],
34
+ "max_new_tokens": 1000,
35
+ "temperature": 0.7,
36
+ "top_p": 0.9,
37
+ "do_sample": True,
38
+ "streamer": streamer,
39
+ }
40
+ )
41
+ generation_thread.start()
42
+
43
+ # Stream tokens as they are generated
44
+ for token in streamer:
45
+ yield token
46
 
47
  iface = gr.ChatInterface(
48
  fn=generate_response_stream,