File size: 931 Bytes
8708772 26d6f1b 8708772 911ec5f 8708772 911ec5f 8708772 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
model = AutoModelForCausalLM.from_pretrained("Sigurdur/qa-icebreaker")
tokenizer = AutoTokenizer.from_pretrained("Sigurdur/qa-icebreaker")
def streaming_respond(question, history):
input_ids = tokenizer.encode(f"### Question:\n{question}\n\n### Answer:\n", return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=100,
temperature=0.7,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
gr.ChatInterface(streaming_respond).launch() |