import os # Questions for Gradio # - Chat share button is enabled by default but thrown an error when clicked. # - How to add local images in HTML? (https://github.com/gradio-app/gradio/issues/884) # - How to allow Chatbot to fill the vertical space? (https://github.com/gradio-app/gradio/issues/4001) # TODO: # - Add the 1MB models, keras/gemma_1.1_instruct_7b_en # - Add retry button, for each model individually # - Add ability to route a message to a single model only. # - log_applied_layout_map: make it work for Llama3CausalLM and LlamaCausalLM (vicuna) # - display context length os.environ["KERAS_BACKEND"] = "jax" import gradio as gr from gradio import ChatMessage import keras_hub from chatstate import ChatState from enum import Enum from models import ( model_presets, load_model, model_labels, preset_to_website_url, get_appropriate_chat_template, ) class TextRoute(Enum): LEFT = 0 RIGHT = 1 BOTH = 2 model_labels_list = list(model_labels) # load and warm up (compile) all the models models = [] for preset in model_presets: model = load_model(preset) chat_template = get_appropriate_chat_template(preset) chat_state = ChatState(model, "", chat_template) prompt, response = chat_state.send_message("Hello") print("model " + preset + " loaded and initialized.") print("The model responded: " + response) models.append(model) # For local debugging # model = keras_hub.models.Llama3CausalLM.from_preset( # # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16" # "../misc-code/ari_tiny_llama3" # ) # models = [model, model, model, model, model] def chat_turn_assistant( message, sel, history, system_message, # max_tokens, # temperature, # top_p, ): model = models[sel] preset = model_presets[sel] chat_template = get_appropriate_chat_template(preset) chat_state = ChatState(model, system_message, chat_template) for msg in history: msg = ChatMessage(**msg) if msg.role == "user": chat_state.add_to_history_as_user(msg.content) elif msg.role == "assistant": chat_state.add_to_history_as_model(msg.content) prompt, response = chat_state.send_message(message) history.append(ChatMessage(role="assistant", content=response)) return history def chat_turn_both_assistant( message, sel1, sel2, history1, history2, system_message ): return ( chat_turn_assistant(message, sel1, history1, system_message), chat_turn_assistant(message, sel2, history2, system_message), ) def chat_turn_user(message, history): history.append(ChatMessage(role="user", content=message)) return history def chat_turn_both_user(message, history1, history2): return ( chat_turn_user(message, history1), chat_turn_user(message, history2), ) def bot_icon_select(model_name): if "gemma" in model_name: return "img/gemma.png" elif "llama" in model_name: return "img/meta.png" elif "vicuna" in model_name: return "img/vicuna.png" elif "mistral" in model_name: return "img/mistral.png" # default return "img/bot.png" def instantiate_select_box(sel, model_labels): return gr.Dropdown( choices=[(name, i) for i, name in enumerate(model_labels)], show_label=False, value=sel, info="Selected model: " + preset_to_website_url(model_presets[sel]) + "", ) def instantiate_chatbot(sel, key): model_name = model_presets[sel] return gr.Chatbot( type="messages", key=key, show_label=False, show_share_button=False, show_copy_all_button=True, avatar_images=("img/usr.png", bot_icon_select(model_name)), ) def instantiate_arrow_button(route, text_route): icons = { TextRoute.LEFT: "img/arrowL.png", TextRoute.RIGHT: "img/arrowR.png", TextRoute.BOTH: "img/arrowRL.png", } button = gr.Button( "", size="sm", scale=0, min_width=40, icon=icons[route], ) button.click(lambda: route, outputs=[text_route]) return button def instantiate_retry_button(route): return gr.Button( "", size="sm", scale=0, min_width=40, icon="img/retry.png", ) def instantiate_trash_button(): return gr.Button( "", size="sm", scale=0, min_width=40, icon="img/trash.png", ) def instantiate_text_box(): return gr.Textbox(label="Your message:", submit_btn=True, key="msg") def instantiate_additional_settings(): with gr.Accordion("Additional settings", open=False): system_message = gr.Textbox( label="Sytem prompt", key="system_prompt", value="You are a helpful assistant and your name is Eliza.", ) return system_message def retry_fn(history): if len(history) >= 2: msg = history.pop(-1) # assistant message msg = history.pop(-1) # user message return msg["content"], history else: return gr.skip(), gr.skip() def retry_fn_both(history1, history2): msg1, history1 = retry_fn(history1) msg2, history2 = retry_fn(history2) if isinstance(msg1, str) and isinstance(msg2, str): if msg1 == msg2: msg = msg1 else: msg = msg1 + " / " + msg2 elif isinstance(msg1, str): msg = msg1 elif isinstance(msg2, str): msg = msg2 else: msg = msg1 return msg, history1, history2 sel1 = instantiate_select_box(0, model_labels_list) sel2 = instantiate_select_box(1, model_labels_list) chatbot1 = instantiate_chatbot(sel1.value, "chat1") chatbot2 = instantiate_chatbot(sel2.value, "chat2") # to correctly align the left/right arrows CSS = ".stick-to-the-right {align-items: end; justify-content: end}" with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo: # Where do messages go text_route = gr.State(TextRoute.BOTH) with gr.Row(): gr.Image( "img/keras_logo_k.png", width=80, height=80, min_width=80, show_label=False, show_download_button=False, show_fullscreen_button=False, show_share_button=False, interactive=False, scale=0, container=False, ) gr.HTML( "

