mztelus
September 2, 2023, 5:33pm
1
I’ve made a local dataset by transforming Common Voice’s audio files into spectrograms. As such, each sample holds a tensor of (513, 128) shape (along some tabular features). Then, I’ve saved this into two formats.
First, I used tensor.save
(along a JSON file for the other features) to save them into hard drive. This way each sample is composed of two files.
Second, I used HF’s datasets package to save the same dataset into arrow format:
def arrow_generator(data_samples):
def _gen():
for rec in data_samples:
yield {
"id": rec["id"],
"client_id": rec["client_id"],
"locale": rec["locale"],
"spectrogram": rec["spectrogram"], # PyTorch tensor
}
return _gen
ds = Dataset.from_generator(arrow_generator(data_samples))
ds.save_to_disk("./arrow_dataset")
To compare the two, the HF’s arrow format takes half the size on the disk. But at the same time, it takes 16x longer to read a batch (load_from_disk
is 16 times slower than the torch.load
).
What I wanted to ask is if this is expected or am I doing some thing wrong?
Hi! torch.save
uses pickle
under the hood to serialize objects, and pickle
is slower than Feather, a format very similar to ours (we plan to switch to Feather eventually), according to this blog post .
Do you mind profiling the load_from_disk
call using the code below?
import cProfile, pstats
from datasets import load_from_disk
with cProfile.Profile() as profiler:
ds = load_from_disk(...)
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()
So that we can be sure this is an Arrow issue.
1 Like
mztelus
September 7, 2023, 4:48pm
3
@mariosasko Thanks for your help. Before I report the stats, her’s my code:
import cProfile, pstats
from datasets import load_from_disk
class MyDataset(Dataset):
def __init__(self, parent_folder):
self.parent_folder = parent_folder
json_files = glob.glob(os.path.join(parent_folder, '**/*.json'),
recursive=True)
self.files = []
for json_file in json_files:
spec_file = json_file.replace('.json', '.spec')
self.files.append((json_file, spec_file))
def __getitem__(self, index):
json_file, spec_file = self.files[index]
with open(json_file, 'r') as fp:
_json = json.load(fp)
locale = os.path.relpath(os.path.dirname(json_file),
self.parent_folder)
id = os.path.splitext(os.path.basename(json_file))[0]
_json["id"] = f"{locale}_{id}"
_json["locale"] = locale
_json["file"] = spec_file
spec = torch.load(spec_file)
return torch.squeeze(spec), _json
def __len__(self):
return len(self.data)
with cProfile.Profile() as profiler:
ds = load_from_disk("./arrow")
for i in range(1000):
sample = ds[i]
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()
with cProfile.Profile() as profiler:
ds = MyDataset("./individual_files")
for i in range(1000):
sample = ds[i]
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()
And here’s the stat repot:
Arrow
226695 function calls (201237 primitive calls) in 14.371 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.004 0.000 14.314 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
1000 0.004 0.000 14.309 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
1000 0.002 0.000 14.269 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
1000 0.001 0.000 14.265 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
1000 0.002 0.000 14.265 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:428(format_row)
1000 0.003 0.000 14.244 0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:143(extract_row)
1000 14.238 0.014 14.238 0.014 {method 'to_pydict' of 'pyarrow.lib.Table' objects}
1 0.000 0.000 0.057 0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/load.py:1820(load_from_disk)
1 0.000 0.000 0.057 0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:1572(load_from_disk)
11319/473 0.012 0.000 0.035 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:128(deepcopy)
437/272 0.000 0.000 0.033 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:227(_deepcopy_dict)
1 0.000 0.000 0.033 0.033 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601(with_format)
3121/4 0.004 0.000 0.033 0.008 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:259(_reconstruct)
536/137 0.001 0.000 0.032 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:201(_deepcopy_list)
134/1 0.000 0.000 0.032 0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:167(__deepcopy__)
134/1 0.001 0.000 0.032 0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:66(_deepcopy)
1000 0.002 0.000 0.030 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:547(query_table)
10798/424 0.003 0.000 0.027 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:264(<genexpr>)
Torch
4183801 function calls (4183741 primitive calls) in 2.283 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.009 0.000 1.415 0.001 /tmp/ipykernel_56163/3477044752.py:14(__getitem__)
1000 0.008 0.000 1.252 0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:671(load)
1000 0.009 0.000 0.886 0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1104(_load)
1000 0.007 0.000 0.868 0.001 {method 'load' of '_pickle.Unpickler' objects}
1 0.044 0.044 0.860 0.860 /tmp/ipykernel_56163/3477044752.py:6(__init__)
1000 0.002 0.000 0.826 0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1125(persistent_load)
1000 0.810 0.001 0.821 0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1109(load_tensor)
1 0.022 0.022 0.793 0.793 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:13(glob)
199849/199818 0.041 0.000 0.770 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:53(_iglob)
30 0.000 0.000 0.376 0.013 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:93(_glob1)
60 0.031 0.001 0.305 0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:162(_listdir)
399752 0.259 0.000 0.274 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:128(_iterdir)
30 0.075 0.002 0.218 0.007 /home/mehran/.conda/envs/whisper/lib/python3.10/fnmatch.py:54(filter)
200877 0.135 0.000 0.207 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/posixpath.py:71(join)
1000 0.003 0.000 0.152 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:66(_is_zipfile)
4000 0.148 0.000 0.148 0.000 {method 'read' of '_io.BufferedReader' objects}
31 0.000 0.000 0.147 0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:121(_glob2)
59/30 0.000 0.000 0.147 0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:167(_rlistdir)
I had to eliminate the lower lines of both reports since the platform did not let me reply so many characters.
This does not show the 16x difference I talked about before, but I guess that was a slightly different scenario. For instance, I was accessing the records randomly then which might had some effect on the performance. Or I was reading more records then.
In any case, there’s a huge difference between the two. It would be great if some how I can improve the performance with arrow since it conserves a lot of space of my drive (individual_files: 184 GB vs arrow: 61 GB) .
mztelus
September 7, 2023, 6:36pm
4
Reading some tutorials, I learned that I could be converting the read dataset into PyTorch tensors directly using load_from_disk(...).with_format("torch")
function.
with cProfile.Profile() as profiler:
ds = load_from_disk("./arrow").with_format("torch")
for i in range(1000):
sample = ds[i]
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()
This simple change improved the performance a lot:
269871 function calls (6612931 primitive calls) in 3.085 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.004 0.000 2.920 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
1000 0.004 0.000 2.916 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
1000 0.002 0.000 2.881 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
1000 0.007 0.000 2.876 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
1000 0.003 0.000 2.869 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:80(format_row)
203000/1000 0.075 0.000 2.697 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:77(recursive_tensorize)
203000/1000 0.157 0.000 2.696 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:372(map_nested)
1000 0.005 0.000 2.644 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:443(<listcomp>)
7000 0.004 0.000 2.637 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:340(_single_map_nested)
209000/7000 0.161 0.000 2.631 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:70(_recursive_tensorize)
2000/1000 0.051 0.000 2.556 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:74(<listcomp>)
207000 0.391 0.000 1.922 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:49(_tensorize)
204000 0.830 0.000 0.830 0.000 {built-in method torch.tensor}
607000 0.269 0.000 0.642 0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:356(issubdtype)
1214000 0.232 0.000 0.339 0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:282(issubclass_)
I’m not sure if there’s any more room left to improve the performance, but this alone solves my problem.
Thanks, @mariosasko , I could not have done it without you.
You should get better performance by using
for sample in iter(ds):
...
instead of
for i in range(1000):
sample = ds[i]
Also, most of the time is spent in PyArrow’s to_pydict
(Arrow objects need to be converted to the Python representation), and converting to Python is not optimized for most Arrow types, so we can’t do much about this.
1 Like
Using datasets.load_from_disk(path_to_dataset).with_format('torch')
did NOT help. It just shifted the cost from here to ChunkedArray.to_numpy()
. pyarrow.ChunkedArray — Apache Arrow v18.0.0