Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
12.7 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CLRS dataset."""
import dataclasses
import functools
from typing import Iterator
from clrs._src import probing
from clrs._src import samplers
from clrs._src import specs
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
def _correct_axis_filtering(tensor, index, name):
if 'hint_' in name:
return tensor[:, index]
else:
return tensor[index]
@dataclasses.dataclass
class CLRSConfig(tfds.core.BuilderConfig):
"""Specify the split in the variant because they have different shapes."""
split: str = ''
DEFAULT_BUILDER_CONFIGS = []
def _build_default_builder_configs():
for split in ['train', 'val', 'test']:
for alg in specs.CLRS_30_ALGS:
DEFAULT_BUILDER_CONFIGS.append(
CLRSConfig(name=f'{alg}_{split}', split=split))
_build_default_builder_configs()
class CLRSDataset(tfds.core.GeneratorBasedBuilder):
"""DatasetBuilder for my_dataset dataset."""
VERSION = tfds.core.Version('1.0.0')
RELEASE_NOTES = {
'1.0.0': 'Initial release.',
}
BUILDER_CONFIGS = DEFAULT_BUILDER_CONFIGS
_instantiated_dataset = None
_instantiated_dataset_name = ''
_instantiated_dataset_split = ''
def _num_samples(self, algorithm_name):
num_samples = samplers.CLRS30[self._builder_config.split]['num_samples'] # pytype: disable=attribute-error # always-use-return-annotations
if self._builder_config.split != 'train': # pytype: disable=attribute-error # always-use-return-annotations
# Generate more samples for those algorithms in which the number of
# signals is small.
num_samples *= specs.CLRS_30_ALGS_SETTINGS[algorithm_name][
'num_samples_multiplier']
return num_samples
def _create_data(self, single_sample):
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
num_samples = self._num_samples(algorithm_name)
sampler, _ = samplers.build_sampler(
algorithm_name,
seed=samplers.CLRS30[self._builder_config.split]['seed'], # pytype: disable=attribute-error # always-use-return-annotations
num_samples=num_samples,
length=samplers.CLRS30[self._builder_config.split]['length'], # pytype: disable=attribute-error # always-use-return-annotations
)
sampled_dataset = sampler.next(batch_size=1 if single_sample else None)
data = {'input_' + t.name: t.data for t in sampled_dataset.features.inputs}
# All other data points have input_, hint_, and output_ prefixes, so we
# guarantee that this key is unused.
data['lengths'] = sampled_dataset.features.lengths
data.update({'output_' + t.name: t.data for t in sampled_dataset.outputs})
data.update({
'hint_' + t.name: t.data for t in sampled_dataset.features.hints})
self._instantiated_dataset = data
def _info(self) -> tfds.core.DatasetInfo:
if tf.io.gfile.exists(self.data_dir):
info = tfds.core.DatasetInfo(builder=self)
info.read_from_directory(self.data_dir)
return info
if (self._instantiated_dataset_name != self._builder_config.name
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
self._create_data(single_sample=True)
data = {k: _correct_axis_filtering(v, 0, k)
for k, v in self._instantiated_dataset.items()}
data_info = {
k: tfds.features.Tensor(shape=v.shape, dtype=tf.dtypes.as_dtype(
v.dtype)) for k, v in data.items()}
return tfds.core.DatasetInfo(
builder=self,
features=tfds.features.FeaturesDict(data_info),
)
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
"""Download the data and define splits."""
if (self._instantiated_dataset_name != self._builder_config.name
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
self._create_data(single_sample=False)
self._instantiated_dataset_name = self._builder_config.name
self._instantiated_dataset_split = self._builder_config.split # pytype: disable=attribute-error # always-use-return-annotations
return {self._builder_config.split: self._generate_examples()} # pytype: disable=attribute-error # always-use-return-annotations
def _generate_examples(self):
"""Generator of examples for each split."""
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
for i in range(self._num_samples(algorithm_name)):
data = {k: _correct_axis_filtering(v, i, k)
for k, v in self._instantiated_dataset.items()}
yield str(i), data
def _get_clrs_file_name():
return f'CLRS30_v{CLRSDataset.VERSION}.tar.gz'
def get_dataset_gcp_url():
return f'https://storage.googleapis.com/dm-clrs/{_get_clrs_file_name()}'
def get_clrs_folder():
return f'CLRS30_v{CLRSDataset.VERSION}'
def _preprocess(data_point, algorithm=None):
"""Convert sampled inputs into DataPoints."""
inputs = []
outputs = []
hints = []
lengths = None
for name, data in data_point.items():
if name == 'lengths':
lengths = data
continue
data_point_name = name.split('_')
name = '_'.join(data_point_name[1:])
(stage, location, dp_type) = specs.SPECS[algorithm][name]
assert stage == data_point_name[0]
if stage == specs.Stage.HINT:
data = tf.experimental.numpy.swapaxes(data, 0, 1)
dp = probing.DataPoint(name, location, dp_type, data)
if stage == specs.Stage.INPUT:
inputs.append(dp)
elif stage == specs.Stage.OUTPUT:
outputs.append(dp)
else:
hints.append(dp)
return samplers.Feedback(
samplers.Features(tuple(inputs), tuple(hints), lengths), tuple(outputs))
def create_dataset(folder, algorithm, split, batch_size):
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
data_dir=folder, split=split)
num_samples = len(dataset) # Must be done here for correct size
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)),
num_samples,
specs.SPECS[algorithm])
def _copy_hint(source, dest, i, start_source, start_dest, to_add):
"""Copy from full-sample hint to a hint chunk."""
assert np.all(dest[start_dest:, i:] == 0)
assert start_dest < dest.shape[0]
assert start_dest + to_add <= dest.shape[0]
assert start_source < source.shape[0]
assert start_source + to_add <= source.shape[0]
dest[start_dest:start_dest+to_add, i] = source[
start_source:start_source+to_add, i]
return dest
def _copy_io(source, dest, i, start_dest, to_add):
"""Copy from an input or output to an input or output chunk."""
assert np.all(dest[start_dest:, i:] == 0)
dest[start_dest:start_dest+to_add, i] = source[i]
return dest
def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int):
"""Generator of fixed-length chunks from full-trajectory samples.
Args:
dataset: full-sample dataset as numpy iterator.
chunk_length: time length of chunks.
Yields:
Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs
has dimensions chunk_length x batch_size x ... Samples are not time-padded,
after the end of one sample immediately comes the next. Since different
samples can have different time lengths, the beginnings and ends of samples
within a batch do not need to coincide. For this reason, the chunked
dataset features include two chunk_length x batch_size int tensors,
`is_first` and `is_last`, that mark the beginning and end of each sample.
For example, if `chunk_legnth`==6 and `batch_size`==2 and the first
full-sample batch had one sample of length 3 and one of length 5,
we would have a first chunked batch with the following `is_first` and
`is_last` tensors:
is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1]
[0, 0] [0, 0] [0 1]
[0, 0] [1, 0] [0 1]
[1, 0] [0, 0] [2 1]
[0, 0] [0, 1] [2 1]
[0, 1]] [0, 0]] [2 3]] )
while the data in the inputs, outputs and hints tensors would correspond
to samples as identified by the sample_id indicated above for reference.
Notice that, while in the full-sample dataset inputs and outputs have
no time dimension, here they do; the input and output tensors are simply
repeated along each sample's time length.
"""
def _get_batch():
d = next(dataset)
return (d.features.inputs, d.features.hints, d.outputs,
d.features.lengths.astype(int))
inputs, hints, outputs, lengths = _get_batch()
for inp in inputs:
if inp.location in [specs.Location.NODE, specs.Location.EDGE]:
batch_size = inp.data.shape[0]
break
io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype)
chunk_inputs = jax.tree_util.tree_map(io_chunk, inputs)
chunk_outputs = jax.tree_util.tree_map(io_chunk, outputs)
hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype)
chunk_hints = jax.tree_util.tree_map(hint_chunk, hints)
inputs = [inputs]
hints = [hints]
outputs = [outputs]
left = [lengths.copy()]
lengths = [lengths.copy()]
while True:
# Create a new empty chunk
chunk_inputs = jax.tree_util.tree_map(np.zeros_like, chunk_inputs)
chunk_hints = jax.tree_util.tree_map(np.zeros_like, chunk_hints)
chunk_outputs = jax.tree_util.tree_map(np.zeros_like, chunk_outputs)
start_mark = np.zeros((chunk_length, batch_size), dtype=int)
end_mark = np.zeros((chunk_length, batch_size), dtype=int)
# Get enough data batches to fill the new chunk
while np.any(np.sum(left, axis=0) < chunk_length):
inp, hh, out, ll = _get_batch()
inputs.append(inp)
hints.append(hh)
outputs.append(out)
left.append(ll.copy())
lengths.append(ll.copy())
# Fill the chunk, one batch element at a time
for i in range(batch_size):
total, idx = 0, 0
while total < chunk_length:
to_add = min(left[idx][i], chunk_length - total)
if to_add:
start = lengths[idx][i] - left[idx][i]
assert start >= 0
f_io = functools.partial(_copy_io, i=i, start_dest=total,
to_add=to_add)
chunk_inputs = jax.tree_util.tree_map(f_io, inputs[idx], chunk_inputs)
chunk_outputs = jax.tree_util.tree_map(f_io, outputs[idx],
chunk_outputs)
f_hint = functools.partial(_copy_hint, i=i, start_source=start,
start_dest=total, to_add=to_add)
chunk_hints = jax.tree_util.tree_map(f_hint, hints[idx], chunk_hints)
if start == 0:
start_mark[total, i] = 1
total += to_add
left[idx][i] -= to_add
assert left[idx][i] >= 0
if left[idx][i] == 0:
end_mark[total - 1, i] = 1
idx += 1
assert total == chunk_length
while left and np.all(left[0] == 0):
inputs.pop(0)
hints.pop(0)
outputs.pop(0)
left.pop(0)
lengths.pop(0)
yield samplers.Feedback(
samplers.FeaturesChunked(chunk_inputs, chunk_hints,
start_mark, end_mark),
chunk_outputs)
def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length):
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
data_dir=folder, split=split)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm))
dataset = dataset.as_numpy_iterator()
return chunkify(dataset, chunk_length), specs.SPECS[algorithm]