Spaces:
Running
Running
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | |
# Licensed under the Apache License, Version 2.0 (the “License”); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an “AS IS” BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | |
""" | |
Gradio-based web UI to explore the Camel dataset. | |
""" | |
import argparse | |
import random | |
from typing import Dict, List, Optional, Tuple | |
import gradio as gr | |
from apps.data_explorer.loader import Datasets, load_datasets | |
def parse_arguments(): | |
""" Get command line arguments. """ | |
parser = argparse.ArgumentParser("Camel data explorer") | |
parser.add_argument( | |
'--data-path', type=str, default=None, | |
help='Path to the folder with ZIP datasets containing JSONs') | |
parser.add_argument('--default-dataset', type=str, default=None, | |
help='Default dataset name selected from ZIPs') | |
parser.add_argument('--share', type=bool, default=False, | |
help='Expose the web UI to Gradio') | |
parser.add_argument( | |
'--server-name', type=str, default="0.0.0.0", | |
help='localhost for local, 0.0.0.0 (default) for public') | |
parser.add_argument('--server-port', type=int, default=8080, | |
help='Port ot run the web page on') | |
parser.add_argument('--inbrowser', type=bool, default=False, | |
help='Open the web UI in the default browser on lunch') | |
parser.add_argument( | |
'--concurrency-count', type=int, default=10, | |
help='Number if concurrent threads at Gradio websocket queue. ' + | |
'Increase to serve more requests but keep an eye on RAM usage.') | |
args, unknown = parser.parse_known_args() | |
if len(unknown) > 0: | |
print("Unknown args: ", unknown) | |
return args | |
def construct_ui(blocks, datasets: Datasets, | |
default_dataset: Optional[str] = None): | |
""" Build Gradio UI and populate with chat data from JSONs. | |
Args: | |
blocks: Gradio blocks | |
datasets (Datasets): Several parsed | |
multi-JSON dataset with chats. | |
default_dataset (str): Default selection of the dataset. | |
Returns: | |
None | |
""" | |
if default_dataset is None: | |
default_dataset = "ai_society_chat" | |
misalignment_set_names = {"misalignment"} | |
ordinary_datasets = [ | |
v for v in datasets.keys() if v not in misalignment_set_names | |
] | |
misalignment_datasets = [ | |
v for v in datasets.keys() if v in misalignment_set_names | |
] | |
default_dataset_name = default_dataset \ | |
if default_dataset in datasets.keys() \ | |
else ordinary_datasets[0] if len(ordinary_datasets) > 0 \ | |
else misalignment_datasets[0] if len(misalignment_datasets) > 0 \ | |
else "" | |
dataset_names = list(datasets.keys()) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
dataset_dd = gr.Dropdown(dataset_names, label="Select dataset", | |
value="NODEFAULT", interactive=True) | |
with gr.Row(): | |
disclaimer_ta = gr.Markdown( | |
"## By clicking AGREE I consent to use the dataset " | |
"for purely educational and academic purposes and " | |
"not use it for any fraudulent activity; and I take " | |
"all the responsibility if the data is used in a " | |
"malicious application.", visible=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
accept_disclaimer_bn = gr.Button("AGREE", visible=False) | |
with gr.Column(scale=1): | |
decline_disclaimer_bn = gr.Button("DECLINE", visible=False) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
assistant_dd = gr.Dropdown([], label="ASSISTANT", value="", | |
interactive=True) | |
with gr.Column(scale=3): | |
user_dd = gr.Dropdown([], label="USER", value="", | |
interactive=True) | |
with gr.Column(scale=1): | |
gr.Markdown( | |
"## CAMEL: Communicative Agents for \"Mind\" Exploration" | |
" of Large Scale Language Model Society\n" | |
"Github repo: [https://github.com/lightaime/camel]" | |
"(https://github.com/lightaime/camel)\n" | |
'<div style="display:flex; justify-content:center;">' | |
'<img src="https://raw.githubusercontent.com/lightaime/camel/' | |
'master/misc/logo.png" alt="Logo" style="max-width:50%;">' | |
'</div>') | |
task_dd = gr.Dropdown([], label="Original task", value="", | |
interactive=True) | |
specified_task_ta = gr.TextArea(label="Specified task", lines=2) | |
chatbot = gr.Chatbot() | |
accepted_st = gr.State(False) | |
def set_default_dataset() -> Dict: | |
""" Trigger for app load. | |
Returns: | |
Dict: Update dict for dataset_dd. | |
""" | |
return gr.update(value=default_dataset_name) | |
def check_if_misalignment(dataset_name: str, accepted: bool) \ | |
-> Tuple[Dict, Dict, Dict]: | |
""" Display AGREE/DECLINE if needed. | |
Returns: | |
Tuple: Visibility updates for the buttons. | |
""" | |
if dataset_name == "misalignment" and not accepted: | |
return gr.update(visible=True), \ | |
gr.update(visible=True), gr.update(visible=True) | |
else: | |
return gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False) | |
def enable_misalignment() -> Tuple[bool, Dict, Dict, Dict]: | |
""" Update the state of the accepted disclaimer. | |
Returns: | |
Tuple: New state and visibility updates for the buttons. | |
""" | |
return True, gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False) | |
def disable_misalignment() -> Tuple[bool, Dict, Dict, Dict]: | |
""" Update the state of the accepted disclaimer. | |
Returns: | |
Tuple: New state and visibility updates for the buttons. | |
""" | |
return False, gr.update(visible=False), \ | |
gr.update(visible=False), gr.update(visible=False) | |
def update_dataset_selection(dataset_name: str, | |
accepted: bool) -> Tuple[Dict, Dict]: | |
""" Update roles based on the selected dataset. | |
Args: | |
dataset_name (str): Name of the loaded .zip dataset. | |
accepted (bool): If the disclaimer thas been accepted. | |
Returns: | |
Tuple[Dict, Dict]: New Assistant and User roles. | |
""" | |
if dataset_name == "misalignment" and not accepted: | |
# If used did not accept the misalignment policy, | |
# keep the old selection. | |
return (gr.update(value="N/A", | |
choices=[]), gr.update(value="N/A", choices=[])) | |
dataset = datasets[dataset_name] | |
assistant_roles = dataset['assistant_roles'] | |
user_roles = dataset['user_roles'] | |
assistant_role = random.choice(assistant_roles) \ | |
if len(assistant_roles) > 0 else "" | |
user_role = random.choice(user_roles) if len(user_roles) > 0 else "" | |
return (gr.update(value=assistant_role, choices=assistant_roles), | |
gr.update(value=user_role, choices=user_roles)) | |
def roles_dd_change(dataset_name: str, assistant_role: str, | |
user_role: str) -> Dict: | |
""" Update the displayed chat upon inputs change. | |
Args: | |
assistant_role (str): Assistant dropdown value. | |
user_role (str): User dropdown value. | |
Returns: | |
Dict: New original roles state dictionary. | |
""" | |
matrix = datasets[dataset_name]['matrix'] | |
if (assistant_role, user_role) in matrix: | |
record: Dict[str, Dict] = matrix[(assistant_role, user_role)] | |
original_task_options = list(record.keys()) | |
original_task = original_task_options[0] | |
else: | |
original_task = "N/A" | |
original_task_options = [] | |
choices = gr.Dropdown(choices=original_task_options, | |
value=original_task, interactive=True) | |
return choices | |
def build_chat_history(messages: Dict[int, Dict]) -> List[Tuple]: | |
""" Structures chatbot contents from the loaded data. | |
Args: | |
messages (Dict[int, Dict]): Messages loaded from JSON. | |
Returns: | |
List[Tuple]: Chat history in chatbot UI element format. | |
""" | |
history: List[Tuple] = [] | |
curr_qa = (None, None) | |
for k in sorted(messages.keys()): | |
msg = messages[k] | |
content = msg['content'] | |
if msg['role_type'] == "USER": | |
if curr_qa[0] is not None: | |
history.append(curr_qa) | |
curr_qa = (content, None) | |
else: | |
curr_qa = (content, None) | |
elif msg['role_type'] == "ASSISTANT": | |
curr_qa = (curr_qa[0], content) | |
history.append(curr_qa) | |
curr_qa = (None, None) | |
else: | |
pass | |
return history | |
def task_dd_change(dataset_name: str, assistant_role: str, user_role: str, | |
original_task: str) -> Tuple[str, List]: | |
""" Load task details and chatbot history into UI elements. | |
Args: | |
assistant_role (str): An assistan role. | |
user_role (str): An user role. | |
original_task (str): The original task. | |
Returns: | |
Tuple[str, List]: New contents of the specified task | |
and chatbot history UI elements. | |
""" | |
matrix = datasets[dataset_name]['matrix'] | |
if (assistant_role, user_role) in matrix: | |
task_dict: Dict[str, Dict] = matrix[(assistant_role, user_role)] | |
if original_task in task_dict: | |
chat = task_dict[original_task] | |
specified_task = chat['specified_task'] | |
history = build_chat_history(chat['messages']) | |
else: | |
specified_task = "N/A" | |
history = [] | |
else: | |
specified_task = "N/A" | |
history = [] | |
return specified_task, history | |
dataset_dd.change(check_if_misalignment, [dataset_dd, accepted_st], | |
[disclaimer_ta, accept_disclaimer_bn, | |
decline_disclaimer_bn]) \ | |
.then(update_dataset_selection, | |
[dataset_dd, accepted_st], | |
[assistant_dd, user_dd]) | |
accept_disclaimer_bn.click(enable_misalignment, None, [ | |
accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn | |
]) \ | |
.then(update_dataset_selection, | |
[dataset_dd, accepted_st], | |
[assistant_dd, user_dd]) | |
decline_disclaimer_bn.click(disable_misalignment, None, [ | |
accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn | |
]) \ | |
.then(update_dataset_selection, | |
[dataset_dd, accepted_st], | |
[assistant_dd, user_dd]) | |
func_args = (roles_dd_change, [dataset_dd, assistant_dd, user_dd], task_dd) | |
assistant_dd.change(*func_args) | |
user_dd.change(*func_args) | |
task_dd.change(task_dd_change, | |
[dataset_dd, assistant_dd, user_dd, task_dd], | |
[specified_task_ta, chatbot]) | |
blocks.load(set_default_dataset, None, dataset_dd) | |
def construct_blocks(data_path: str, default_dataset: Optional[str]): | |
""" Construct Blocs app but do not launch it. | |
Args: | |
data_path (str): Path to the set of ZIP datasets. | |
default_dataset (Optional[str]): Name of the default dataset, | |
without extension. | |
Returns: | |
gr.Blocks: Blocks instance. | |
""" | |
print("Loading the dataset...") | |
datasets = load_datasets(data_path) | |
print("Dataset is loaded") | |
print("Getting Data Explorer web server online...") | |
with gr.Blocks() as blocks: | |
construct_ui(blocks, datasets, default_dataset) | |
return blocks | |
def main(): | |
""" Entry point. """ | |
args = parse_arguments() | |
blocks = construct_blocks(args.data_path, args.default_dataset) | |
blocks.queue(args.concurrency_count) \ | |
.launch(share=args.share, inbrowser=args.inbrowser, | |
server_name=args.server_name, server_port=args.server_port) | |
print("Exiting.") | |
if __name__ == "__main__": | |
main() | |