DmitriiKhizbullin's picture
Sync with the main repo
b25fb44
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
"""
Everything related to parsing the data JSONs into UI-compatible format.
"""
import glob
import os
import re
from typing import Any, Dict, Optional, Tuple, Union
from tqdm import tqdm
from apps.common.auto_zip import AutoZip
ChatHistory = Dict[str, Any]
ParsedChatHistory = Dict[str, Any]
AllChats = Dict[str, Any]
Datasets = Dict[str, AllChats]
REPO_ROOT = os.path.realpath(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../.."))
def parse(raw_chat: ChatHistory) -> Union[ParsedChatHistory, None]:
""" Gets the JSON raw chat data, validates it and transforms
into an easy to work with form.
Args:
raw_chat (ChatHistory): In-memory loaded JSON data file.
Returns:
Union[ParsedChatHistory, None]: Parsed chat data or None
if there were parsing errors.
"""
if "role_1" not in raw_chat:
return None
role_1 = raw_chat["role_1"]
if "_RoleType.ASSISTANT" not in role_1:
return None
assistant_role = role_1.split("_RoleType.ASSISTANT")
if len(assistant_role) < 1:
return None
if len(assistant_role[0]) <= 0:
return None
assistant_role = assistant_role[0]
role_2 = raw_chat["role_2"]
if "_RoleType.USER" not in role_2:
return None
user_role = role_2.split("_RoleType.USER")
if len(user_role) < 1:
return None
if len(user_role[0]) <= 0:
return None
user_role = user_role[0]
original_task = raw_chat["original_task"]
if len(original_task) <= 0:
return None
specified_task = raw_chat["specified_task"]
if len(specified_task) <= 0:
return None
messages = dict()
for key in raw_chat:
match = re.search("message_(?P<number>[0-9]+)", key)
if match:
number = int(match.group("number"))
messages[number] = raw_chat[key]
return dict(
assistant_role=assistant_role,
user_role=user_role,
original_task=original_task,
specified_task=specified_task,
messages=messages,
)
def load_zip(zip_path: str) -> AllChats:
""" Load all JSONs from a zip file and parse them.
Args:
path (str): path to the ZIP file.
Returns:
AllChats: A dictionary with all possible assistant and
user roles and the matrix of chats.
"""
zip_inst = AutoZip(zip_path)
parsed_list = []
for raw_chat in tqdm(iter(zip_inst)):
parsed = parse(raw_chat)
if parsed is None:
continue
parsed_list.append(parsed)
assistant_roles_set = set()
user_roles_set = set()
for parsed in parsed_list:
assistant_roles_set.add(parsed['assistant_role'])
user_roles_set.add(parsed['user_role'])
assistant_roles = list(sorted(assistant_roles_set))
user_roles = list(sorted(user_roles_set))
matrix: Dict[Tuple[str, str], Dict[str, Dict]] = dict()
for parsed in parsed_list:
key = (parsed['assistant_role'], parsed['user_role'])
original_task: str = parsed['original_task']
new_item = {
k: v
for k, v in parsed.items()
if k not in {'assistant_role', 'user_role', 'original_task'}
}
if key in matrix:
matrix[key][original_task] = new_item
else:
matrix[key] = {original_task: new_item}
return dict(
assistant_roles=assistant_roles,
user_roles=user_roles,
matrix=matrix,
)
def load_datasets(path: Optional[str] = None) -> Datasets:
""" Load all JSONs from a set of zip files and parse them.
Args:
path (str): path to the folder with ZIP datasets.
Returns:
Datasets: A dictionary of dataset name and dataset contents.
"""
if path is None:
path = os.path.join(REPO_ROOT, "datasets")
filt = os.path.join(path, "*.zip")
files = glob.glob(filt)
datasets = {}
for file_name in tqdm(files):
name = os.path.splitext(os.path.basename(file_name))[0]
datasets[name] = load_zip(file_name)
return datasets