Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Probing utilities. | |
The dataflow for an algorithm is represented by `(stage, loc, type, data)` | |
"probes" that are valid under that algorithm's spec (see `specs.py`). | |
When constructing probes, it is convenient to represent these fields in a nested | |
format (`ProbesDict`) to facilate efficient contest-based look-up. | |
""" | |
import functools | |
from typing import Dict, List, Tuple, Union | |
import attr | |
from clrs._src import specs | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import tensorflow as tf | |
_Location = specs.Location | |
_Stage = specs.Stage | |
_Type = specs.Type | |
_OutputClass = specs.OutputClass | |
_Array = np.ndarray | |
_Data = Union[_Array, List[_Array]] | |
_DataOrType = Union[_Data, str] | |
ProbesDict = Dict[ | |
str, Dict[str, Dict[str, Dict[str, _DataOrType]]]] | |
def _convert_to_str(element): | |
if isinstance(element, tf.Tensor): | |
return element.numpy().decode('utf-8') | |
elif isinstance(element, (np.ndarray, bytes)): | |
return element.decode('utf-8') | |
else: | |
return element | |
# First anotation makes this object jax.jit/pmap friendly, second one makes this | |
# tf.data.Datasets friendly. | |
class DataPoint: | |
"""Describes a data point.""" | |
_name: str | |
_location: str | |
_type_: str | |
data: _Array | |
def name(self): | |
return _convert_to_str(self._name) | |
def location(self): | |
return _convert_to_str(self._location) | |
def type_(self): | |
return _convert_to_str(self._type_) | |
def __repr__(self): | |
s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t' | |
return s + f'type={self.type_},\tdata=Array{self.data.shape})' | |
def tree_flatten(self): | |
data = (self.data,) | |
meta = (self.name, self.location, self.type_) | |
return data, meta | |
def tree_unflatten(cls, meta, data): | |
name, location, type_ = meta | |
subdata, = data | |
return DataPoint(name, location, type_, subdata) | |
class ProbeError(Exception): | |
pass | |
def initialize(spec: specs.Spec) -> ProbesDict: | |
"""Initializes an empty `ProbesDict` corresponding with the provided spec.""" | |
probes = dict() | |
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: | |
probes[stage] = {} | |
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: | |
probes[stage][loc] = {} | |
for name in spec: | |
stage, loc, t = spec[name] | |
probes[stage][loc][name] = {} | |
probes[stage][loc][name]['data'] = [] | |
probes[stage][loc][name]['type_'] = t | |
# Pytype thinks initialize() returns a ProbesDict with a str for all final | |
# values instead of _DataOrType. | |
return probes # pytype: disable=bad-return-type | |
def push(probes: ProbesDict, stage: str, next_probe): | |
"""Pushes a probe into an existing `ProbesDict`.""" | |
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: | |
for name in probes[stage][loc]: | |
if name not in next_probe: | |
raise ProbeError(f'Missing probe for {name}.') | |
if isinstance(probes[stage][loc][name]['data'], _Array): | |
raise ProbeError('Attemping to push to finalized `ProbesDict`.') | |
# Pytype thinks initialize() returns a ProbesDict with a str for all final | |
# values instead of _DataOrType. | |
probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error | |
def finalize(probes: ProbesDict): | |
"""Finalizes a `ProbesDict` by stacking/squeezing `data` field.""" | |
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: | |
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: | |
for name in probes[stage][loc]: | |
if isinstance(probes[stage][loc][name]['data'], _Array): | |
raise ProbeError('Attemping to re-finalize a finalized `ProbesDict`.') | |
if stage == _Stage.HINT: | |
# Hints are provided for each timestep. Stack them here. | |
probes[stage][loc][name]['data'] = np.stack( | |
probes[stage][loc][name]['data']) | |
else: | |
# Only one instance of input/output exist. Remove leading axis. | |
probes[stage][loc][name]['data'] = np.squeeze( | |
np.array(probes[stage][loc][name]['data'])) | |
def split_stages( | |
probes: ProbesDict, | |
spec: specs.Spec, | |
) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]: | |
"""Splits contents of `ProbesDict` into `DataPoint`s by stage.""" | |
inputs = [] | |
outputs = [] | |
hints = [] | |
for name in spec: | |
stage, loc, t = spec[name] | |
if stage not in probes: | |
raise ProbeError(f'Missing stage {stage}.') | |
if loc not in probes[stage]: | |
raise ProbeError(f'Missing location {loc}.') | |
if name not in probes[stage][loc]: | |
raise ProbeError(f'Missing probe {name}.') | |
if 'type_' not in probes[stage][loc][name]: | |
raise ProbeError(f'Probe {name} missing attribute `type_`.') | |
if 'data' not in probes[stage][loc][name]: | |
raise ProbeError(f'Probe {name} missing attribute `data`.') | |
if t != probes[stage][loc][name]['type_']: | |
raise ProbeError(f'Probe {name} of incorrect type {t}.') | |
data = probes[stage][loc][name]['data'] | |
if not isinstance(probes[stage][loc][name]['data'], _Array): | |
raise ProbeError((f'Invalid `data` for probe "{name}". ' + | |
'Did you forget to call `probing.finalize`?')) | |
if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]: | |
# pytype: disable=attribute-error | |
if not ((data == 0) | (data == 1) | (data == -1)).all(): | |
raise ProbeError(f'0|1|-1 `data` for probe "{name}"') | |
# pytype: enable=attribute-error | |
if t in [_Type.MASK_ONE, _Type.CATEGORICAL | |
] and not np.all(np.sum(np.abs(data), -1) == 1): | |
raise ProbeError(f'Expected one-hot `data` for probe "{name}"') | |
dim_to_expand = 1 if stage == _Stage.HINT else 0 | |
data_point = DataPoint(name=name, location=loc, type_=t, | |
data=np.expand_dims(data, dim_to_expand)) | |
if stage == _Stage.INPUT: | |
inputs.append(data_point) | |
elif stage == _Stage.OUTPUT: | |
outputs.append(data_point) | |
else: | |
hints.append(data_point) | |
return inputs, outputs, hints | |
# pylint: disable=invalid-name | |
def array(A_pos: np.ndarray) -> np.ndarray: | |
"""Constructs an `array` probe.""" | |
probe = np.arange(A_pos.shape[0]) | |
for i in range(1, A_pos.shape[0]): | |
probe[A_pos[i]] = A_pos[i - 1] | |
return probe | |
def array_cat(A: np.ndarray, n: int) -> np.ndarray: | |
"""Constructs an `array_cat` probe.""" | |
assert n > 0 | |
probe = np.zeros((A.shape[0], n)) | |
for i in range(A.shape[0]): | |
probe[i, A[i]] = 1 | |
return probe | |
def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray: | |
"""Constructs a `heap` probe.""" | |
assert heap_size > 0 | |
probe = np.arange(A_pos.shape[0]) | |
for i in range(1, heap_size): | |
probe[A_pos[i]] = A_pos[(i - 1) // 2] | |
return probe | |
def graph(A: np.ndarray) -> np.ndarray: | |
"""Constructs a `graph` probe.""" | |
probe = (A != 0) * 1.0 | |
probe = ((A + np.eye(A.shape[0])) != 0) * 1.0 | |
return probe | |
def mask_one(i: int, n: int) -> np.ndarray: | |
"""Constructs a `mask_one` probe.""" | |
assert n > i | |
probe = np.zeros(n) | |
probe[i] = 1 | |
return probe | |
def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: | |
"""Constructs a `strings_id` probe.""" | |
probe_T = np.zeros(T_pos.shape[0]) | |
probe_P = np.ones(P_pos.shape[0]) | |
return np.concatenate([probe_T, probe_P]) | |
def strings_pair(pair_probe: np.ndarray) -> np.ndarray: | |
"""Constructs a `strings_pair` probe.""" | |
n = pair_probe.shape[0] | |
m = pair_probe.shape[1] | |
probe_ret = np.zeros((n + m, n + m)) | |
for i in range(0, n): | |
for j in range(0, m): | |
probe_ret[i, j + n] = pair_probe[i, j] | |
return probe_ret | |
def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray: | |
"""Constructs a `strings_pair_cat` probe.""" | |
assert nb_classes > 0 | |
n = pair_probe.shape[0] | |
m = pair_probe.shape[1] | |
# Add an extra class for 'this cell left blank.' | |
probe_ret = np.zeros((n + m, n + m, nb_classes + 1)) | |
for i in range(0, n): | |
for j in range(0, m): | |
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE | |
# Fill the blank cells. | |
for i_1 in range(0, n): | |
for i_2 in range(0, n): | |
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED | |
for j_1 in range(0, m): | |
for x in range(0, n + m): | |
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED | |
return probe_ret | |
def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray, | |
pi: np.ndarray) -> np.ndarray: | |
"""Constructs a `strings_pi` probe.""" | |
probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) | |
for j in range(P_pos.shape[0]): | |
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]] | |
return probe | |
def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: | |
"""Constructs a `strings_pos` probe.""" | |
probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0] | |
probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0] | |
return np.concatenate([probe_T, probe_P]) | |
def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: | |
"""Constructs a `strings_pred` probe.""" | |
probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) | |
for i in range(1, T_pos.shape[0]): | |
probe[T_pos[i]] = T_pos[i - 1] | |
for j in range(1, P_pos.shape[0]): | |
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1] | |
return probe | |
def predecessor_to_cyclic_predecessor_and_first( | |
pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: | |
"""Converts predecessor pointers to cyclic predecessor + first node mask. | |
This function assumes that the pointers represent a linear order of the nodes | |
(akin to a linked list), where each node points to its predecessor and the | |
first node points to itself. It returns the same pointers, except that | |
the first node points to the last, and a mask_one marking the first node. | |
Example: | |
``` | |
pointers = [2, 1, 1] | |
P = [[0, 0, 1], | |
[1, 0, 0], | |
[0, 1, 0]], | |
M = [0, 1, 0] | |
``` | |
Args: | |
pointers: array of shape [N] containing pointers. The pointers are assumed | |
to describe a linear order such that `pointers[i]` is the predecessor | |
of node `i`. | |
Returns: | |
Permutation pointers `P` of shape [N] and one-hot vector `M` of shape [N]. | |
""" | |
nb_nodes = pointers.shape[-1] | |
pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes) | |
# Find the index of the last node: it's the node that no other node points to. | |
last = pointers_one_hot.sum(-2).argmin() | |
# Find the first node: should be the only one pointing to itself. | |
first = pointers_one_hot.diagonal().argmax() | |
mask = jax.nn.one_hot(first, nb_nodes) | |
pointers_one_hot += mask[..., None] * jax.nn.one_hot(last, nb_nodes) | |
pointers_one_hot -= mask[..., None] * mask | |
return pointers_one_hot, mask | |