MLR-Copilot / benchmarks /CLRS /env /baselines_test.py
Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
11.6 kB
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for `baselines.py`."""
import copy
import functools
from typing import Generator
from absl.testing import absltest
from absl.testing import parameterized
import chex
from clrs._src import baselines
from clrs._src import dataset
from clrs._src import probing
from clrs._src import processors
from clrs._src import samplers
from clrs._src import specs
import haiku as hk
import jax
import numpy as np
_Array = np.ndarray
def _error(x, y):
return np.sum(np.abs(x-y))
def _make_sampler(algo: str, length: int) -> samplers.Sampler:
sampler, _ = samplers.build_sampler(
algo,
seed=samplers.CLRS30['val']['seed'],
num_samples=samplers.CLRS30['val']['num_samples'],
length=length,
)
return sampler
def _without_permutation(feedback):
"""Replace should-be permutations with pointers."""
outputs = []
for x in feedback.outputs:
if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
outputs.append(x)
continue
assert x.location == specs.Location.NODE
outputs.append(probing.DataPoint(name=x.name, location=x.location,
type_=specs.Type.POINTER, data=x.data))
return feedback._replace(outputs=outputs)
def _make_iterable_sampler(
algo: str, batch_size: int,
length: int) -> Generator[samplers.Feedback, None, None]:
sampler = _make_sampler(algo, length)
while True:
yield _without_permutation(sampler.next(batch_size))
def _remove_permutation_from_spec(spec):
"""Modify spec to turn permutation type to pointer."""
new_spec = {}
for k in spec:
if (spec[k][1] == specs.Location.NODE and
spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
else:
new_spec[k] = spec[k]
return new_spec
class BaselinesTest(parameterized.TestCase):
def test_full_vs_chunked(self):
"""Test that chunking does not affect gradients."""
batch_size = 4
length = 8
algo = 'insertion_sort'
spec = _remove_permutation_from_spec(specs.SPECS[algo])
rng_key = jax.random.PRNGKey(42)
full_ds = _make_iterable_sampler(algo, batch_size, length)
chunked_ds = dataset.chunkify(
_make_iterable_sampler(algo, batch_size, length),
length)
double_chunked_ds = dataset.chunkify(
_make_iterable_sampler(algo, batch_size, length),
length * 2)
full_batches = [next(full_ds) for _ in range(2)]
chunked_batches = [next(chunked_ds) for _ in range(2)]
double_chunk_batch = next(double_chunked_ds)
with chex.fake_jit(): # jitting makes test longer
processor_factory = processors.get_processor_factory(
'mpnn', use_ln=False, nb_triplet_fts=0)
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
learning_rate=0.01,
decode_hints=True, encode_hints=True)
b_full = baselines.BaselineModel(
spec, dummy_trajectory=full_batches[0], **common_args)
b_full.init(full_batches[0].features, seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
full_params = b_full.params
full_loss_0 = b_full.feedback(rng_key, full_batches[0])
b_full.params = full_params
full_loss_1 = b_full.feedback(rng_key, full_batches[1])
new_full_params = b_full.params
b_chunked = baselines.BaselineModelChunked(
spec, dummy_trajectory=chunked_batches[0], **common_args)
b_chunked.init([[chunked_batches[0].features]], seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
chunked_params = b_chunked.params
jax.tree_util.tree_map(np.testing.assert_array_equal, full_params,
chunked_params)
chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0])
b_chunked.params = chunked_params
chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1])
new_chunked_params = b_chunked.params
b_chunked.params = chunked_params
double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch)
# Test that losses match
np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4)
np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4)
np.testing.assert_allclose(full_loss_0 + full_loss_1,
2 * double_chunked_loss,
rtol=1e-4)
# Test that gradients are the same (parameters changed equally).
# First check that gradients were not zero, i.e., parameters have changed.
param_change, _ = jax.tree_util.tree_flatten(
jax.tree_util.tree_map(_error, full_params, new_full_params))
self.assertGreater(np.mean(param_change), 0.1)
# Now check that full and chunked gradients are the same.
jax.tree_util.tree_map(
functools.partial(np.testing.assert_allclose, rtol=1e-4),
new_full_params, new_chunked_params)
def test_multi_vs_single(self):
"""Test that multi = single when we only train one of the algorithms."""
batch_size = 4
length = 16
algos = ['insertion_sort', 'activity_selector', 'bfs']
spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
rng_key = jax.random.PRNGKey(42)
full_ds = [_make_iterable_sampler(algo, batch_size, length)
for algo in algos]
full_batches = [next(ds) for ds in full_ds]
full_batches_2 = [next(ds) for ds in full_ds]
with chex.fake_jit(): # jitting makes test longer
processor_factory = processors.get_processor_factory(
'mpnn', use_ln=False, nb_triplet_fts=0)
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
learning_rate=0.01,
decode_hints=True, encode_hints=True)
b_single = baselines.BaselineModel(
spec[0], dummy_trajectory=full_batches[0], **common_args)
b_multi = baselines.BaselineModel(
spec, dummy_trajectory=full_batches, **common_args)
b_single.init(full_batches[0].features, seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
b_multi.init([f.features for f in full_batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
single_params = []
single_losses = []
multi_params = []
multi_losses = []
single_params.append(copy.deepcopy(b_single.params))
single_losses.append(b_single.feedback(rng_key, full_batches[0]))
single_params.append(copy.deepcopy(b_single.params))
single_losses.append(b_single.feedback(rng_key, full_batches_2[0]))
single_params.append(copy.deepcopy(b_single.params))
multi_params.append(copy.deepcopy(b_multi.params))
multi_losses.append(b_multi.feedback(rng_key, full_batches[0],
algorithm_index=0))
multi_params.append(copy.deepcopy(b_multi.params))
multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0],
algorithm_index=0))
multi_params.append(copy.deepcopy(b_multi.params))
# Test that losses match
np.testing.assert_array_equal(single_losses, multi_losses)
# Test that loss decreased
assert single_losses[1] < single_losses[0]
# Test that param changes were the same in single and multi-algorithm
for single, multi in zip(single_params, multi_params):
assert hk.data_structures.is_subset(subset=single, superset=multi)
for module_name, params in single.items():
jax.tree_util.tree_map(np.testing.assert_array_equal, params,
multi[module_name])
# Test that params change for the trained algorithm, but not the others
for module_name, params in multi_params[0].items():
param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)),
params,
multi_params[1][module_name])
param_change = sum(param_changes.values())
if module_name in single_params[0]: # params of trained algorithm
assert param_change > 1e-3
else: # params of non-trained algorithms
assert param_change == 0.0
@parameterized.parameters(True, False)
def test_multi_algorithm_idx(self, is_chunked):
"""Test that algorithm selection works as intended."""
batch_size = 4
length = 8
algos = ['insertion_sort', 'activity_selector', 'bfs']
spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
rng_key = jax.random.PRNGKey(42)
if is_chunked:
ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length),
2 * length) for algo in algos]
else:
ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos]
batches = [next(d) for d in ds]
processor_factory = processors.get_processor_factory(
'mpnn', use_ln=False, nb_triplet_fts=0)
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
learning_rate=0.01,
decode_hints=True, encode_hints=True)
if is_chunked:
baseline = baselines.BaselineModelChunked(
spec, dummy_trajectory=batches, **common_args)
baseline.init([[f.features for f in batches]], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
else:
baseline = baselines.BaselineModel(
spec, dummy_trajectory=batches, **common_args)
baseline.init([f.features for f in batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
# Find out what parameters change when we train each algorithm
def _change(x, y):
changes = {}
for module_name, params in x.items():
changes[module_name] = sum(
jax.tree_util.tree_map(
lambda a, b: np.sum(np.abs(a-b)), params, y[module_name]
).values())
return changes
param_changes = []
for algo_idx in range(len(algos)):
init_params = copy.deepcopy(baseline.params)
_ = baseline.feedback(
rng_key,
batches[algo_idx],
algorithm_index=(0, algo_idx) if is_chunked else algo_idx)
param_changes.append(_change(init_params, baseline.params))
# Test that non-changing parameters correspond to encoders/decoders
# associated with the non-trained algorithms
unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes]
def _get_other_algos(algo_idx, modules):
return set([k for k in modules if '_construct_encoders_decoders' in k
and f'algo_{algo_idx}' not in k])
for algo_idx in range(len(algos)):
expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys())
self.assertNotEmpty(expected_unchanged)
self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx]))
if __name__ == '__main__':
absltest.main()