Spaces:
Runtime error
Runtime error
import gradio as gr | |
import logging | |
import sys | |
import requests | |
import argparse | |
from app_module.configuration import get_index | |
from app_module.serve_utils import ( | |
disable_btn, no_change_btn, | |
downvote_last_response, enable_btn, flag_last_response, | |
get_window_url_params, upvote_last_response | |
) | |
from app_module.conversation import default_conversation | |
# configure logging | |
logging.basicConfig( | |
stream=sys.stdout, | |
level=logging.DEBUG, | |
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", | |
) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
_title = "PPAIA - Pet Parent Artificial Intelligence Assistant" | |
index = get_index() | |
chat_engine = index.as_chat_engine(verbose=True) | |
with open("assets/custom.css", "r", encoding="utf-8") as f: | |
customCSS = f.read() | |
def load_demo(url_params, request: gr.Request): | |
state = default_conversation.copy() | |
logging.info("Loading demo") | |
return (state, | |
gr.Chatbot.update(visible=True), | |
gr.Textbox.update(visible=True), | |
gr.UploadButton.update(visible=False), | |
gr.Button.update(visible=True), | |
gr.Row.update(visible=True)) | |
def clear_history(request: gr.Request): | |
state = default_conversation.copy() | |
chat_engine.reset() | |
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
def add_text_http_bot(state, text, file, request: gr.Request): | |
if len(text) <= 0 and (file is None): | |
state.skip_next = True | |
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 | |
if file is not None: | |
text = text | |
# if '<image>' not in text: | |
# text = text + '\n<image>' | |
# text = (text, image) | |
state.append_message(state.roles[0], text) | |
state.append_message(state.roles[1], None) | |
state.skip_next = False | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
if state.skip_next: | |
# This generate call is skipped due to invalid inputs | |
yield (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 | |
return | |
state.messages[-1][-1] = "β" | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
try: | |
chat_history = state.to_chat_history() | |
chat_response = chat_engine.chat(text, chat_history=chat_history) | |
state.messages[-1][-1] = chat_response.response | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
except requests.exceptions.RequestException as e: | |
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
return | |
state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 | |
def regenerate_http_bot(state, request: gr.Request): | |
state.messages[-1][-1] = None | |
state.skip_next = False | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
state.messages[-1][-1] = "β" | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
try: | |
chat_history = state.to_chat_history() | |
chat_response = chat_engine.chat( | |
state.messages[-1][0], chat_history=chat_history) | |
state.messages[-1][-1] = chat_response.response | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
except requests.exceptions.RequestException as e: | |
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" | |
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
return | |
state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5 | |
def build_demo(): | |
logging.info("Building demo") | |
with gr.Blocks(title=_title, css=customCSS) as demo: | |
state = gr.State() | |
logging.info("Building demo - blocks") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
with gr.Row(): | |
chatbot = gr.Chatbot( | |
elem_id="ppaia_chatbot", visible=False, height="100%", show_label=False) | |
with gr.Row(): | |
with gr.Column(min_width=70, scale=1, visible=False): | |
upload_btn = gr.UploadButton( | |
"π", file_types=["image"], visible=False) | |
with gr.Column(scale=12): | |
textbox = gr.Textbox( | |
show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False | |
) | |
with gr.Column(min_width=70, scale=1): | |
submit_btn = gr.Button("βοΈ", visible=False) | |
with gr.Row(visible=False) as button_row: | |
upvote_btn = gr.Button("π Upvote") | |
downvote_btn = gr.Button("π Downvote") | |
flag_btn = gr.Button("π© Flag") | |
regenerate_btn = gr.Button("π Regenerate") | |
clear_btn = gr.Button( | |
"π§Ή Clear Conversation", | |
) | |
url_params = gr.JSON(visible=False) | |
logging.info("Building demo - buttons") | |
btn_list = [upvote_btn, downvote_btn, | |
flag_btn, regenerate_btn, clear_btn] | |
logging.info("Building demo - button actions") | |
upvote_btn.click(upvote_last_response, | |
[state], [textbox, upvote_btn, downvote_btn, flag_btn]) | |
downvote_btn.click(downvote_last_response, | |
[state], [textbox, upvote_btn, downvote_btn, flag_btn]) | |
flag_btn.click(flag_last_response, | |
[state], [textbox, upvote_btn, downvote_btn, flag_btn]) | |
regenerate_btn.click(regenerate_http_bot, [state], | |
[state, chatbot, textbox, upload_btn] + btn_list) | |
clear_btn.click(clear_history, None, [ | |
state, chatbot, textbox, upload_btn] + btn_list) | |
textbox.submit( | |
add_text_http_bot, | |
[state, textbox, upload_btn], | |
[state, chatbot, textbox, upload_btn] + btn_list | |
) | |
submit_btn.click( | |
add_text_http_bot, | |
[state, textbox, upload_btn], | |
[state, chatbot, textbox, upload_btn] + btn_list | |
) | |
logging.info("Building demo - load") | |
demo.load(load_demo, | |
[url_params], | |
[state, chatbot, textbox, upload_btn, submit_btn, button_row], | |
_js=get_window_url_params | |
) | |
logging.info("Demo built") | |
return demo | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0") | |
parser.add_argument("--debug", action="store_true", | |
help="using debug mode") | |
parser.add_argument("--port", type=int) | |
parser.add_argument("--concurrency-count", type=int, default=1) | |
args = parser.parse_args() | |
demo = build_demo() | |
logging.info("Launching demo") | |
demo.queue(api_open=False).launch( | |
share=False) | |