icechat / app.py
Sigurdur's picture
Update app.py
26d6f1b verified
raw
history blame
931 Bytes
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
model = AutoModelForCausalLM.from_pretrained("Sigurdur/qa-icebreaker")
tokenizer = AutoTokenizer.from_pretrained("Sigurdur/qa-icebreaker")
def streaming_respond(question, history):
input_ids = tokenizer.encode(f"### Question:\n{question}\n\n### Answer:\n", return_tensors="pt")
streamer = TextIteratorStreamer(
tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=100,
temperature=0.7,
num_beams=1,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
gr.ChatInterface(streaming_respond).launch()