DmitriiKhizbullin commited on
Commit
a25aa8f
·
1 Parent(s): 2b43e50

Cloned from Dmitrii space

Browse files
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from apps.data_explorer.data_explorer import construct_blocks, parse_arguments
2
+ from apps.data_explorer.downloader import download_data
3
+
4
+ if __name__ == "__main__":
5
+
6
+ download_data()
7
+
8
+ args = parse_arguments()
9
+ blocks = construct_blocks(args.data_path, args.default_dataset)
10
+ blocks.launch()
apps/data_explorer/data_explorer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio-based web UI to explore the Camel dataset.
3
+ """
4
+
5
+ import argparse
6
+ import random
7
+ from typing import Dict, List, Optional, Tuple
8
+
9
+ import gradio as gr
10
+
11
+ from apps.data_explorer.loader import Datasets, load_datasets
12
+
13
+
14
+ def parse_arguments():
15
+ """ Get command line arguments. """
16
+
17
+ parser = argparse.ArgumentParser("Camel data explorer")
18
+ parser.add_argument(
19
+ '--data-path', type=str, default=None,
20
+ help='Path to the folder with ZIP datasets containing JSONs')
21
+ parser.add_argument('--default-dataset', type=str, default=None,
22
+ help='Default dataset name selected from ZIPs')
23
+ parser.add_argument('--share', type=bool, default=False,
24
+ help='Expose the web UI to Gradio')
25
+ parser.add_argument(
26
+ '--server-name', type=str, default="0.0.0.0",
27
+ help='localhost for local, 0.0.0.0 (default) for public')
28
+ parser.add_argument('--server-port', type=int, default=8080,
29
+ help='Port ot run the web page on')
30
+ parser.add_argument('--inbrowser', type=bool, default=False,
31
+ help='Open the web UI in the default browser on lunch')
32
+ parser.add_argument(
33
+ '--concurrency-count', type=int, default=10,
34
+ help='Number if concurrent threads at Gradio websocket queue. ' +
35
+ 'Increase to serve more requests but keep an eye on RAM usage.')
36
+ args, unknown = parser.parse_known_args()
37
+ if len(unknown) > 0:
38
+ print("Unknown args: ", unknown)
39
+ return args
40
+
41
+
42
+ def construct_ui(blocks, datasets: Datasets, default_dataset: str = None):
43
+ """ Build Gradio UI and populate with chat data from JSONs.
44
+
45
+ Args:
46
+ blocks: Gradio blocks
47
+ datasets (Datasets): Several parsed
48
+ multi-JSON dataset with chats.
49
+ default_dataset (str): Default selection of the dataset.
50
+
51
+ Returns:
52
+ None
53
+ """
54
+
55
+ if default_dataset is None:
56
+ default_dataset = "ai_society_chat"
57
+
58
+ misalignment_set_names = {"misalignment"}
59
+ ordinary_datasets = [
60
+ v for v in datasets.keys() if v not in misalignment_set_names
61
+ ]
62
+ misalignment_datasets = [
63
+ v for v in datasets.keys() if v in misalignment_set_names
64
+ ]
65
+ default_dataset_name = default_dataset \
66
+ if default_dataset in datasets.keys() \
67
+ else ordinary_datasets[0] if len(ordinary_datasets) > 0 \
68
+ else misalignment_datasets[0] if len(misalignment_datasets) > 0 \
69
+ else ""
70
+ dataset_names = list(datasets.keys())
71
+
72
+ with gr.Row().style():
73
+ with gr.Column(scale=2):
74
+ with gr.Row():
75
+ dataset_dd = gr.Dropdown(dataset_names, label="Select dataset",
76
+ value="NODEFAULT", interactive=True)
77
+ with gr.Row():
78
+ disclaimer_ta = gr.Markdown(
79
+ "## By clicking AGREE I consent to use the dataset "
80
+ "for purely educational and academic purposes and "
81
+ "not use it for any fraudulent activity; and I take "
82
+ "all the responsibility if the data is used in a "
83
+ "malicious application.", visible=False)
84
+ with gr.Row():
85
+ with gr.Column(scale=1):
86
+ accept_disclaimer_bn = gr.Button("AGREE", visible=False)
87
+ with gr.Column(scale=1):
88
+ decline_disclaimer_bn = gr.Button("DECLINE", visible=False)
89
+ with gr.Row():
90
+ with gr.Column(scale=3):
91
+ assistant_dd = gr.Dropdown([], label="ASSISTANT", value="",
92
+ interactive=True)
93
+ with gr.Column(scale=3):
94
+ user_dd = gr.Dropdown([], label="USER", value="",
95
+ interactive=True)
96
+ with gr.Column(scale=1):
97
+ gr.Markdown(
98
+ "## CAMEL: Communicative Agents for \"Mind\" Exploration"
99
+ " of Large Scale Language Model Society\n"
100
+ "Github repo: [https://github.com/lightaime/camel]"
101
+ "(https://github.com/lightaime/camel)\n"
102
+ '<div style="display:flex; justify-content:center;">'
103
+ '<img src="https://raw.githubusercontent.com/lightaime/camel/'
104
+ 'master/misc/logo.png" alt="Logo" style="max-width:50%;">'
105
+ '</div>')
106
+
107
+ task_dd = gr.Dropdown([], label="Original task", value="",
108
+ interactive=True)
109
+ specified_task_ta = gr.TextArea(label="Specified task", lines=2)
110
+ chatbot = gr.Chatbot()
111
+ accepted_st = gr.State(False)
112
+
113
+ def set_default_dataset() -> Dict:
114
+ """ Trigger for app load.
115
+
116
+ Returns:
117
+ Dict: Update dict for dataset_dd.
118
+ """
119
+ return gr.update(value=default_dataset_name)
120
+
121
+ def check_if_misalignment(dataset_name: str, accepted: bool) \
122
+ -> Tuple[Dict, Dict, Dict]:
123
+ """ Display AGREE/DECLINE if needed.
124
+
125
+ Returns:
126
+ Tuple: Visibility updates for the buttons.
127
+ """
128
+
129
+ if dataset_name == "misalignment" and not accepted:
130
+ return gr.update(visible=True), \
131
+ gr.update(visible=True), gr.update(visible=True)
132
+ else:
133
+ return gr.update(visible=False), \
134
+ gr.update(visible=False), gr.update(visible=False)
135
+
136
+ def enable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
137
+ """ Update the state of the accepted disclaimer.
138
+
139
+ Returns:
140
+ Tuple: New state and visibility updates for the buttons.
141
+ """
142
+
143
+ return True, gr.update(visible=False), \
144
+ gr.update(visible=False), gr.update(visible=False)
145
+
146
+ def disable_misalignment() -> Tuple[bool, Dict, Dict, Dict]:
147
+ """ Update the state of the accepted disclaimer.
148
+
149
+ Returns:
150
+ Tuple: New state and visibility updates for the buttons.
151
+ """
152
+
153
+ return False, gr.update(visible=False), \
154
+ gr.update(visible=False), gr.update(visible=False)
155
+
156
+ def update_dataset_selection(dataset_name: str,
157
+ accepted: bool) -> Tuple[Dict, Dict]:
158
+ """ Update roles based on the selected dataset.
159
+
160
+ Args:
161
+ dataset_name (str): Name of the loaded .zip dataset.
162
+ accepted (bool): If the disclaimer thas been accepted.
163
+
164
+ Returns:
165
+ Tuple[Dict, Dict]: New Assistant and User roles.
166
+ """
167
+
168
+ if dataset_name == "misalignment" and not accepted:
169
+ # If used did not accept the misalignment policy,
170
+ # keep the old selection.
171
+ return (gr.update(value="N/A",
172
+ choices=[]), gr.update(value="N/A", choices=[]))
173
+
174
+ dataset = datasets[dataset_name]
175
+ assistant_roles = dataset['assistant_roles']
176
+ user_roles = dataset['user_roles']
177
+ assistant_role = random.choice(assistant_roles) \
178
+ if len(assistant_roles) > 0 else ""
179
+ user_role = random.choice(user_roles) if len(user_roles) > 0 else ""
180
+ return (gr.update(value=assistant_role, choices=assistant_roles),
181
+ gr.update(value=user_role, choices=user_roles))
182
+
183
+ def roles_dd_change(dataset_name: str, assistant_role: str,
184
+ user_role: str) -> Dict:
185
+ """ Update the displayed chat upon inputs change.
186
+
187
+ Args:
188
+ assistant_role (str): Assistant dropdown value.
189
+ user_role (str): User dropdown value.
190
+
191
+ Returns:
192
+ Dict: New original roles state dictionary.
193
+ """
194
+ matrix = datasets[dataset_name]['matrix']
195
+ if (assistant_role, user_role) in matrix:
196
+ record: Dict[str, Dict] = matrix[(assistant_role, user_role)]
197
+ original_task_options = list(record.keys())
198
+ original_task = original_task_options[0]
199
+ else:
200
+ original_task = "N/A"
201
+ original_task_options = []
202
+
203
+ choices = gr.Dropdown.update(choices=original_task_options,
204
+ value=original_task, interactive=True)
205
+ return choices
206
+
207
+ def build_chat_history(messages: Dict[int, Dict]) -> List[Tuple]:
208
+ """ Structures chatbot contents from the loaded data.
209
+
210
+ Args:
211
+ messages (Dict[int, Dict]): Messages loaded from JSON.
212
+
213
+ Returns:
214
+ List[Tuple]: Chat history in chatbot UI element format.
215
+ """
216
+ history = []
217
+ curr_qa = (None, None)
218
+ for k in sorted(messages.keys()):
219
+ msg = messages[k]
220
+ content = msg['content']
221
+ if msg['role_type'] == "USER":
222
+ if curr_qa[0] is not None:
223
+ history.append(curr_qa)
224
+ curr_qa = (content, None)
225
+ else:
226
+ curr_qa = (content, None)
227
+ elif msg['role_type'] == "ASSISTANT":
228
+ curr_qa = (curr_qa[0], content)
229
+ history.append(curr_qa)
230
+ curr_qa = (None, None)
231
+ else:
232
+ pass
233
+ return history
234
+
235
+ def task_dd_change(dataset_name: str, assistant_role: str, user_role: str,
236
+ original_task: str) -> Tuple[str, List]:
237
+ """ Load task details and chatbot history into UI elements.
238
+
239
+ Args:
240
+ assistant_role (str): An assistan role.
241
+ user_role (str): An user role.
242
+ original_task (str): The original task.
243
+
244
+ Returns:
245
+ Tuple[str, List]: New contents of the specified task
246
+ and chatbot history UI elements.
247
+ """
248
+
249
+ matrix = datasets[dataset_name]['matrix']
250
+ if (assistant_role, user_role) in matrix:
251
+ task_dict: Dict[str, Dict] = matrix[(assistant_role, user_role)]
252
+ if original_task in task_dict:
253
+ chat = task_dict[original_task]
254
+ specified_task = chat['specified_task']
255
+ history = build_chat_history(chat['messages'])
256
+ else:
257
+ specified_task = "N/A"
258
+ history = []
259
+ else:
260
+ specified_task = "N/A"
261
+ history = []
262
+ return specified_task, history
263
+
264
+ dataset_dd.change(check_if_misalignment, [dataset_dd, accepted_st],
265
+ [disclaimer_ta, accept_disclaimer_bn,
266
+ decline_disclaimer_bn]) \
267
+ .then(update_dataset_selection,
268
+ [dataset_dd, accepted_st],
269
+ [assistant_dd, user_dd])
270
+
271
+ accept_disclaimer_bn.click(enable_misalignment, None, [
272
+ accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
273
+ ]) \
274
+ .then(update_dataset_selection,
275
+ [dataset_dd, accepted_st],
276
+ [assistant_dd, user_dd])
277
+
278
+ decline_disclaimer_bn.click(disable_misalignment, None, [
279
+ accepted_st, disclaimer_ta, accept_disclaimer_bn, decline_disclaimer_bn
280
+ ]) \
281
+ .then(update_dataset_selection,
282
+ [dataset_dd, accepted_st],
283
+ [assistant_dd, user_dd])
284
+
285
+ func_args = (roles_dd_change, [dataset_dd, assistant_dd, user_dd], task_dd)
286
+ assistant_dd.change(*func_args)
287
+ user_dd.change(*func_args)
288
+
289
+ task_dd.change(task_dd_change,
290
+ [dataset_dd, assistant_dd, user_dd, task_dd],
291
+ [specified_task_ta, chatbot])
292
+
293
+ blocks.load(set_default_dataset, None, dataset_dd)
294
+
295
+
296
+ def construct_blocks(data_path: str, default_dataset: Optional[str]):
297
+ """ Construct Blocs app but do not launch it.
298
+
299
+ Args:
300
+ data_path (str): Path to the set of ZIP datasets.
301
+ default_dataset (Optional[str]): Name of the default dataset,
302
+ without extension.
303
+
304
+ Returns:
305
+ gr.Blocks: Blocks instance.
306
+ """
307
+
308
+ print("Loading the dataset...")
309
+ datasets = load_datasets(data_path)
310
+ print("Dataset is loaded")
311
+
312
+ print("Getting Data Explorer web server online...")
313
+
314
+ with gr.Blocks() as blocks:
315
+ construct_ui(blocks, datasets, default_dataset)
316
+
317
+ return blocks
318
+
319
+
320
+ def main():
321
+ """ Entry point. """
322
+
323
+ args = parse_arguments()
324
+
325
+ blocks = construct_blocks(args.data_path, args.default_dataset)
326
+
327
+ blocks.queue(args.concurrency_count) \
328
+ .launch(share=args.share, inbrowser=args.inbrowser,
329
+ server_name=args.server_name, server_port=args.server_port)
330
+
331
+ print("Exiting.")
332
+
333
+
334
+ if __name__ == "__main__":
335
+ main()
apps/data_explorer/downloader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib.request
3
+
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ REPO_ROOT = os.path.realpath(
7
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
8
+
9
+
10
+ def download_data():
11
+
12
+ print("Downloading...")
13
+
14
+ data_dir = os.path.join(REPO_ROOT, "datasets/")
15
+
16
+ os.makedirs(data_dir, exist_ok=True)
17
+
18
+ try:
19
+ hf_hub_download(repo_id="camel-ai/ai_society", repo_type="dataset",
20
+ filename="ai_society_chat.zip", local_dir=data_dir,
21
+ local_dir_use_symlinks=False)
22
+
23
+ hf_hub_download(repo_id="camel-ai/code", repo_type="dataset",
24
+ filename="code_chat.zip", local_dir=data_dir,
25
+ local_dir_use_symlinks=False)
26
+ except:
27
+ for name in ("ai_society_chat.zip", "code_chat.zip"):
28
+ data_url = ("https://storage.googleapis.com/"
29
+ f"camel-bucket/datasets/private/{name}")
30
+ file_path = os.path.join(data_dir, os.path.split(data_url)[1])
31
+ urllib.request.urlretrieve(data_url, file_path)
32
+
33
+ data_url = ("https://storage.googleapis.com/"
34
+ "camel-bucket/datasets/private/misalignment.zip")
35
+ file_path = os.path.join(data_dir, os.path.split(data_url)[1])
36
+ urllib.request.urlretrieve(data_url, file_path)
37
+
38
+ print("Download done")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ download_data()
apps/data_explorer/loader.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Everything related to parsing the data JSONs into UI-compatible format.
3
+ """
4
+
5
+ import glob
6
+ import json
7
+ import os
8
+ import re
9
+ import zipfile
10
+ from typing import Any, Dict, List, Optional, Tuple, Union
11
+
12
+ from tqdm import tqdm
13
+
14
+ ChatHistory = Dict[str, Any]
15
+ ParsedChatHistory = Dict[str, Any]
16
+ AllChats = Dict[str, Any]
17
+ Datasets = Dict[str, AllChats]
18
+
19
+ REPO_ROOT = os.path.realpath(
20
+ os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
21
+
22
+
23
+ class AutoZip:
24
+ def __init__(self, zip_path: str, ext: str = ".json"):
25
+ self.zip_path = zip_path
26
+ self.zip = zipfile.ZipFile(zip_path, "r")
27
+ self.fl = [f for f in self.zip.filelist if f.filename.endswith(ext)]
28
+
29
+ def __next__(self):
30
+ if self.index >= len(self.fl):
31
+ raise StopIteration
32
+ else:
33
+ finfo = self.fl[self.index]
34
+ with self.zip.open(finfo) as f:
35
+ raw_json = json.loads(f.read().decode("utf-8"))
36
+ self.index += 1
37
+ return raw_json
38
+
39
+ def __len__(self):
40
+ return len(self.fl)
41
+
42
+ def __iter__(self):
43
+ self.index = 0
44
+ return self
45
+
46
+
47
+ def parse(raw_chat: ChatHistory) -> Union[ParsedChatHistory, None]:
48
+ """ Gets the JSON raw chat data, validates it and transforms
49
+ into an easy to work with form.
50
+
51
+ Args:
52
+ raw_chat (ChatHistory): In-memory loaded JSON data file.
53
+
54
+ Returns:
55
+ Union[ParsedChatHistory, None]: Parsed chat data or None
56
+ if there were parsing errors.
57
+ """
58
+
59
+ if "role_1" not in raw_chat:
60
+ return None
61
+
62
+ role_1 = raw_chat["role_1"]
63
+ if "_RoleType.ASSISTANT" not in role_1:
64
+ return None
65
+ assistant_role = role_1.split("_RoleType.ASSISTANT")
66
+ if len(assistant_role) < 1:
67
+ return None
68
+ if len(assistant_role[0]) <= 0:
69
+ return None
70
+ assistant_role = assistant_role[0]
71
+
72
+ role_2 = raw_chat["role_2"]
73
+ if "_RoleType.USER" not in role_2:
74
+ return None
75
+ user_role = role_2.split("_RoleType.USER")
76
+ if len(user_role) < 1:
77
+ return None
78
+ if len(user_role[0]) <= 0:
79
+ return None
80
+ user_role = user_role[0]
81
+
82
+ original_task = raw_chat["original_task"]
83
+ if len(original_task) <= 0:
84
+ return None
85
+
86
+ specified_task = raw_chat["specified_task"]
87
+ if len(specified_task) <= 0:
88
+ return None
89
+
90
+ messages = dict()
91
+ for key in raw_chat:
92
+ match = re.search("message_(?P<number>[0-9]+)", key)
93
+ if match:
94
+ number = int(match.group("number"))
95
+ messages[number] = raw_chat[key]
96
+
97
+ return dict(
98
+ assistant_role=assistant_role,
99
+ user_role=user_role,
100
+ original_task=original_task,
101
+ specified_task=specified_task,
102
+ messages=messages,
103
+ )
104
+
105
+
106
+ def load_zip(zip_path: str) -> AllChats:
107
+ """ Load all JSONs from a zip file and parse them.
108
+
109
+ Args:
110
+ path (str): path to the ZIP file.
111
+
112
+ Returns:
113
+ AllChats: A dictionary with all possible assistant and
114
+ user roles and the matrix of chats.
115
+ """
116
+
117
+ zip_inst = AutoZip(zip_path)
118
+ parsed_list = []
119
+ for raw_chat in tqdm(iter(zip_inst)):
120
+ parsed = parse(raw_chat)
121
+ if parsed is None:
122
+ continue
123
+ parsed_list.append(parsed)
124
+
125
+ assistant_roles = set()
126
+ user_roles = set()
127
+ for parsed in parsed_list:
128
+ assistant_roles.add(parsed['assistant_role'])
129
+ user_roles.add(parsed['user_role'])
130
+ assistant_roles = list(sorted(assistant_roles))
131
+ user_roles = list(sorted(user_roles))
132
+ matrix: Dict[Tuple[str, str], List[Dict]] = dict()
133
+ for parsed in parsed_list:
134
+ key = (parsed['assistant_role'], parsed['user_role'])
135
+ original_task = parsed['original_task']
136
+ new_item = {
137
+ k: v
138
+ for k, v in parsed.items()
139
+ if k not in {'assistant_role', 'user_role', 'original_task'}
140
+ }
141
+ if key in matrix:
142
+ matrix[key][original_task] = new_item
143
+ else:
144
+ matrix[key] = {original_task: new_item}
145
+
146
+ return dict(
147
+ assistant_roles=assistant_roles,
148
+ user_roles=user_roles,
149
+ matrix=matrix,
150
+ )
151
+
152
+
153
+ def load_datasets(path: Optional[str] = None) -> Datasets:
154
+ """ Load all JSONs from a set of zip files and parse them.
155
+
156
+ Args:
157
+ path (str): path to the folder with ZIP datasets.
158
+
159
+ Returns:
160
+ Datasets: A dictionary of dataset name and dataset contents.
161
+ """
162
+
163
+ if path is None:
164
+ path = os.path.join(REPO_ROOT, "datasets")
165
+
166
+ filt = os.path.join(path, "*.zip")
167
+ files = glob.glob(filt)
168
+ datasets = {}
169
+ for file_name in tqdm(files):
170
+ name = os.path.splitext(os.path.basename(file_name))[0]
171
+ datasets[name] = load_zip(file_name)
172
+ return datasets