consistent_chat / app.py
shljessie
use local model
34849f2
import os
from threading import Thread
from typing import Iterator
import gradio as gr
from typing import List, Tuple
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import spaces
MAX_INPUT_TOKEN_LENGTH= 50
LICENSE = """
<p/>
---
As a derivate work of [ConsistentAgents]() by Seonghee Lee.
"""
if torch.cuda.is_available():
model_id = "./backprop_llama2_69_1e-05"
HF_ACCESS_TOKEN = os.getenv('HF_ACCESS_TOKEN')
model = AutoModelForCausalLM.from_pretrained(model_id, use_auth_token=HF_ACCESS_TOKEN, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: List[Tuple[str, str]],
system_prompt: str,
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = []
if system_prompt:
conversation.append({"role": "system", "content": system_prompt})
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Create the Gradio interface
# gr.ChatInterface(
# yes_man,
# chatbot=gr.Chatbot(height=300),
# textbox=gr.Textbox(placeholder="Ask me a yes or no question", container=False, scale=7),
# title="Yes Man",
# description="Ask Yes Man any question",
# theme="soft",
# examples=["Hello", "Am I cool?", "Are tomatoes vegetables?"],
# cache_examples=True,
# retry_btn=None,
# undo_btn="Delete Previous",
# clear_btn="Clear",
# ).launch()
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6),
],
)
with gr.Blocks(css="style.css") as demo:
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch(server_name='10.79.12.70',share=True)