Battle of the Keras chatbots on TPU

" + "All the models are loaded into the TPU memory. " + "You can call any of them and compare their answers. " + "The entire chat
history is fed to the models at every submission. " + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision." ) with gr.Row(): sel1.render(), sel2.render(), with gr.Row(): chatbot1.render() chatbot2.render() @gr.render(inputs=text_route) def render_text_area(route): if route == TextRoute.BOTH: with gr.Row(): msg = instantiate_text_box() with gr.Column(scale=0, min_width=100): with gr.Row(): instantiate_arrow_button(TextRoute.LEFT, text_route) retry = instantiate_retry_button(route) with gr.Row(): instantiate_arrow_button(TextRoute.RIGHT, text_route) trash = instantiate_trash_button() retry.click( retry_fn_both, inputs=[chatbot1, chatbot2], outputs=[msg, chatbot1, chatbot2], ) trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2]) elif route == TextRoute.LEFT: with gr.Row(): with gr.Column(scale=1): msg = instantiate_text_box() with gr.Column(scale=1): with gr.Row(): instantiate_arrow_button(TextRoute.RIGHT, text_route) retry = instantiate_retry_button(route) with gr.Row(): instantiate_arrow_button(TextRoute.BOTH, text_route) trash = instantiate_trash_button() retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1]) trash.click(lambda: ("", []), outputs=[msg, chatbot1]) elif route == TextRoute.RIGHT: with gr.Row(): with gr.Column(scale=1, elem_classes="stick-to-the-right"): with gr.Row(elem_classes="stick-to-the-right"): retry = instantiate_retry_button(route) instantiate_arrow_button(TextRoute.LEFT, text_route) with gr.Row(elem_classes="stick-to-the-right"): trash = instantiate_trash_button() instantiate_arrow_button(TextRoute.BOTH, text_route) with gr.Column(scale=1): msg = instantiate_text_box() retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2]) trash.click(lambda: ("", []), outputs=[msg, chatbot2]) system_message = instantiate_additional_settings() # Route the submitted message to the left, right or both chatbots if route == TextRoute.LEFT: submission = msg.submit( chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1] ).then( chat_turn_assistant, [msg, sel1, chatbot1, system_message], outputs=[chatbot1], ) elif route == TextRoute.RIGHT: submission = msg.submit( chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2] ).then( chat_turn_assistant, [msg, sel2, chatbot2, system_message], outputs=[chatbot2], ) elif route == TextRoute.BOTH: submission = msg.submit( chat_turn_both_user, inputs=[msg, chatbot1, chatbot2], outputs=[chatbot1, chatbot2], ).then( chat_turn_both_assistant, [msg, sel1, sel2, chatbot1, chatbot2, system_message], outputs=[chatbot1, chatbot2], ) # In all cases reset text box after submission submission.then(lambda: "", outputs=msg) sel1.select( lambda sel: instantiate_chatbot(sel, "chat1"), inputs=[sel1], outputs=[chatbot1], ).then( lambda sel: instantiate_select_box(sel, model_labels_list), inputs=[sel1], outputs=[sel1], ) sel2.select( lambda sel: instantiate_chatbot(sel, "chat2"), inputs=[sel2], outputs=[chatbot2], ).then( lambda sel: instantiate_select_box(sel, model_labels_list), inputs=[sel2], outputs=[sel2], ) if __name__ == "__main__": demo.launch()