# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Unit tests for `baselines.py`.""" import copy import functools from typing import Generator from absl.testing import absltest from absl.testing import parameterized import chex from clrs._src import baselines from clrs._src import dataset from clrs._src import probing from clrs._src import processors from clrs._src import samplers from clrs._src import specs import haiku as hk import jax import numpy as np _Array = np.ndarray def _error(x, y): return np.sum(np.abs(x-y)) def _make_sampler(algo: str, length: int) -> samplers.Sampler: sampler, _ = samplers.build_sampler( algo, seed=samplers.CLRS30['val']['seed'], num_samples=samplers.CLRS30['val']['num_samples'], length=length, ) return sampler def _without_permutation(feedback): """Replace should-be permutations with pointers.""" outputs = [] for x in feedback.outputs: if x.type_ != specs.Type.SHOULD_BE_PERMUTATION: outputs.append(x) continue assert x.location == specs.Location.NODE outputs.append(probing.DataPoint(name=x.name, location=x.location, type_=specs.Type.POINTER, data=x.data)) return feedback._replace(outputs=outputs) def _make_iterable_sampler( algo: str, batch_size: int, length: int) -> Generator[samplers.Feedback, None, None]: sampler = _make_sampler(algo, length) while True: yield _without_permutation(sampler.next(batch_size)) def _remove_permutation_from_spec(spec): """Modify spec to turn permutation type to pointer.""" new_spec = {} for k in spec: if (spec[k][1] == specs.Location.NODE and spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION): new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER) else: new_spec[k] = spec[k] return new_spec class BaselinesTest(parameterized.TestCase): def test_full_vs_chunked(self): """Test that chunking does not affect gradients.""" batch_size = 4 length = 8 algo = 'insertion_sort' spec = _remove_permutation_from_spec(specs.SPECS[algo]) rng_key = jax.random.PRNGKey(42) full_ds = _make_iterable_sampler(algo, batch_size, length) chunked_ds = dataset.chunkify( _make_iterable_sampler(algo, batch_size, length), length) double_chunked_ds = dataset.chunkify( _make_iterable_sampler(algo, batch_size, length), length * 2) full_batches = [next(full_ds) for _ in range(2)] chunked_batches = [next(chunked_ds) for _ in range(2)] double_chunk_batch = next(double_chunked_ds) with chex.fake_jit(): # jitting makes test longer processor_factory = processors.get_processor_factory( 'mpnn', use_ln=False, nb_triplet_fts=0) common_args = dict(processor_factory=processor_factory, hidden_dim=8, learning_rate=0.01, decode_hints=True, encode_hints=True) b_full = baselines.BaselineModel( spec, dummy_trajectory=full_batches[0], **common_args) b_full.init(full_batches[0].features, seed=42) # pytype: disable=wrong-arg-types # jax-ndarray full_params = b_full.params full_loss_0 = b_full.feedback(rng_key, full_batches[0]) b_full.params = full_params full_loss_1 = b_full.feedback(rng_key, full_batches[1]) new_full_params = b_full.params b_chunked = baselines.BaselineModelChunked( spec, dummy_trajectory=chunked_batches[0], **common_args) b_chunked.init([[chunked_batches[0].features]], seed=42) # pytype: disable=wrong-arg-types # jax-ndarray chunked_params = b_chunked.params jax.tree_util.tree_map(np.testing.assert_array_equal, full_params, chunked_params) chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0]) b_chunked.params = chunked_params chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1]) new_chunked_params = b_chunked.params b_chunked.params = chunked_params double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch) # Test that losses match np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4) np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4) np.testing.assert_allclose(full_loss_0 + full_loss_1, 2 * double_chunked_loss, rtol=1e-4) # Test that gradients are the same (parameters changed equally). # First check that gradients were not zero, i.e., parameters have changed. param_change, _ = jax.tree_util.tree_flatten( jax.tree_util.tree_map(_error, full_params, new_full_params)) self.assertGreater(np.mean(param_change), 0.1) # Now check that full and chunked gradients are the same. jax.tree_util.tree_map( functools.partial(np.testing.assert_allclose, rtol=1e-4), new_full_params, new_chunked_params) def test_multi_vs_single(self): """Test that multi = single when we only train one of the algorithms.""" batch_size = 4 length = 16 algos = ['insertion_sort', 'activity_selector', 'bfs'] spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos] rng_key = jax.random.PRNGKey(42) full_ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos] full_batches = [next(ds) for ds in full_ds] full_batches_2 = [next(ds) for ds in full_ds] with chex.fake_jit(): # jitting makes test longer processor_factory = processors.get_processor_factory( 'mpnn', use_ln=False, nb_triplet_fts=0) common_args = dict(processor_factory=processor_factory, hidden_dim=8, learning_rate=0.01, decode_hints=True, encode_hints=True) b_single = baselines.BaselineModel( spec[0], dummy_trajectory=full_batches[0], **common_args) b_multi = baselines.BaselineModel( spec, dummy_trajectory=full_batches, **common_args) b_single.init(full_batches[0].features, seed=0) # pytype: disable=wrong-arg-types # jax-ndarray b_multi.init([f.features for f in full_batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray single_params = [] single_losses = [] multi_params = [] multi_losses = [] single_params.append(copy.deepcopy(b_single.params)) single_losses.append(b_single.feedback(rng_key, full_batches[0])) single_params.append(copy.deepcopy(b_single.params)) single_losses.append(b_single.feedback(rng_key, full_batches_2[0])) single_params.append(copy.deepcopy(b_single.params)) multi_params.append(copy.deepcopy(b_multi.params)) multi_losses.append(b_multi.feedback(rng_key, full_batches[0], algorithm_index=0)) multi_params.append(copy.deepcopy(b_multi.params)) multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0], algorithm_index=0)) multi_params.append(copy.deepcopy(b_multi.params)) # Test that losses match np.testing.assert_array_equal(single_losses, multi_losses) # Test that loss decreased assert single_losses[1] < single_losses[0] # Test that param changes were the same in single and multi-algorithm for single, multi in zip(single_params, multi_params): assert hk.data_structures.is_subset(subset=single, superset=multi) for module_name, params in single.items(): jax.tree_util.tree_map(np.testing.assert_array_equal, params, multi[module_name]) # Test that params change for the trained algorithm, but not the others for module_name, params in multi_params[0].items(): param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)), params, multi_params[1][module_name]) param_change = sum(param_changes.values()) if module_name in single_params[0]: # params of trained algorithm assert param_change > 1e-3 else: # params of non-trained algorithms assert param_change == 0.0 @parameterized.parameters(True, False) def test_multi_algorithm_idx(self, is_chunked): """Test that algorithm selection works as intended.""" batch_size = 4 length = 8 algos = ['insertion_sort', 'activity_selector', 'bfs'] spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos] rng_key = jax.random.PRNGKey(42) if is_chunked: ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length), 2 * length) for algo in algos] else: ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos] batches = [next(d) for d in ds] processor_factory = processors.get_processor_factory( 'mpnn', use_ln=False, nb_triplet_fts=0) common_args = dict(processor_factory=processor_factory, hidden_dim=8, learning_rate=0.01, decode_hints=True, encode_hints=True) if is_chunked: baseline = baselines.BaselineModelChunked( spec, dummy_trajectory=batches, **common_args) baseline.init([[f.features for f in batches]], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray else: baseline = baselines.BaselineModel( spec, dummy_trajectory=batches, **common_args) baseline.init([f.features for f in batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray # Find out what parameters change when we train each algorithm def _change(x, y): changes = {} for module_name, params in x.items(): changes[module_name] = sum( jax.tree_util.tree_map( lambda a, b: np.sum(np.abs(a-b)), params, y[module_name] ).values()) return changes param_changes = [] for algo_idx in range(len(algos)): init_params = copy.deepcopy(baseline.params) _ = baseline.feedback( rng_key, batches[algo_idx], algorithm_index=(0, algo_idx) if is_chunked else algo_idx) param_changes.append(_change(init_params, baseline.params)) # Test that non-changing parameters correspond to encoders/decoders # associated with the non-trained algorithms unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes] def _get_other_algos(algo_idx, modules): return set([k for k in modules if '_construct_encoders_decoders' in k and f'algo_{algo_idx}' not in k]) for algo_idx in range(len(algos)): expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys()) self.assertNotEmpty(expected_unchanged) self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx])) if __name__ == '__main__': absltest.main()