bpiyush's picture
Upload folder using huggingface_hub
eafbf97 verified
"""
Utilities for input-output loading/saving.
"""
from typing import Any, List
import yaml
import pickle
import json
import pandas as pd
class PrettySafeLoader(yaml.SafeLoader):
"""Custom loader for reading YAML files"""
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
PrettySafeLoader.add_constructor(
u'tag:yaml.org,2002:python/tuple',
PrettySafeLoader.construct_python_tuple
)
def load_yml(path: str, loader_type: str = 'default'):
"""Read params from a yml file.
Args:
path (str): path to the .yml file
loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
Returns:
Any: object (typically dict) loaded from .yml file
"""
assert loader_type in ['default', 'safe']
loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
with open(path, 'r') as f:
data = yaml.load(f, Loader=loader)
return data
def save_yml(data: dict, path: str):
"""Save params in the given yml file path.
Args:
data (dict): data object to save
path (str): path to .yml file to be saved
"""
with open(path, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def load_pkl(path: str, encoding: str = "ascii"):
"""Loads a .pkl file.
Args:
path (str): path to the .pkl file
encoding (str, optional): encoding to use for loading. Defaults to "ascii".
Returns:
Any: unpickled object
"""
return pickle.load(open(path, "rb"), encoding=encoding)
def save_pkl(data: Any, path: str) -> None:
"""Saves given object into .pkl file
Args:
data (Any): object to be saved
path (str): path to the location to be saved at
"""
with open(path, 'wb') as f:
pickle.dump(data, f)
def load_json(path: str) -> dict:
"""Helper to load json file"""
with open(path, 'rb') as f:
data = json.load(f)
return data
def save_json(data: dict, path: str):
"""Helper to save `dict` as .json file."""
with open(path, 'w') as f:
json.dump(data, f)
def load_txt(path: str):
"""Loads lines of a .txt file.
Args:
path (str): path to the .txt file
Returns:
List: lines of .txt file
"""
with open(path) as f:
lines = f.read().splitlines()
return lines
def save_txt(data: dict, path: str):
"""Writes data (lines) to a txt file.
Args:
data (dict): List of strings
path (str): path to .txt file
"""
assert isinstance(data, list)
lines = "\n".join(data)
with open(path, "w") as f:
f.write(str(lines))
def read_spreadsheet(sheet_id, gid, url=None, drop_na=True, **kwargs):
if url is None:
BASE_URL = 'https://docs.google.com/spreadsheets/d/'
url = BASE_URL + sheet_id + f'/export?gid={gid}&format=csv'
df = pd.read_csv(url, **kwargs)
if drop_na:
# drop all rows which have atleast 1 NaN value
df = df.dropna(axis=0)
return df
def load_midi(file, rate=16000):
import pretty_midi
assert file.endswith('.mid')
pm = pretty_midi.PrettyMIDI(file)
y = pm.synthesize(fs=rate)
return y, rate
def load_ptz(path):
import gzip
import torch
with gzip.open(path, 'rb') as f:
data = torch.load(f)
return data
def save_video(frames, path, fps=30):
import imageio
imageio.mimwrite(path, frames, fps=fps)