File size: 931 Bytes
8708772
 
 
 
26d6f1b
 
8708772
 
 
911ec5f
8708772
 
 
 
 
 
911ec5f
 
8708772
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

model = AutoModelForCausalLM.from_pretrained("Sigurdur/qa-icebreaker")
tokenizer = AutoTokenizer.from_pretrained("Sigurdur/qa-icebreaker")


def streaming_respond(question, history):
    input_ids = tokenizer.encode(f"### Question:\n{question}\n\n### Answer:\n", return_tensors="pt")
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=100,
        temperature=0.7,
        num_beams=1,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


gr.ChatInterface(streaming_respond).launch()