PPAIA-public / app.py
Abdizriel's picture
Redo gradio code
e48a914
import gradio as gr
import logging
import sys
import requests
import argparse
from app_module.configuration import get_index
from app_module.serve_utils import (
disable_btn, no_change_btn,
downvote_last_response, enable_btn, flag_last_response,
get_window_url_params, upvote_last_response
)
from app_module.conversation import default_conversation
# configure logging
logging.basicConfig(
stream=sys.stdout,
level=logging.DEBUG,
format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s",
)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))
_title = "PPAIA - Pet Parent Artificial Intelligence Assistant"
index = get_index()
chat_engine = index.as_chat_engine(verbose=True)
with open("assets/custom.css", "r", encoding="utf-8") as f:
customCSS = f.read()
def load_demo(url_params, request: gr.Request):
state = default_conversation.copy()
logging.info("Loading demo")
return (state,
gr.Chatbot.update(visible=True),
gr.Textbox.update(visible=True),
gr.UploadButton.update(visible=False),
gr.Button.update(visible=True),
gr.Row.update(visible=True))
def clear_history(request: gr.Request):
state = default_conversation.copy()
chat_engine.reset()
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
def add_text_http_bot(state, text, file, request: gr.Request):
if len(text) <= 0 and (file is None):
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
if file is not None:
text = text
# if '<image>' not in text:
# text = text + '\n<image>'
# text = (text, image)
state.append_message(state.roles[0], text)
state.append_message(state.roles[1], None)
state.skip_next = False
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
if state.skip_next:
# This generate call is skipped due to invalid inputs
yield (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
return
state.messages[-1][-1] = "β–Œ"
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
try:
chat_history = state.to_chat_history()
chat_response = chat_engine.chat(text, chat_history=chat_history)
state.messages[-1][-1] = chat_response.response
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
def regenerate_http_bot(state, request: gr.Request):
state.messages[-1][-1] = None
state.skip_next = False
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
state.messages[-1][-1] = "β–Œ"
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
try:
chat_history = state.to_chat_history()
chat_response = chat_engine.chat(
state.messages[-1][0], chat_history=chat_history)
state.messages[-1][-1] = chat_response.response
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
def build_demo():
logging.info("Building demo")
with gr.Blocks(title=_title, css=customCSS) as demo:
state = gr.State()
logging.info("Building demo - blocks")
with gr.Row(equal_height=True):
with gr.Column(scale=5):
with gr.Row():
chatbot = gr.Chatbot(
elem_id="ppaia_chatbot", visible=False, height="100%", show_label=False)
with gr.Row():
with gr.Column(min_width=70, scale=1, visible=False):
upload_btn = gr.UploadButton(
"πŸ“", file_types=["image"], visible=False)
with gr.Column(scale=12):
textbox = gr.Textbox(
show_label=False, placeholder="Enter text and press ENTER", visible=False, container=False
)
with gr.Column(min_width=70, scale=1):
submit_btn = gr.Button("βœ‰οΈ", visible=False)
with gr.Row(visible=False) as button_row:
upvote_btn = gr.Button("πŸ‘ Upvote")
downvote_btn = gr.Button("πŸ‘Ž Downvote")
flag_btn = gr.Button("🚩 Flag")
regenerate_btn = gr.Button("πŸ”„ Regenerate")
clear_btn = gr.Button(
"🧹 Clear Conversation",
)
url_params = gr.JSON(visible=False)
logging.info("Building demo - buttons")
btn_list = [upvote_btn, downvote_btn,
flag_btn, regenerate_btn, clear_btn]
logging.info("Building demo - button actions")
upvote_btn.click(upvote_last_response,
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
downvote_btn.click(downvote_last_response,
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
flag_btn.click(flag_last_response,
[state], [textbox, upvote_btn, downvote_btn, flag_btn])
regenerate_btn.click(regenerate_http_bot, [state],
[state, chatbot, textbox, upload_btn] + btn_list)
clear_btn.click(clear_history, None, [
state, chatbot, textbox, upload_btn] + btn_list)
textbox.submit(
add_text_http_bot,
[state, textbox, upload_btn],
[state, chatbot, textbox, upload_btn] + btn_list
)
submit_btn.click(
add_text_http_bot,
[state, textbox, upload_btn],
[state, chatbot, textbox, upload_btn] + btn_list
)
logging.info("Building demo - load")
demo.load(load_demo,
[url_params],
[state, chatbot, textbox, upload_btn, submit_btn, button_row],
_js=get_window_url_params
)
logging.info("Demo built")
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--debug", action="store_true",
help="using debug mode")
parser.add_argument("--port", type=int)
parser.add_argument("--concurrency-count", type=int, default=1)
args = parser.parse_args()
demo = build_demo()
logging.info("Launching demo")
demo.queue(api_open=False).launch(
share=False)