|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities for building models.""" |
|
|
|
from typing import Mapping, Optional, Tuple |
|
|
|
import numpy as np |
|
from scipy.spatial import transform |
|
import xarray |
|
|
|
|
|
def get_graph_spatial_features( |
|
*, node_lat: np.ndarray, node_lon: np.ndarray, |
|
senders: np.ndarray, receivers: np.ndarray, |
|
add_node_positions: bool, |
|
add_node_latitude: bool, |
|
add_node_longitude: bool, |
|
add_relative_positions: bool, |
|
relative_longitude_local_coordinates: bool, |
|
relative_latitude_local_coordinates: bool, |
|
sine_cosine_encoding: bool = False, |
|
encoding_num_freqs: int = 10, |
|
encoding_multiplicative_factor: float = 1.2, |
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
"""Computes spatial features for the nodes. |
|
|
|
Args: |
|
node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes] |
|
node_lon: Longitudes in the [0, 360] interval of shape [num_nodes] |
|
senders: Sender indices of shape [num_edges] |
|
receivers: Receiver indices of shape [num_edges] |
|
add_node_positions: Add unit norm absolute positions. |
|
add_node_latitude: Add a feature for latitude (cos(90 - lat)) |
|
Note even if this is set to False, the model may be able to infer the |
|
longitude from relative features, unless |
|
`relative_latitude_local_coordinates` is also True, or if there is any |
|
bias on the relative edge sizes for different longitudes. |
|
add_node_longitude: Add features for longitude (cos(lon), sin(lon)). |
|
Note even if this is set to False, the model may be able to infer the |
|
longitude from relative features, unless |
|
`relative_longitude_local_coordinates` is also True, or if there is any |
|
bias on the relative edge sizes for different longitudes. |
|
add_relative_positions: Whether to relative positions in R3 to the edges. |
|
relative_longitude_local_coordinates: If True, relative positions are |
|
computed in a local space where the receiver is at 0 longitude. |
|
relative_latitude_local_coordinates: If True, relative positions are |
|
computed in a local space where the receiver is at 0 latitude. |
|
sine_cosine_encoding: If True, we will transform the node/edge features |
|
with sine and cosine functions, similar to NERF. |
|
encoding_num_freqs: frequency parameter |
|
encoding_multiplicative_factor: used for calculating the frequency. |
|
|
|
Returns: |
|
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features]. |
|
with node and edge features. |
|
|
|
""" |
|
|
|
num_nodes = node_lat.shape[0] |
|
num_edges = senders.shape[0] |
|
dtype = node_lat.dtype |
|
node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon) |
|
|
|
|
|
node_features = [] |
|
if add_node_positions: |
|
|
|
node_features.extend(spherical_to_cartesian(node_phi, node_theta)) |
|
|
|
if add_node_latitude: |
|
|
|
|
|
node_features.append(np.cos(node_theta)) |
|
|
|
if add_node_longitude: |
|
|
|
node_features.append(np.cos(node_phi)) |
|
node_features.append(np.sin(node_phi)) |
|
|
|
if not node_features: |
|
node_features = np.zeros([num_nodes, 0], dtype=dtype) |
|
else: |
|
node_features = np.stack(node_features, axis=-1) |
|
|
|
|
|
edge_features = [] |
|
|
|
if add_relative_positions: |
|
|
|
relative_position = get_relative_position_in_receiver_local_coordinates( |
|
node_phi=node_phi, |
|
node_theta=node_theta, |
|
senders=senders, |
|
receivers=receivers, |
|
latitude_local_coordinates=relative_latitude_local_coordinates, |
|
longitude_local_coordinates=relative_longitude_local_coordinates |
|
) |
|
|
|
|
|
relative_edge_distances = np.linalg.norm( |
|
relative_position, axis=-1, keepdims=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
max_edge_distance = relative_edge_distances.max() |
|
edge_features.append(relative_edge_distances / max_edge_distance) |
|
edge_features.append(relative_position / max_edge_distance) |
|
|
|
if not edge_features: |
|
edge_features = np.zeros([num_edges, 0], dtype=dtype) |
|
else: |
|
edge_features = np.concatenate(edge_features, axis=-1) |
|
|
|
if sine_cosine_encoding: |
|
def sine_cosine_transform(x: np.ndarray) -> np.ndarray: |
|
freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs) |
|
phases = freqs * x[..., None] |
|
x_sin = np.sin(phases) |
|
x_cos = np.cos(phases) |
|
x_cat = np.concatenate([x_sin, x_cos], axis=-1) |
|
return x_cat.reshape([x.shape[0], -1]) |
|
|
|
node_features = sine_cosine_transform(node_features) |
|
edge_features = sine_cosine_transform(edge_features) |
|
|
|
return node_features, edge_features |
|
|
|
|
|
def lat_lon_to_leading_axes( |
|
grid_xarray: xarray.DataArray) -> xarray.DataArray: |
|
"""Reorders xarray so lat/lon axes come first.""" |
|
|
|
|
|
|
|
return grid_xarray.transpose("lat", "lon", ...) |
|
|
|
|
|
def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray: |
|
"""Reorders xarray so batch/time/level axes come first (if present).""" |
|
|
|
|
|
|
|
|
|
|
|
input_dims = list(grid_xarray.dims) |
|
output_dims = list(input_dims) |
|
for leading_key in ["level", "time", "batch"]: |
|
if leading_key in input_dims: |
|
output_dims.remove(leading_key) |
|
output_dims.insert(0, leading_key) |
|
return grid_xarray.transpose(*output_dims) |
|
|
|
|
|
def lat_lon_deg_to_spherical(node_lat: np.ndarray, |
|
node_lon: np.ndarray, |
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
phi = np.deg2rad(node_lon) |
|
theta = np.deg2rad(90 - node_lat) |
|
return phi, theta |
|
|
|
|
|
def spherical_to_lat_lon(phi: np.ndarray, |
|
theta: np.ndarray, |
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
lon = np.mod(np.rad2deg(phi), 360) |
|
lat = 90 - np.rad2deg(theta) |
|
return lat, lon |
|
|
|
|
|
def cartesian_to_spherical(x: np.ndarray, |
|
y: np.ndarray, |
|
z: np.ndarray, |
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
phi = np.arctan2(y, x) |
|
with np.errstate(invalid="ignore"): |
|
theta = np.arccos(z) |
|
return phi, theta |
|
|
|
|
|
def spherical_to_cartesian( |
|
phi: np.ndarray, theta: np.ndarray |
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
|
|
return (np.cos(phi)*np.sin(theta), |
|
np.sin(phi)*np.sin(theta), |
|
np.cos(theta)) |
|
|
|
|
|
def get_relative_position_in_receiver_local_coordinates( |
|
node_phi: np.ndarray, |
|
node_theta: np.ndarray, |
|
senders: np.ndarray, |
|
receivers: np.ndarray, |
|
latitude_local_coordinates: bool, |
|
longitude_local_coordinates: bool |
|
) -> np.ndarray: |
|
"""Returns relative position features for the edges. |
|
|
|
The relative positions will be computed in a rotated space for a local |
|
coordinate system as defined by the receiver. The relative positions are |
|
simply obtained by subtracting sender position minues receiver position in |
|
that local coordinate system after the rotation in R^3. |
|
|
|
Args: |
|
node_phi: [num_nodes] with polar angles. |
|
node_theta: [num_nodes] with azimuthal angles. |
|
senders: [num_edges] with indices. |
|
receivers: [num_edges] with indices. |
|
latitude_local_coordinates: Whether to rotate edges such that in the |
|
positions are computed such that the receiver is always at latitude 0. |
|
longitude_local_coordinates: Whether to rotate edges such that in the |
|
positions are computed such that the receiver is always at longitude 0. |
|
|
|
Returns: |
|
Array of relative positions in R3 [num_edges, 3] |
|
""" |
|
|
|
node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1) |
|
|
|
|
|
if not (latitude_local_coordinates or longitude_local_coordinates): |
|
return node_pos[senders] - node_pos[receivers] |
|
|
|
|
|
rotation_matrices = get_rotation_matrices_to_local_coordinates( |
|
reference_phi=node_phi, |
|
reference_theta=node_theta, |
|
rotate_latitude=latitude_local_coordinates, |
|
rotate_longitude=longitude_local_coordinates) |
|
|
|
|
|
|
|
edge_rotation_matrices = rotation_matrices[receivers] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
receiver_pos_in_rotated_space = rotate_with_matrices( |
|
edge_rotation_matrices, node_pos[receivers]) |
|
sender_pos_in_in_rotated_space = rotate_with_matrices( |
|
edge_rotation_matrices, node_pos[senders]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space |
|
|
|
|
|
def get_rotation_matrices_to_local_coordinates( |
|
reference_phi: np.ndarray, |
|
reference_theta: np.ndarray, |
|
rotate_latitude: bool, |
|
rotate_longitude: bool) -> np.ndarray: |
|
|
|
"""Returns a rotation matrix to rotate to a point based on a reference vector. |
|
|
|
The rotation matrix is build such that, a vector in the |
|
same coordinate system at the reference point that points towards the pole |
|
before the rotation, continues to point towards the pole after the rotation. |
|
|
|
Args: |
|
reference_phi: [leading_axis] Polar angles of the reference. |
|
reference_theta: [leading_axis] Azimuthal angles of the reference. |
|
rotate_latitude: Whether to produce a rotation matrix that would rotate |
|
R^3 vectors to zero latitude. |
|
rotate_longitude: Whether to produce a rotation matrix that would rotate |
|
R^3 vectors to zero longitude. |
|
|
|
Returns: |
|
Matrices of shape [leading_axis] such that when applied to the reference |
|
position with `rotate_with_matrices(rotation_matrices, reference_pos)` |
|
|
|
* phi goes to 0. if "rotate_longitude" is True. |
|
|
|
* theta goes to np.pi / 2 if "rotate_latitude" is True. |
|
|
|
The rotation consists of: |
|
* rotate_latitude = False, rotate_longitude = True: |
|
Latitude preserving rotation. |
|
* rotate_latitude = True, rotate_longitude = True: |
|
Latitude preserving rotation, followed by longitude preserving |
|
rotation. |
|
* rotate_latitude = True, rotate_longitude = False: |
|
Latitude preserving rotation, followed by longitude preserving |
|
rotation, and the inverse of the latitude preserving rotation. Note |
|
this is computationally different from rotating the longitude only |
|
and is. We do it like this, so the polar geodesic curve, continues |
|
to be aligned with one of the axis after the rotation. |
|
|
|
""" |
|
|
|
if rotate_longitude and rotate_latitude: |
|
|
|
|
|
|
|
azimuthal_rotation = - reference_phi |
|
|
|
|
|
|
|
|
|
polar_rotation = - reference_theta + np.pi/2 |
|
|
|
return transform.Rotation.from_euler( |
|
"zy", np.stack([azimuthal_rotation, polar_rotation], |
|
axis=1)).as_matrix() |
|
elif rotate_longitude: |
|
|
|
azimuthal_rotation = - reference_phi |
|
return transform.Rotation.from_euler("z", -reference_phi).as_matrix() |
|
elif rotate_latitude: |
|
|
|
|
|
azimuthal_rotation = - reference_phi |
|
polar_rotation = - reference_theta + np.pi/2 |
|
|
|
return transform.Rotation.from_euler( |
|
"zyz", np.stack( |
|
[azimuthal_rotation, polar_rotation, -azimuthal_rotation] |
|
, axis=1)).as_matrix() |
|
else: |
|
raise ValueError( |
|
"At least one of longitude and latitude should be rotated.") |
|
|
|
|
|
def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray |
|
) -> np.ndarray: |
|
return np.einsum("bji,bi->bj", rotation_matrices, positions) |
|
|
|
|
|
def get_bipartite_graph_spatial_features( |
|
*, |
|
senders_node_lat: np.ndarray, |
|
senders_node_lon: np.ndarray, |
|
senders: np.ndarray, |
|
receivers_node_lat: np.ndarray, |
|
receivers_node_lon: np.ndarray, |
|
receivers: np.ndarray, |
|
add_node_positions: bool, |
|
add_node_latitude: bool, |
|
add_node_longitude: bool, |
|
add_relative_positions: bool, |
|
edge_normalization_factor: Optional[float] = None, |
|
relative_longitude_local_coordinates: bool, |
|
relative_latitude_local_coordinates: bool, |
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
"""Computes spatial features for the nodes. |
|
|
|
This function is almost identical to `get_graph_spatial_features`. The only |
|
difference is that sender nodes and receiver nodes can be in different arrays. |
|
This is necessary to enable combination with typed Graph. |
|
|
|
Args: |
|
senders_node_lat: Latitudes in the [-90, 90] interval of shape |
|
[num_sender_nodes] |
|
senders_node_lon: Longitudes in the [0, 360] interval of shape |
|
[num_sender_nodes] |
|
senders: Sender indices of shape [num_edges], indices in [0, |
|
num_sender_nodes) |
|
receivers_node_lat: Latitudes in the [-90, 90] interval of shape |
|
[num_receiver_nodes] |
|
receivers_node_lon: Longitudes in the [0, 360] interval of shape |
|
[num_receiver_nodes] |
|
receivers: Receiver indices of shape [num_edges], indices in [0, |
|
num_receiver_nodes) |
|
add_node_positions: Add unit norm absolute positions. |
|
add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if |
|
this is set to False, the model may be able to infer the longitude from |
|
relative features, unless `relative_latitude_local_coordinates` is also |
|
True, or if there is any bias on the relative edge sizes for different |
|
longitudes. |
|
add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note |
|
even if this is set to False, the model may be able to infer the longitude |
|
from relative features, unless `relative_longitude_local_coordinates` is |
|
also True, or if there is any bias on the relative edge sizes for |
|
different longitudes. |
|
add_relative_positions: Whether to relative positions in R3 to the edges. |
|
edge_normalization_factor: Allows explicitly controlling edge normalization. |
|
If None, defaults to max edge length. This supports using pre-trained |
|
model weights with a different graph structure to what it was trained on. |
|
relative_longitude_local_coordinates: If True, relative positions are |
|
computed in a local space where the receiver is at 0 longitude. |
|
relative_latitude_local_coordinates: If True, relative positions are |
|
computed in a local space where the receiver is at 0 latitude. |
|
|
|
Returns: |
|
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features]. |
|
with node and edge features. |
|
|
|
""" |
|
|
|
num_senders = senders_node_lat.shape[0] |
|
num_receivers = receivers_node_lat.shape[0] |
|
num_edges = senders.shape[0] |
|
dtype = senders_node_lat.dtype |
|
assert receivers_node_lat.dtype == dtype |
|
senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical( |
|
senders_node_lat, senders_node_lon) |
|
receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical( |
|
receivers_node_lat, receivers_node_lon) |
|
|
|
|
|
senders_node_features = [] |
|
receivers_node_features = [] |
|
if add_node_positions: |
|
|
|
senders_node_features.extend( |
|
spherical_to_cartesian(senders_node_phi, senders_node_theta)) |
|
receivers_node_features.extend( |
|
spherical_to_cartesian(receivers_node_phi, receivers_node_theta)) |
|
|
|
if add_node_latitude: |
|
|
|
|
|
senders_node_features.append(np.cos(senders_node_theta)) |
|
receivers_node_features.append(np.cos(receivers_node_theta)) |
|
|
|
if add_node_longitude: |
|
|
|
senders_node_features.append(np.cos(senders_node_phi)) |
|
senders_node_features.append(np.sin(senders_node_phi)) |
|
|
|
receivers_node_features.append(np.cos(receivers_node_phi)) |
|
receivers_node_features.append(np.sin(receivers_node_phi)) |
|
|
|
if not senders_node_features: |
|
senders_node_features = np.zeros([num_senders, 0], dtype=dtype) |
|
receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype) |
|
else: |
|
senders_node_features = np.stack(senders_node_features, axis=-1) |
|
receivers_node_features = np.stack(receivers_node_features, axis=-1) |
|
|
|
|
|
edge_features = [] |
|
|
|
if add_relative_positions: |
|
|
|
relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( |
|
senders_node_phi=senders_node_phi, |
|
senders_node_theta=senders_node_theta, |
|
receivers_node_phi=receivers_node_phi, |
|
receivers_node_theta=receivers_node_theta, |
|
senders=senders, |
|
receivers=receivers, |
|
latitude_local_coordinates=relative_latitude_local_coordinates, |
|
longitude_local_coordinates=relative_longitude_local_coordinates) |
|
|
|
|
|
relative_edge_distances = np.linalg.norm( |
|
relative_position, axis=-1, keepdims=True) |
|
|
|
if edge_normalization_factor is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
edge_normalization_factor = relative_edge_distances.max() |
|
|
|
edge_features.append(relative_edge_distances / edge_normalization_factor) |
|
edge_features.append(relative_position / edge_normalization_factor) |
|
|
|
if not edge_features: |
|
edge_features = np.zeros([num_edges, 0], dtype=dtype) |
|
else: |
|
edge_features = np.concatenate(edge_features, axis=-1) |
|
|
|
return senders_node_features, receivers_node_features, edge_features |
|
|
|
|
|
def get_bipartite_relative_position_in_receiver_local_coordinates( |
|
senders_node_phi: np.ndarray, |
|
senders_node_theta: np.ndarray, |
|
senders: np.ndarray, |
|
receivers_node_phi: np.ndarray, |
|
receivers_node_theta: np.ndarray, |
|
receivers: np.ndarray, |
|
latitude_local_coordinates: bool, |
|
longitude_local_coordinates: bool) -> np.ndarray: |
|
"""Returns relative position features for the edges. |
|
|
|
This function is equivalent to |
|
`get_relative_position_in_receiver_local_coordinates`, but adapted to work |
|
with bipartite typed graphs. |
|
|
|
The relative positions will be computed in a rotated space for a local |
|
coordinate system as defined by the receiver. The relative positions are |
|
simply obtained by subtracting sender position minues receiver position in |
|
that local coordinate system after the rotation in R^3. |
|
|
|
Args: |
|
senders_node_phi: [num_sender_nodes] with polar angles. |
|
senders_node_theta: [num_sender_nodes] with azimuthal angles. |
|
senders: [num_edges] with indices into sender nodes. |
|
receivers_node_phi: [num_sender_nodes] with polar angles. |
|
receivers_node_theta: [num_sender_nodes] with azimuthal angles. |
|
receivers: [num_edges] with indices into receiver nodes. |
|
latitude_local_coordinates: Whether to rotate edges such that in the |
|
positions are computed such that the receiver is always at latitude 0. |
|
longitude_local_coordinates: Whether to rotate edges such that in the |
|
positions are computed such that the receiver is always at longitude 0. |
|
|
|
Returns: |
|
Array of relative positions in R3 [num_edges, 3] |
|
""" |
|
|
|
senders_node_pos = np.stack( |
|
spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1) |
|
|
|
receivers_node_pos = np.stack( |
|
spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1) |
|
|
|
|
|
if not (latitude_local_coordinates or longitude_local_coordinates): |
|
return senders_node_pos[senders] - receivers_node_pos[receivers] |
|
|
|
|
|
receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates( |
|
reference_phi=receivers_node_phi, |
|
reference_theta=receivers_node_theta, |
|
rotate_latitude=latitude_local_coordinates, |
|
rotate_longitude=longitude_local_coordinates) |
|
|
|
|
|
|
|
edge_rotation_matrices = receiver_rotation_matrices[receivers] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
receiver_pos_in_rotated_space = rotate_with_matrices( |
|
edge_rotation_matrices, receivers_node_pos[receivers]) |
|
sender_pos_in_in_rotated_space = rotate_with_matrices( |
|
edge_rotation_matrices, senders_node_pos[senders]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space |
|
|
|
|
|
def variable_to_stacked( |
|
variable: xarray.Variable, |
|
sizes: Mapping[str, int], |
|
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), |
|
) -> xarray.Variable: |
|
"""Converts an xarray.Variable to preserved_dims + ("channels",). |
|
|
|
Any dimensions other than those included in preserved_dims get stacked into a |
|
final "channels" dimension. If any of the preserved_dims are missing then they |
|
are added, with the data broadcast/tiled to match the sizes specified in |
|
`sizes`. |
|
|
|
Args: |
|
variable: An xarray.Variable. |
|
sizes: Mapping including sizes for any dimensions which are not present in |
|
`variable` but are needed for the output. This may be needed for example |
|
for a static variable with only ("lat", "lon") dims, or if you want to |
|
encode just the latitude coordinates (a variable with dims ("lat",)). |
|
preserved_dims: dimensions of variable to not be folded in channels. |
|
|
|
Returns: |
|
An xarray.Variable with dimensions preserved_dims + ("channels",). |
|
""" |
|
stack_to_channels_dims = [ |
|
d for d in variable.dims if d not in preserved_dims] |
|
if stack_to_channels_dims: |
|
variable = variable.stack(channels=stack_to_channels_dims) |
|
dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims} |
|
dims["channels"] = variable.sizes.get("channels", 1) |
|
return variable.set_dims(dims) |
|
|
|
|
|
def dataset_to_stacked( |
|
dataset: xarray.Dataset, |
|
sizes: Optional[Mapping[str, int]] = None, |
|
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), |
|
) -> xarray.DataArray: |
|
"""Converts an xarray.Dataset to a single stacked array. |
|
|
|
This takes each consistuent data_var, converts it into BHWC layout |
|
using `variable_to_stacked`, then concats them all along the channels axis. |
|
|
|
Args: |
|
dataset: An xarray.Dataset. |
|
sizes: Mapping including sizes for any dimensions which are not present in |
|
the `dataset` but are needed for the output. See variable_to_stacked. |
|
preserved_dims: dimensions from the dataset that should not be folded in |
|
the predictions channels. |
|
|
|
Returns: |
|
An xarray.DataArray with dimensions preserved_dims + ("channels",). |
|
Existing coordinates for preserved_dims axes will be preserved, however |
|
there will be no coordinates for "channels". |
|
""" |
|
data_vars = [ |
|
variable_to_stacked(dataset.variables[name], sizes or dataset.sizes, |
|
preserved_dims) |
|
for name in sorted(dataset.data_vars.keys()) |
|
] |
|
coords = { |
|
dim: coord |
|
for dim, coord in dataset.coords.items() |
|
if dim in preserved_dims |
|
} |
|
return xarray.DataArray( |
|
data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords) |
|
|
|
|
|
def stacked_to_dataset( |
|
stacked_array: xarray.Variable, |
|
template_dataset: xarray.Dataset, |
|
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"), |
|
) -> xarray.Dataset: |
|
"""The inverse of dataset_to_stacked. |
|
|
|
Requires a template dataset to demonstrate the variables/shapes/coordinates |
|
required. |
|
All variables must have preserved_dims dimensions. |
|
|
|
Args: |
|
stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked |
|
would if it was asked to encode `template_dataset`. |
|
template_dataset: A template Dataset (or other mapping of DataArrays) |
|
demonstrating the shape of output required (variables, shapes, |
|
coordinates etc). |
|
preserved_dims: dimensions from the target_template that were not folded in |
|
the predictions channels. The preserved_dims need to be a subset of the |
|
dims of all the variables of template_dataset. |
|
|
|
Returns: |
|
An xarray.Dataset (or other mapping of DataArrays) with the same shape and |
|
type as template_dataset. |
|
""" |
|
unstack_from_channels_sizes = {} |
|
var_names = sorted(template_dataset.keys()) |
|
for name in var_names: |
|
template_var = template_dataset[name] |
|
if not all(dim in template_var.dims for dim in preserved_dims): |
|
raise ValueError( |
|
f"stacked_to_dataset requires all Variables to have {preserved_dims} " |
|
f"dimensions, but found only {template_var.dims}.") |
|
unstack_from_channels_sizes[name] = { |
|
dim: size for dim, size in template_var.sizes.items() |
|
if dim not in preserved_dims} |
|
|
|
channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64) |
|
for name, unstack_sizes in unstack_from_channels_sizes.items()} |
|
total_expected_channels = sum(channels.values()) |
|
found_channels = stacked_array.sizes["channels"] |
|
if total_expected_channels != found_channels: |
|
raise ValueError( |
|
f"Expected {total_expected_channels} channels but found " |
|
f"{found_channels}, when trying to convert a stacked array of shape " |
|
f"{stacked_array.sizes} to a dataset of shape {template_dataset}.") |
|
|
|
data_vars = {} |
|
index = 0 |
|
for name in var_names: |
|
template_var = template_dataset[name] |
|
var = stacked_array.isel({"channels": slice(index, index + channels[name])}) |
|
index += channels[name] |
|
var = var.unstack({"channels": unstack_from_channels_sizes[name]}) |
|
var = var.transpose(*template_var.dims) |
|
data_vars[name] = xarray.DataArray( |
|
data=var, |
|
coords=template_var.coords, |
|
|
|
|
|
|
|
name=template_var.name, |
|
) |
|
return type(template_dataset)(data_vars) |
|
|