|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Serialize and deserialize trees.""" |
|
|
|
import dataclasses |
|
import io |
|
import types |
|
from typing import Any, BinaryIO, Optional, TypeVar |
|
|
|
import numpy as np |
|
|
|
_T = TypeVar("_T") |
|
|
|
|
|
def dump(dest: BinaryIO, value: Any) -> None: |
|
"""Dump a tree of dicts/dataclasses to a file object. |
|
|
|
Args: |
|
dest: a file object to write to. |
|
value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and |
|
other basic types. Unions are not supported, other than Optional/None |
|
which is only supported in dataclasses, not in dicts, lists or tuples. |
|
All leaves must be coercible to a numpy array, and recoverable as a single |
|
arg to a type. |
|
""" |
|
buffer = io.BytesIO() |
|
np.savez(buffer, **_flatten(value)) |
|
dest.write(buffer.getvalue()) |
|
|
|
|
|
def load(source: BinaryIO, typ: type[_T]) -> _T: |
|
"""Load from a file object and convert it to the specified type. |
|
|
|
Args: |
|
source: a file object to read from. |
|
typ: a type object that acts as a schema for deserialization. It must match |
|
what was serialized. If a type is Any, it will be returned however numpy |
|
serialized it, which is what you want for a tree of numpy arrays. |
|
|
|
Returns: |
|
the deserialized value as the specified type. |
|
""" |
|
return _convert_types(typ, _unflatten(np.load(source))) |
|
|
|
|
|
_SEP = ":" |
|
|
|
|
|
def _flatten(tree: Any) -> dict[str, Any]: |
|
"""Flatten a tree of dicts/dataclasses/lists/tuples to a single dict.""" |
|
if dataclasses.is_dataclass(tree): |
|
|
|
tree = {f.name: v for f in dataclasses.fields(tree) |
|
if (v := getattr(tree, f.name)) is not None} |
|
elif isinstance(tree, (list, tuple)): |
|
tree = dict(enumerate(tree)) |
|
|
|
assert isinstance(tree, dict) |
|
|
|
flat = {} |
|
for k, v in tree.items(): |
|
k = str(k) |
|
assert _SEP not in k |
|
if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): |
|
for a, b in _flatten(v).items(): |
|
flat[f"{k}{_SEP}{a}"] = b |
|
else: |
|
assert v is not None |
|
flat[k] = v |
|
return flat |
|
|
|
|
|
def _unflatten(flat: dict[str, Any]) -> dict[str, Any]: |
|
"""Unflatten a dict to a tree of dicts.""" |
|
tree = {} |
|
for flat_key, v in flat.items(): |
|
node = tree |
|
keys = flat_key.split(_SEP) |
|
for k in keys[:-1]: |
|
if k not in node: |
|
node[k] = {} |
|
node = node[k] |
|
node[keys[-1]] = v |
|
return tree |
|
|
|
|
|
def _convert_types(typ: type[_T], value: Any) -> _T: |
|
"""Convert some structure into the given type. The structures must match.""" |
|
if typ in (Any, ...): |
|
return value |
|
|
|
if typ in (int, float, str, bool): |
|
return typ(value) |
|
|
|
if typ is np.ndarray: |
|
assert isinstance(value, np.ndarray) |
|
return value |
|
|
|
if dataclasses.is_dataclass(typ): |
|
kwargs = {} |
|
for f in dataclasses.fields(typ): |
|
|
|
|
|
|
|
|
|
if isinstance(f.type, (types.UnionType, type(Optional[int]))): |
|
constructors = [t for t in f.type.__args__ if t is not types.NoneType] |
|
if len(constructors) != 1: |
|
raise TypeError( |
|
"Optional works, Union with anything except None doesn't") |
|
if f.name not in value: |
|
kwargs[f.name] = None |
|
continue |
|
constructor = constructors[0] |
|
else: |
|
constructor = f.type |
|
|
|
if f.name in value: |
|
kwargs[f.name] = _convert_types(constructor, value[f.name]) |
|
else: |
|
raise ValueError(f"Missing value: {f.name}") |
|
return typ(**kwargs) |
|
|
|
base_type = getattr(typ, "__origin__", None) |
|
|
|
if base_type is dict: |
|
assert len(typ.__args__) == 2 |
|
key_type, value_type = typ.__args__ |
|
return {_convert_types(key_type, k): _convert_types(value_type, v) |
|
for k, v in value.items()} |
|
|
|
if base_type is list: |
|
assert len(typ.__args__) == 1 |
|
value_type = typ.__args__[0] |
|
return [_convert_types(value_type, v) |
|
for _, v in sorted(value.items(), key=lambda x: int(x[0]))] |
|
|
|
if base_type is tuple: |
|
if len(typ.__args__) == 2 and typ.__args__[1] == ...: |
|
|
|
value_type = typ.__args__[0] |
|
return tuple(_convert_types(value_type, v) |
|
for _, v in sorted(value.items(), key=lambda x: int(x[0]))) |
|
else: |
|
|
|
assert len(typ.__args__) == len(value) |
|
return tuple( |
|
_convert_types(t, v) |
|
for t, (_, v) in zip( |
|
typ.__args__, sorted(value.items(), key=lambda x: int(x[0])))) |
|
|
|
|
|
try: |
|
return typ(value) |
|
except TypeError as e: |
|
raise TypeError( |
|
"_convert_types expects the type argument to be a dataclass defined " |
|
"with types that are valid constructors (eg tuple is fine, Tuple " |
|
"isn't), and accept a numpy array as the sole argument.") from e |
|
|