dolphin / app.py
nroggendorff's picture
Update app.py
8339753 verified
raw
history blame
2.03 kB
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [50256, 50295]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU(duration=480)
def predict(message, history):
torch.set_default_device("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
"cognitivecomputations/dolphin-2.8-mistral-7b-v02",
torch_dtype="auto",
load_in_4bit=True,
trust_remote_code=True
)
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
system_prompt = "<|im_start|>system\nYou are Dolphin, a helpful AI assistant.<|im_end|>"
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids,
streamer=streamer,
max_new_tokens=256,
do_sample=True,
top_p=0.95,
top_k=50,
temperature=0.7,
num_beams=1,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '<|im_end|>' in partial_message:
break
yield partial_message
gr.ChatInterface(predict).launch()