|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for working with trees of xarray.DataArray (including Datasets). |
|
|
|
Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library; |
|
it won't work as a leaf node since it implements Mapping, but also won't work |
|
as an internal node since tree doesn't know how to re-create it properly. |
|
|
|
To fix this, we reimplement a subset of `map_structure`, exposing its |
|
constituent DataArrays as leaf nodes. This means it can be mapped over as a |
|
generic container of DataArrays, while still preserving the result as a Dataset |
|
where possible. |
|
|
|
This is useful because in a few places we need to handle a general |
|
Mapping[str, DataArray] (where the coordinates might not be compatible across |
|
the constituent DataArrays) but also the special case of a Dataset nicely. |
|
|
|
For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for |
|
some of the child DataArrays, they will be omitted from the returned dataset. If |
|
any values other than DataArrays or None are returned, then we don't attempt to |
|
return a Dataset and just return a plain dict of the results. Similarly if |
|
DataArrays are returned but with non-matching coordinates, it will just return a |
|
plain dict of DataArrays. |
|
|
|
Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py, |
|
but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`. |
|
as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the |
|
latter exposes DataArrays as leaf nodes. |
|
""" |
|
|
|
from typing import Any, Callable |
|
|
|
import xarray |
|
|
|
|
|
def map_structure(func: Callable[..., Any], *structures: Any) -> Any: |
|
"""Maps func through given structures with xarrays. See tree.map_structure.""" |
|
if not callable(func): |
|
raise TypeError(f'func must be callable, got: {func}') |
|
if not structures: |
|
raise ValueError('Must provide at least one structure') |
|
|
|
first = structures[0] |
|
if isinstance(first, xarray.Dataset): |
|
data = {k: func(*[s[k] for s in structures]) for k in first.keys()} |
|
if all(isinstance(a, (type(None), xarray.DataArray)) |
|
for a in data.values()): |
|
data_arrays = [v.rename(k) for k, v in data.items() if v is not None] |
|
try: |
|
return xarray.merge(data_arrays, join='exact') |
|
except ValueError: |
|
pass |
|
return data |
|
if isinstance(first, dict): |
|
return {k: map_structure(func, *[s[k] for s in structures]) |
|
for k in first.keys()} |
|
if isinstance(first, (list, tuple, set)): |
|
return type(first)(map_structure(func, *s) for s in zip(*structures)) |
|
return func(*structures) |
|
|