Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
|
|
4 |
|
5 |
# Load model and tokenizer
|
6 |
model_name = "Spestly/AwA-1.5B"
|
@@ -18,23 +19,30 @@ def generate_response_stream(message, history):
|
|
18 |
f"### Instruction:\n{message}\n\n### Response:"
|
19 |
)
|
20 |
|
|
|
21 |
inputs = tokenizer(instruction, return_tensors="pt")
|
22 |
|
23 |
-
|
24 |
-
|
25 |
-
generated_ids = model.generate(
|
26 |
-
**inputs,
|
27 |
-
max_new_tokens=1000,
|
28 |
-
num_return_sequences=1,
|
29 |
-
temperature=0.7,
|
30 |
-
top_p=0.9,
|
31 |
-
do_sample=True,
|
32 |
-
streamer=None, # Replace this if the Transformers version supports streaming
|
33 |
-
)
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
iface = gr.ChatInterface(
|
40 |
fn=generate_response_stream,
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
3 |
import torch
|
4 |
+
import threading
|
5 |
|
6 |
# Load model and tokenizer
|
7 |
model_name = "Spestly/AwA-1.5B"
|
|
|
19 |
f"### Instruction:\n{message}\n\n### Response:"
|
20 |
)
|
21 |
|
22 |
+
# Tokenize the input instruction
|
23 |
inputs = tokenizer(instruction, return_tensors="pt")
|
24 |
|
25 |
+
# Create a TextIteratorStreamer for real-time output
|
26 |
+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
+
# Generate tokens in a separate thread
|
29 |
+
generation_thread = threading.Thread(
|
30 |
+
target=model.generate,
|
31 |
+
kwargs={
|
32 |
+
"input_ids": inputs["input_ids"],
|
33 |
+
"attention_mask": inputs["attention_mask"],
|
34 |
+
"max_new_tokens": 1000,
|
35 |
+
"temperature": 0.7,
|
36 |
+
"top_p": 0.9,
|
37 |
+
"do_sample": True,
|
38 |
+
"streamer": streamer,
|
39 |
+
}
|
40 |
+
)
|
41 |
+
generation_thread.start()
|
42 |
+
|
43 |
+
# Stream tokens as they are generated
|
44 |
+
for token in streamer:
|
45 |
+
yield token
|
46 |
|
47 |
iface = gr.ChatInterface(
|
48 |
fn=generate_response_stream,
|