Wendong-Fan's picture
Update apps/data_explorer/data_explorer.py
521fc66 verified
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
"""
Gradio-based web UI to explore the Camel dataset.
"""
import argparse
import random
from typing import Dict, List, Optional, Tuple
import gradio as gr
from apps.data_explorer.loader import Datasets, load_datasets
def parse_arguments():
""" Get command line arguments. """
parser = argparse.ArgumentParser("Camel data explorer")
parser.add_argument(
'--data-path', type=str, default=None,
help='Path to the folder with ZIP datasets containing JSONs')
parser.add_argument('--default-dataset', type=str, default=None,
help='Default dataset name selected from ZIPs')
parser.add_argument('--share', type=bool, default=False,
help='Expose the web UI to Gradio')
parser.add_argument(
'--server-name', type=str, default="0.0.0.0",
help='localhost for local, 0.0.0.0 (default) for public')
parser.add_argument('--server-port', type=int, default=8080,
help='Port ot run the web page on')
parser.add_argument('--inbrowser', type=bool, default=False,
help='Open the web UI in the default browser on lunch')
parser.add_argument(
'--concurrency-count', type=int, default=10,
help='Number if concurrent threads at Gradio websocket queue. ' +
'Increase to serve more requests but keep an eye on RAM usage.')
args, unknown = parser.parse_known_args()
if len(unknown) > 0:
print("Unknown args: ", unknown)
return args
def construct_ui(blocks, datasets: Datasets,
default_dataset: Optional[str] = None):
""" Build Gradio UI and populate with chat data from JSONs.
Args:
blocks: Gradio blocks
datasets (Datasets): Several parsed
multi-JSON dataset with chats.
default_dataset (str): Default selection of the dataset.
Returns:
None
"""
if default_dataset is None:
default_dataset = "ai_society_chat"
misalignment_set_names = {"misalignment"}
ordinary_datasets = [
v for v in datasets.keys() if v not in misalignment_set_names
]
misalignment_datasets = [
v for v in datasets.keys() if v in misalignment_set_names
]
default_dataset_name = default_dataset \
if default_dataset in datasets.keys() \
else ordinary_datasets[0] if len(ordinary_datasets) > 0 \
else misalignment_datasets[0] if len(misalignment_datasets) > 0 \
else ""
dataset_names = list(datasets.keys())
with gr.Row():
with gr.Column(scale=2):
with gr.Row():
dataset_dd = gr.Dropdown(dataset_names, label="Select dataset",
value="NODEFAULT", interactive=True)
with gr.Row():
disclaimer_ta = gr.Markdown(
"## By clicking AGREE I consent to use the dataset "
"for purely educational and academic purposes and "
"not use it for any fraudulent activity; and I take "
"all the responsibility if the data is used in a "
"malicious application.", visible=False)
with gr.Row():
with gr.Column(scale=1):
accept_disclaimer_bn = gr.Button("AGREE", visible=False)
with gr.Column(scale=1):
decline_disclaimer_bn = gr.Button("DECLINE", visible=False)
with gr.Row():
with gr.Column(scale=3):
assistant_dd = gr.Dropdown([], label="ASSISTANT", value="",
interactive=True)
with gr.Column(scale=3):
user_dd = gr.Dropdown([], label="USER", value="",
interactive=True)
with gr.Column(scale=1):
gr.Markdown(
"## CAMEL: Communicative Agents for \"Mind\" Exploration"
" of Large Scale Language Model Society\n"
"Github repo: [https://github.com/lightaime/camel]"
"(https://github.com/lightaime/camel)\n"
'<div style="display:flex; justify-content:center;">'
'<img src="https://raw.githubusercontent.com/lightaime/camel/'
'master/misc/logo.png" alt="Logo" style="max-width:50%;">'
'</div>')
task_dd = gr.Dropdown([], label="Original task", value="",
interactive=True)
specified_task_ta = gr.TextArea(label="Specified task", lines=2)
chatbot = gr.Chatbot()
accepted_st = gr.State(False)
def set_default_dataset() -> Dict:
""" Trigger for app load.
Returns:
Dict: Update dict for dataset_dd.
"""
return gr.update(value=default_dataset_name)
def check_if_misalignment(dataset_name: str, accepted: bool) \
-> Tuple[Dict, Dict, Dict]:
""" Display AGREE/DECLINE if needed.
Returns:
Tuple: Visibility updates for the buttons.
"""
if dataset_name == "misalignment" and not accepted:
return gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True)
else:
return gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
def enable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
""" Update the state of the accepted disclaimer.
Returns:
Tuple: New state and visibility updates for the buttons.
"""
return True, gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
def disable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
""" Update the state of the accepted disclaimer.
Returns:
Tuple: New state and visibility updates for the buttons.
"""
return False, gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False)
def update_dataset_selection(dataset_name: str,
accepted: bool) -> Tuple[Dict, Dict]:
""" Update roles based on the selected dataset.
Args:
dataset_name (str): Name of the loaded .zip dataset.
accepted (bool): If the disclaimer thas been accepted.
Returns:
Tuple[Dict, Dict]: New Assistant and User roles.
"""
if dataset_name == "misalignment" and not accepted:
# If used did not accept the misalignment policy,
# keep the old selection.
return (gr.update(value="N/A",
choices=[]), gr.update(value="N/A", choices=[]))
dataset = datasets[dataset_name]
assistant_roles = dataset['assistant_roles']
user_roles = dataset['user_roles']
assistant_role = random.choice(assistant_roles) \
if len(assistant_roles) > 0 else ""
user_role = random.choice(user_roles) if len(user_roles) > 0 else ""
return (gr.update(value=assistant_role, choices=assistant_roles),
gr.update(value=user_role, choices=user_roles))
def roles_dd_change(dataset_name: str, assistant_role: str,
user_role: str) -> Dict:
""" Update the displayed chat upon inputs change.
Args:
assistant_role (str): Assistant dropdown value.
user_role (str): User dropdown value.
Returns:
Dict: New original roles state dictionary.
"""
matrix = datasets[dataset_name]['matrix']
if (assistant_role, user_role) in matrix:
record: Dict[str, Dict] = matrix[(assistant_role, user_role)]
original_task_options = list(record.keys())
original_task = original_task_options[0]
else:
original_task = "N/A"
original_task_options = []
choices = gr.Dropdown(choices=original_task_options,
value=original_task, interactive=True)
return choices
def build_chat_history(messages: Dict[int, Dict]) -> List[Tuple]:
""" Structures chatbot contents from the loaded data.
Args:
messages (Dict[int, Dict]): Messages loaded from JSON.
Returns:
List[Tuple]: Chat history in chatbot UI element format.
"""
history: List[Tuple] = []
curr_qa = (None, None)
for k in sorted(messages.keys()):
msg = messages[k]
content = msg['content']
if msg['role_type'] == "USER":
if curr_qa[0] is not None:
history.append(curr_qa)
curr_qa = (content, None)
else:
curr_qa = (content, None)
elif msg['role_type'] == "ASSISTANT":
curr_qa = (curr_qa[0], content)
history.append(curr_qa)
curr_qa = (None, None)
else:
pass
return history
def task_dd_change(dataset_name: str, assistant_role: str, user_role: str,
original_task: str) -> Tuple[str, List]:
""" Load task details and chatbot history into UI elements.
Args:
assistant_role (str): An assistan role.
user_role (str): An user role.
original_task (str): The original task.
Returns:
Tuple[str, List]: New contents of the specified task
and chatbot history UI elements.
"""
matrix = datasets[dataset_name]['matrix']
if (assistant_role, user_role) in matrix:
task_dict: Dict[str, Dict] = matrix[(assistant_role, user_role)]
if original_task in task_dict:
chat = task_dict[original_task]
specified_task = chat['specified_task']
history = build_chat_history(chat['messages'])
else:
specified_task = "N/A"
history = []
else:
specified_task = "N/A"
history = []
return specified_task, history
dataset_dd.change(check_if_misalignment, [dataset_dd, accepted_st],
[disclaimer_ta, accept_disclaimer_bn,
decline_disclaimer_bn]) \
.then(update_dataset_selection,
[dataset_dd, accepted_st],
[assistant_dd, user_dd])
accept_disclaimer_bn.click(enable_misalignment, None, [
accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
]) \
.then(update_dataset_selection,
[dataset_dd, accepted_st],
[assistant_dd, user_dd])
decline_disclaimer_bn.click(disable_misalignment, None, [
accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
]) \
.then(update_dataset_selection,
[dataset_dd, accepted_st],
[assistant_dd, user_dd])
func_args = (roles_dd_change, [dataset_dd, assistant_dd, user_dd], task_dd)
assistant_dd.change(*func_args)
user_dd.change(*func_args)
task_dd.change(task_dd_change,
[dataset_dd, assistant_dd, user_dd, task_dd],
[specified_task_ta, chatbot])
blocks.load(set_default_dataset, None, dataset_dd)
def construct_blocks(data_path: str, default_dataset: Optional[str]):
""" Construct Blocs app but do not launch it.
Args:
data_path (str): Path to the set of ZIP datasets.
default_dataset (Optional[str]): Name of the default dataset,
without extension.
Returns:
gr.Blocks: Blocks instance.
"""
print("Loading the dataset...")
datasets = load_datasets(data_path)
print("Dataset is loaded")
print("Getting Data Explorer web server online...")
with gr.Blocks() as blocks:
construct_ui(blocks, datasets, default_dataset)
return blocks
def main():
""" Entry point. """
args = parse_arguments()
blocks = construct_blocks(args.data_path, args.default_dataset)
blocks.queue(args.concurrency_count) \
.launch(share=args.share, inbrowser=args.inbrowser,
server_name=args.server_name, server_port=args.server_port)
print("Exiting.")
if __name__ == "__main__":
main()