Spaces:
Running
Running
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""CLRS dataset.""" | |
import dataclasses | |
import functools | |
from typing import Iterator | |
from clrs._src import probing | |
from clrs._src import samplers | |
from clrs._src import specs | |
import jax | |
import numpy as np | |
import tensorflow as tf | |
import tensorflow_datasets as tfds | |
def _correct_axis_filtering(tensor, index, name): | |
if 'hint_' in name: | |
return tensor[:, index] | |
else: | |
return tensor[index] | |
class CLRSConfig(tfds.core.BuilderConfig): | |
"""Specify the split in the variant because they have different shapes.""" | |
split: str = '' | |
DEFAULT_BUILDER_CONFIGS = [] | |
def _build_default_builder_configs(): | |
for split in ['train', 'val', 'test']: | |
for alg in specs.CLRS_30_ALGS: | |
DEFAULT_BUILDER_CONFIGS.append( | |
CLRSConfig(name=f'{alg}_{split}', split=split)) | |
_build_default_builder_configs() | |
class CLRSDataset(tfds.core.GeneratorBasedBuilder): | |
"""DatasetBuilder for my_dataset dataset.""" | |
VERSION = tfds.core.Version('1.0.0') | |
RELEASE_NOTES = { | |
'1.0.0': 'Initial release.', | |
} | |
BUILDER_CONFIGS = DEFAULT_BUILDER_CONFIGS | |
_instantiated_dataset = None | |
_instantiated_dataset_name = '' | |
_instantiated_dataset_split = '' | |
def _num_samples(self, algorithm_name): | |
num_samples = samplers.CLRS30[self._builder_config.split]['num_samples'] # pytype: disable=attribute-error # always-use-return-annotations | |
if self._builder_config.split != 'train': # pytype: disable=attribute-error # always-use-return-annotations | |
# Generate more samples for those algorithms in which the number of | |
# signals is small. | |
num_samples *= specs.CLRS_30_ALGS_SETTINGS[algorithm_name][ | |
'num_samples_multiplier'] | |
return num_samples | |
def _create_data(self, single_sample): | |
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) | |
num_samples = self._num_samples(algorithm_name) | |
sampler, _ = samplers.build_sampler( | |
algorithm_name, | |
seed=samplers.CLRS30[self._builder_config.split]['seed'], # pytype: disable=attribute-error # always-use-return-annotations | |
num_samples=num_samples, | |
length=samplers.CLRS30[self._builder_config.split]['length'], # pytype: disable=attribute-error # always-use-return-annotations | |
) | |
sampled_dataset = sampler.next(batch_size=1 if single_sample else None) | |
data = {'input_' + t.name: t.data for t in sampled_dataset.features.inputs} | |
# All other data points have input_, hint_, and output_ prefixes, so we | |
# guarantee that this key is unused. | |
data['lengths'] = sampled_dataset.features.lengths | |
data.update({'output_' + t.name: t.data for t in sampled_dataset.outputs}) | |
data.update({ | |
'hint_' + t.name: t.data for t in sampled_dataset.features.hints}) | |
self._instantiated_dataset = data | |
def _info(self) -> tfds.core.DatasetInfo: | |
if tf.io.gfile.exists(self.data_dir): | |
info = tfds.core.DatasetInfo(builder=self) | |
info.read_from_directory(self.data_dir) | |
return info | |
if (self._instantiated_dataset_name != self._builder_config.name | |
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations | |
self._create_data(single_sample=True) | |
data = {k: _correct_axis_filtering(v, 0, k) | |
for k, v in self._instantiated_dataset.items()} | |
data_info = { | |
k: tfds.features.Tensor(shape=v.shape, dtype=tf.dtypes.as_dtype( | |
v.dtype)) for k, v in data.items()} | |
return tfds.core.DatasetInfo( | |
builder=self, | |
features=tfds.features.FeaturesDict(data_info), | |
) | |
def _split_generators(self, dl_manager: tfds.download.DownloadManager): | |
"""Download the data and define splits.""" | |
if (self._instantiated_dataset_name != self._builder_config.name | |
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations | |
self._create_data(single_sample=False) | |
self._instantiated_dataset_name = self._builder_config.name | |
self._instantiated_dataset_split = self._builder_config.split # pytype: disable=attribute-error # always-use-return-annotations | |
return {self._builder_config.split: self._generate_examples()} # pytype: disable=attribute-error # always-use-return-annotations | |
def _generate_examples(self): | |
"""Generator of examples for each split.""" | |
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1]) | |
for i in range(self._num_samples(algorithm_name)): | |
data = {k: _correct_axis_filtering(v, i, k) | |
for k, v in self._instantiated_dataset.items()} | |
yield str(i), data | |
def _get_clrs_file_name(): | |
return f'CLRS30_v{CLRSDataset.VERSION}.tar.gz' | |
def get_dataset_gcp_url(): | |
return f'https://storage.googleapis.com/dm-clrs/{_get_clrs_file_name()}' | |
def get_clrs_folder(): | |
return f'CLRS30_v{CLRSDataset.VERSION}' | |
def _preprocess(data_point, algorithm=None): | |
"""Convert sampled inputs into DataPoints.""" | |
inputs = [] | |
outputs = [] | |
hints = [] | |
lengths = None | |
for name, data in data_point.items(): | |
if name == 'lengths': | |
lengths = data | |
continue | |
data_point_name = name.split('_') | |
name = '_'.join(data_point_name[1:]) | |
(stage, location, dp_type) = specs.SPECS[algorithm][name] | |
assert stage == data_point_name[0] | |
if stage == specs.Stage.HINT: | |
data = tf.experimental.numpy.swapaxes(data, 0, 1) | |
dp = probing.DataPoint(name, location, dp_type, data) | |
if stage == specs.Stage.INPUT: | |
inputs.append(dp) | |
elif stage == specs.Stage.OUTPUT: | |
outputs.append(dp) | |
else: | |
hints.append(dp) | |
return samplers.Feedback( | |
samplers.Features(tuple(inputs), tuple(hints), lengths), tuple(outputs)) | |
def create_dataset(folder, algorithm, split, batch_size): | |
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}', | |
data_dir=folder, split=split) | |
num_samples = len(dataset) # Must be done here for correct size | |
dataset = dataset.repeat() | |
dataset = dataset.batch(batch_size) | |
return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)), | |
num_samples, | |
specs.SPECS[algorithm]) | |
def _copy_hint(source, dest, i, start_source, start_dest, to_add): | |
"""Copy from full-sample hint to a hint chunk.""" | |
assert np.all(dest[start_dest:, i:] == 0) | |
assert start_dest < dest.shape[0] | |
assert start_dest + to_add <= dest.shape[0] | |
assert start_source < source.shape[0] | |
assert start_source + to_add <= source.shape[0] | |
dest[start_dest:start_dest+to_add, i] = source[ | |
start_source:start_source+to_add, i] | |
return dest | |
def _copy_io(source, dest, i, start_dest, to_add): | |
"""Copy from an input or output to an input or output chunk.""" | |
assert np.all(dest[start_dest:, i:] == 0) | |
dest[start_dest:start_dest+to_add, i] = source[i] | |
return dest | |
def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int): | |
"""Generator of fixed-length chunks from full-trajectory samples. | |
Args: | |
dataset: full-sample dataset as numpy iterator. | |
chunk_length: time length of chunks. | |
Yields: | |
Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs | |
has dimensions chunk_length x batch_size x ... Samples are not time-padded, | |
after the end of one sample immediately comes the next. Since different | |
samples can have different time lengths, the beginnings and ends of samples | |
within a batch do not need to coincide. For this reason, the chunked | |
dataset features include two chunk_length x batch_size int tensors, | |
`is_first` and `is_last`, that mark the beginning and end of each sample. | |
For example, if `chunk_legnth`==6 and `batch_size`==2 and the first | |
full-sample batch had one sample of length 3 and one of length 5, | |
we would have a first chunked batch with the following `is_first` and | |
`is_last` tensors: | |
is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1] | |
[0, 0] [0, 0] [0 1] | |
[0, 0] [1, 0] [0 1] | |
[1, 0] [0, 0] [2 1] | |
[0, 0] [0, 1] [2 1] | |
[0, 1]] [0, 0]] [2 3]] ) | |
while the data in the inputs, outputs and hints tensors would correspond | |
to samples as identified by the sample_id indicated above for reference. | |
Notice that, while in the full-sample dataset inputs and outputs have | |
no time dimension, here they do; the input and output tensors are simply | |
repeated along each sample's time length. | |
""" | |
def _get_batch(): | |
d = next(dataset) | |
return (d.features.inputs, d.features.hints, d.outputs, | |
d.features.lengths.astype(int)) | |
inputs, hints, outputs, lengths = _get_batch() | |
for inp in inputs: | |
if inp.location in [specs.Location.NODE, specs.Location.EDGE]: | |
batch_size = inp.data.shape[0] | |
break | |
io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype) | |
chunk_inputs = jax.tree_util.tree_map(io_chunk, inputs) | |
chunk_outputs = jax.tree_util.tree_map(io_chunk, outputs) | |
hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype) | |
chunk_hints = jax.tree_util.tree_map(hint_chunk, hints) | |
inputs = [inputs] | |
hints = [hints] | |
outputs = [outputs] | |
left = [lengths.copy()] | |
lengths = [lengths.copy()] | |
while True: | |
# Create a new empty chunk | |
chunk_inputs = jax.tree_util.tree_map(np.zeros_like, chunk_inputs) | |
chunk_hints = jax.tree_util.tree_map(np.zeros_like, chunk_hints) | |
chunk_outputs = jax.tree_util.tree_map(np.zeros_like, chunk_outputs) | |
start_mark = np.zeros((chunk_length, batch_size), dtype=int) | |
end_mark = np.zeros((chunk_length, batch_size), dtype=int) | |
# Get enough data batches to fill the new chunk | |
while np.any(np.sum(left, axis=0) < chunk_length): | |
inp, hh, out, ll = _get_batch() | |
inputs.append(inp) | |
hints.append(hh) | |
outputs.append(out) | |
left.append(ll.copy()) | |
lengths.append(ll.copy()) | |
# Fill the chunk, one batch element at a time | |
for i in range(batch_size): | |
total, idx = 0, 0 | |
while total < chunk_length: | |
to_add = min(left[idx][i], chunk_length - total) | |
if to_add: | |
start = lengths[idx][i] - left[idx][i] | |
assert start >= 0 | |
f_io = functools.partial(_copy_io, i=i, start_dest=total, | |
to_add=to_add) | |
chunk_inputs = jax.tree_util.tree_map(f_io, inputs[idx], chunk_inputs) | |
chunk_outputs = jax.tree_util.tree_map(f_io, outputs[idx], | |
chunk_outputs) | |
f_hint = functools.partial(_copy_hint, i=i, start_source=start, | |
start_dest=total, to_add=to_add) | |
chunk_hints = jax.tree_util.tree_map(f_hint, hints[idx], chunk_hints) | |
if start == 0: | |
start_mark[total, i] = 1 | |
total += to_add | |
left[idx][i] -= to_add | |
assert left[idx][i] >= 0 | |
if left[idx][i] == 0: | |
end_mark[total - 1, i] = 1 | |
idx += 1 | |
assert total == chunk_length | |
while left and np.all(left[0] == 0): | |
inputs.pop(0) | |
hints.pop(0) | |
outputs.pop(0) | |
left.pop(0) | |
lengths.pop(0) | |
yield samplers.Feedback( | |
samplers.FeaturesChunked(chunk_inputs, chunk_hints, | |
start_mark, end_mark), | |
chunk_outputs) | |
def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length): | |
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}', | |
data_dir=folder, split=split) | |
dataset = dataset.repeat() | |
dataset = dataset.batch(batch_size) | |
dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm)) | |
dataset = dataset.as_numpy_iterator() | |
return chunkify(dataset, chunk_length), specs.SPECS[algorithm] | |