Upload 25 files
Browse files- graphcast/autoregressive.py +312 -0
- graphcast/casting.py +205 -0
- graphcast/checkpoint.py +170 -0
- graphcast/checkpoint_test.py +124 -0
- graphcast/data_utils.py +359 -0
- graphcast/data_utils_test.py +310 -0
- graphcast/deep_typed_graph_net.py +391 -0
- graphcast/graphcast.py +796 -0
- graphcast/grid_mesh_connectivity.py +133 -0
- graphcast/grid_mesh_connectivity_test.py +74 -0
- graphcast/icosahedral_mesh.py +281 -0
- graphcast/icosahedral_mesh_test.py +131 -0
- graphcast/losses.py +179 -0
- graphcast/model_utils.py +724 -0
- graphcast/normalization.py +196 -0
- graphcast/predictor_base.py +170 -0
- graphcast/rollout.py +269 -0
- graphcast/solar_radiation.py +605 -0
- graphcast/solar_radiation_test.py +240 -0
- graphcast/typed_graph.py +97 -0
- graphcast/typed_graph_net.py +317 -0
- graphcast/xarray_jax.py +810 -0
- graphcast/xarray_jax_test.py +526 -0
- graphcast/xarray_tree.py +70 -0
- graphcast/xarray_tree_test.py +95 -0
graphcast/autoregressive.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""A Predictor wrapping a one-step Predictor to make autoregressive predictions.
|
15 |
+
"""
|
16 |
+
|
17 |
+
from typing import Optional, cast
|
18 |
+
|
19 |
+
from absl import logging
|
20 |
+
from graphcast import predictor_base
|
21 |
+
from graphcast import xarray_jax
|
22 |
+
from graphcast import xarray_tree
|
23 |
+
import haiku as hk
|
24 |
+
import jax
|
25 |
+
import xarray
|
26 |
+
|
27 |
+
|
28 |
+
def _unflatten_and_expand_time(flat_variables, tree_def, time_coords):
|
29 |
+
variables = jax.tree_util.tree_unflatten(tree_def, flat_variables)
|
30 |
+
return variables.expand_dims(time=time_coords, axis=0)
|
31 |
+
|
32 |
+
|
33 |
+
def _get_flat_arrays_and_single_timestep_treedef(variables):
|
34 |
+
flat_arrays = jax.tree_util.tree_leaves(variables.transpose('time', ...))
|
35 |
+
_, treedef = jax.tree_util.tree_flatten(variables.isel(time=0, drop=True))
|
36 |
+
return flat_arrays, treedef
|
37 |
+
|
38 |
+
|
39 |
+
class Predictor(predictor_base.Predictor):
|
40 |
+
"""Wraps a one-step Predictor to make multi-step predictions autoregressively.
|
41 |
+
|
42 |
+
The wrapped Predictor will be used to predict a single timestep conditional
|
43 |
+
on the inputs passed to the outer Predictor. Its predictions are then
|
44 |
+
passed back in as inputs at the next timestep, for as many timesteps as are
|
45 |
+
requested in the targets_template. (When multiple timesteps of input are
|
46 |
+
used, a rolling window of inputs is maintained with new predictions
|
47 |
+
concatenated onto the end).
|
48 |
+
|
49 |
+
You may ask for additional variables to be predicted as targets which aren't
|
50 |
+
used as inputs. These will be predicted as output variables only and not fed
|
51 |
+
back in autoregressively. All target variables must be time-dependent however.
|
52 |
+
|
53 |
+
You may also specify static (non-time-dependent) inputs which will be passed
|
54 |
+
in at each timestep but are not predicted.
|
55 |
+
|
56 |
+
At present, any time-dependent inputs must also be present as targets so they
|
57 |
+
can be passed in autoregressively.
|
58 |
+
|
59 |
+
The loss of the wrapped one-step Predictor is averaged over all timesteps to
|
60 |
+
give a loss for the autoregressive Predictor.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
predictor: predictor_base.Predictor,
|
66 |
+
noise_level: Optional[float] = None,
|
67 |
+
gradient_checkpointing: bool = False,
|
68 |
+
):
|
69 |
+
"""Initializes an autoregressive predictor wrapper.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
predictor: A predictor to wrap in an auto-regressive way.
|
73 |
+
noise_level: Optional value that multiplies the standard normal noise
|
74 |
+
added to the time-dependent variables of the predictor inputs. In
|
75 |
+
particular, no noise is added to the predictions that are fed back
|
76 |
+
auto-regressively. Defaults to not adding noise.
|
77 |
+
gradient_checkpointing: If True, gradient checkpointing will be
|
78 |
+
used at each step of the computation to save on memory. Roughtly this
|
79 |
+
should make the backwards pass two times more expensive, and the time
|
80 |
+
per step counting the forward pass, should only increase by about 50%.
|
81 |
+
Note this parameter will be ignored with a warning if the scan sequence
|
82 |
+
length is 1.
|
83 |
+
"""
|
84 |
+
self._predictor = predictor
|
85 |
+
self._noise_level = noise_level
|
86 |
+
self._gradient_checkpointing = gradient_checkpointing
|
87 |
+
|
88 |
+
def _get_and_validate_constant_inputs(self, inputs, targets, forcings):
|
89 |
+
constant_inputs = inputs.drop_vars(targets.keys(), errors='ignore')
|
90 |
+
constant_inputs = constant_inputs.drop_vars(
|
91 |
+
forcings.keys(), errors='ignore')
|
92 |
+
for name, var in constant_inputs.items():
|
93 |
+
if 'time' in var.dims:
|
94 |
+
raise ValueError(
|
95 |
+
f'Time-dependent input variable {name} must either be a forcing '
|
96 |
+
'variable, or a target variable to allow for auto-regressive '
|
97 |
+
'feedback.')
|
98 |
+
return constant_inputs
|
99 |
+
|
100 |
+
def _validate_targets_and_forcings(self, targets, forcings):
|
101 |
+
for name, var in targets.items():
|
102 |
+
if 'time' not in var.dims:
|
103 |
+
raise ValueError(f'Target variable {name} must be time-dependent.')
|
104 |
+
|
105 |
+
for name, var in forcings.items():
|
106 |
+
if 'time' not in var.dims:
|
107 |
+
raise ValueError(f'Forcing variable {name} must be time-dependent.')
|
108 |
+
|
109 |
+
overlap = forcings.keys() & targets.keys()
|
110 |
+
if overlap:
|
111 |
+
raise ValueError('The following were specified as both targets and '
|
112 |
+
f'forcings, which isn\'t allowed: {overlap}')
|
113 |
+
|
114 |
+
def _update_inputs(self, inputs, next_frame):
|
115 |
+
num_inputs = inputs.dims['time']
|
116 |
+
|
117 |
+
predicted_or_forced_inputs = next_frame[list(inputs.keys())]
|
118 |
+
|
119 |
+
# Combining datasets with inputs and target time stamps aligns them.
|
120 |
+
# Only keep the num_inputs trailing frames for use as next inputs.
|
121 |
+
return (xarray.concat([inputs, predicted_or_forced_inputs], dim='time')
|
122 |
+
.tail(time=num_inputs)
|
123 |
+
# Update the time coordinate to reset the lead times for
|
124 |
+
# next AR iteration.
|
125 |
+
.assign_coords(time=inputs.coords['time']))
|
126 |
+
|
127 |
+
def __call__(self,
|
128 |
+
inputs: xarray.Dataset,
|
129 |
+
targets_template: xarray.Dataset,
|
130 |
+
forcings: xarray.Dataset,
|
131 |
+
**kwargs) -> xarray.Dataset:
|
132 |
+
"""Calls the Predictor.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
inputs: input variable used to make predictions. Inputs can include both
|
136 |
+
time-dependent and time independent variables. Any time-dependent
|
137 |
+
input variables must also be present in the targets_template or the
|
138 |
+
forcings.
|
139 |
+
targets_template: A target template containing informations about which
|
140 |
+
variables should be predicted and the time alignment of the predictions.
|
141 |
+
All target variables must be time-dependent.
|
142 |
+
The number of time frames is used to set the number of unroll of the AR
|
143 |
+
predictor (e.g. multiple unroll of the inner predictor for one time step
|
144 |
+
in the targets is not supported yet).
|
145 |
+
forcings: Variables that will be fed to the model. The variables
|
146 |
+
should not overlap with the target ones. The time coordinates of the
|
147 |
+
forcing variables should match the target ones.
|
148 |
+
Forcing variables which are also present in the inputs, will be used to
|
149 |
+
supply ground-truth values for those inputs when they are passed to the
|
150 |
+
underlying predictor at timesteps beyond the first timestep.
|
151 |
+
**kwargs: Additional arguments passed along to the inner Predictor.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
predictions: the model predictions matching the target template.
|
155 |
+
|
156 |
+
Raise:
|
157 |
+
ValueError: if the time coordinates of the inputs and targets are not
|
158 |
+
different by a constant time step.
|
159 |
+
"""
|
160 |
+
|
161 |
+
constant_inputs = self._get_and_validate_constant_inputs(
|
162 |
+
inputs, targets_template, forcings)
|
163 |
+
self._validate_targets_and_forcings(targets_template, forcings)
|
164 |
+
|
165 |
+
# After the above checks, the remaining inputs must be time-dependent:
|
166 |
+
inputs = inputs.drop_vars(constant_inputs.keys())
|
167 |
+
|
168 |
+
# A predictions template only including the next time to predict.
|
169 |
+
target_template = targets_template.isel(time=[0])
|
170 |
+
|
171 |
+
flat_forcings, forcings_treedef = (
|
172 |
+
_get_flat_arrays_and_single_timestep_treedef(forcings))
|
173 |
+
scan_variables = flat_forcings
|
174 |
+
|
175 |
+
def one_step_prediction(inputs, scan_variables):
|
176 |
+
|
177 |
+
flat_forcings = scan_variables
|
178 |
+
forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
|
179 |
+
target_template.coords['time'])
|
180 |
+
|
181 |
+
# Add constant inputs:
|
182 |
+
all_inputs = xarray.merge([constant_inputs, inputs])
|
183 |
+
predictions: xarray.Dataset = self._predictor(
|
184 |
+
all_inputs, target_template,
|
185 |
+
forcings=forcings,
|
186 |
+
**kwargs)
|
187 |
+
|
188 |
+
next_frame = xarray.merge([predictions, forcings])
|
189 |
+
next_inputs = self._update_inputs(inputs, next_frame)
|
190 |
+
|
191 |
+
# Drop the length-1 time dimension, since scan will concat all the outputs
|
192 |
+
# for different times along a new leading time dimension:
|
193 |
+
predictions = predictions.squeeze('time', drop=True)
|
194 |
+
# We return the prediction flattened into plain jax arrays, because the
|
195 |
+
# extra leading dimension added by scan prevents the tree_util
|
196 |
+
# registrations in xarray_jax from unflattening them back into an
|
197 |
+
# xarray.Dataset automatically:
|
198 |
+
flat_pred = jax.tree_util.tree_leaves(predictions)
|
199 |
+
return next_inputs, flat_pred
|
200 |
+
|
201 |
+
if self._gradient_checkpointing:
|
202 |
+
scan_length = targets_template.dims['time']
|
203 |
+
if scan_length <= 1:
|
204 |
+
logging.warning(
|
205 |
+
'Skipping gradient checkpointing for sequence length of 1')
|
206 |
+
else:
|
207 |
+
# Just in case we take gradients (e.g. for control), although
|
208 |
+
# in most cases this will just be for a forward pass.
|
209 |
+
one_step_prediction = hk.remat(one_step_prediction)
|
210 |
+
|
211 |
+
# Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
|
212 |
+
_, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
|
213 |
+
|
214 |
+
# The result of scan will have an extra leading axis on all arrays,
|
215 |
+
# corresponding to the target times in this case. We need to be prepared for
|
216 |
+
# it when unflattening the arrays back into a Dataset:
|
217 |
+
scan_result_template = (
|
218 |
+
target_template.squeeze('time', drop=True)
|
219 |
+
.expand_dims(time=targets_template.coords['time'], axis=0))
|
220 |
+
_, scan_result_treedef = jax.tree_util.tree_flatten(scan_result_template)
|
221 |
+
predictions = jax.tree_util.tree_unflatten(scan_result_treedef, flat_preds)
|
222 |
+
return predictions
|
223 |
+
|
224 |
+
def loss(self,
|
225 |
+
inputs: xarray.Dataset,
|
226 |
+
targets: xarray.Dataset,
|
227 |
+
forcings: xarray.Dataset,
|
228 |
+
**kwargs
|
229 |
+
) -> predictor_base.LossAndDiagnostics:
|
230 |
+
"""The mean of the per-timestep losses of the underlying predictor."""
|
231 |
+
if targets.sizes['time'] == 1:
|
232 |
+
# If there is only a single target timestep then we don't need any
|
233 |
+
# autoregressive feedback and can delegate the loss directly to the
|
234 |
+
# underlying single-step predictor. This means the underlying predictor
|
235 |
+
# doesn't need to implement .loss_and_predictions.
|
236 |
+
return self._predictor.loss(inputs, targets, forcings, **kwargs)
|
237 |
+
|
238 |
+
constant_inputs = self._get_and_validate_constant_inputs(
|
239 |
+
inputs, targets, forcings)
|
240 |
+
self._validate_targets_and_forcings(targets, forcings)
|
241 |
+
# After the above checks, the remaining inputs must be time-dependent:
|
242 |
+
inputs = inputs.drop_vars(constant_inputs.keys())
|
243 |
+
|
244 |
+
if self._noise_level:
|
245 |
+
def add_noise(x):
|
246 |
+
return x + self._noise_level * jax.random.normal(
|
247 |
+
hk.next_rng_key(), shape=x.shape)
|
248 |
+
# Add noise to time-dependent variables of the inputs.
|
249 |
+
inputs = jax.tree_map(add_noise, inputs)
|
250 |
+
|
251 |
+
# The per-timestep targets passed by scan to one_step_loss below will have
|
252 |
+
# no leading time axis. We need a treedef without the time axis to use
|
253 |
+
# inside one_step_loss to unflatten it back into a dataset:
|
254 |
+
flat_targets, target_treedef = _get_flat_arrays_and_single_timestep_treedef(
|
255 |
+
targets)
|
256 |
+
scan_variables = flat_targets
|
257 |
+
|
258 |
+
flat_forcings, forcings_treedef = (
|
259 |
+
_get_flat_arrays_and_single_timestep_treedef(forcings))
|
260 |
+
scan_variables = (flat_targets, flat_forcings)
|
261 |
+
|
262 |
+
def one_step_loss(inputs, scan_variables):
|
263 |
+
flat_target, flat_forcings = scan_variables
|
264 |
+
forcings = _unflatten_and_expand_time(flat_forcings, forcings_treedef,
|
265 |
+
targets.coords['time'][:1])
|
266 |
+
|
267 |
+
target = _unflatten_and_expand_time(flat_target, target_treedef,
|
268 |
+
targets.coords['time'][:1])
|
269 |
+
|
270 |
+
# Add constant inputs:
|
271 |
+
all_inputs = xarray.merge([constant_inputs, inputs])
|
272 |
+
|
273 |
+
(loss, diagnostics), predictions = self._predictor.loss_and_predictions(
|
274 |
+
all_inputs,
|
275 |
+
target,
|
276 |
+
forcings=forcings,
|
277 |
+
**kwargs)
|
278 |
+
|
279 |
+
# Unwrap to jax arrays shape (batch,):
|
280 |
+
loss, diagnostics = xarray_tree.map_structure(
|
281 |
+
xarray_jax.unwrap_data, (loss, diagnostics))
|
282 |
+
|
283 |
+
predictions = cast(xarray.Dataset, predictions) # Keeps pytype happy.
|
284 |
+
next_frame = xarray.merge([predictions, forcings])
|
285 |
+
next_inputs = self._update_inputs(inputs, next_frame)
|
286 |
+
|
287 |
+
return next_inputs, (loss, diagnostics)
|
288 |
+
|
289 |
+
if self._gradient_checkpointing:
|
290 |
+
scan_length = targets.dims['time']
|
291 |
+
if scan_length <= 1:
|
292 |
+
logging.warning(
|
293 |
+
'Skipping gradient checkpointing for sequence length of 1')
|
294 |
+
else:
|
295 |
+
one_step_loss = hk.remat(one_step_loss)
|
296 |
+
|
297 |
+
# We can pass inputs (the initial state of the loop) in directly as a
|
298 |
+
# Dataset because the shape we pass in to scan is the same as the shape scan
|
299 |
+
# passes to the inner function. But, for scan_variables, we must flatten the
|
300 |
+
# targets (and unflatten them inside the inner function) because they are
|
301 |
+
# passed to the inner function per-timestep without the original time axis.
|
302 |
+
# The same apply to the optional forcing.
|
303 |
+
_, (per_timestep_losses, per_timestep_diagnostics) = hk.scan(
|
304 |
+
one_step_loss, inputs, scan_variables)
|
305 |
+
|
306 |
+
# Re-wrap loss and diagnostics as DataArray and average them over time:
|
307 |
+
(loss, diagnostics) = jax.tree_util.tree_map(
|
308 |
+
lambda x: xarray_jax.DataArray(x, dims=('time', 'batch')).mean( # pylint: disable=g-long-lambda
|
309 |
+
'time', skipna=False),
|
310 |
+
(per_timestep_losses, per_timestep_diagnostics))
|
311 |
+
|
312 |
+
return loss, diagnostics
|
graphcast/casting.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Wrappers that take care of casting."""
|
15 |
+
|
16 |
+
import contextlib
|
17 |
+
from typing import Any, Mapping, Tuple
|
18 |
+
|
19 |
+
import chex
|
20 |
+
from graphcast import predictor_base
|
21 |
+
import haiku as hk
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
import numpy as np
|
25 |
+
import xarray
|
26 |
+
|
27 |
+
|
28 |
+
PyTree = Any
|
29 |
+
|
30 |
+
|
31 |
+
class Bfloat16Cast(predictor_base.Predictor):
|
32 |
+
"""Wrapper that casts all inputs to bfloat16 and outputs to targets dtype."""
|
33 |
+
|
34 |
+
def __init__(self, predictor: predictor_base.Predictor, enabled: bool = True):
|
35 |
+
"""Inits the wrapper.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
predictor: predictor being wrapped.
|
39 |
+
enabled: disables the wrapper if False, for simpler hyperparameter scans.
|
40 |
+
|
41 |
+
"""
|
42 |
+
self._enabled = enabled
|
43 |
+
self._predictor = predictor
|
44 |
+
|
45 |
+
def __call__(self,
|
46 |
+
inputs: xarray.Dataset,
|
47 |
+
targets_template: xarray.Dataset,
|
48 |
+
forcings: xarray.Dataset,
|
49 |
+
**kwargs
|
50 |
+
) -> xarray.Dataset:
|
51 |
+
if not self._enabled:
|
52 |
+
return self._predictor(inputs, targets_template, forcings, **kwargs)
|
53 |
+
|
54 |
+
with bfloat16_variable_view():
|
55 |
+
predictions = self._predictor(
|
56 |
+
*_all_inputs_to_bfloat16(inputs, targets_template, forcings),
|
57 |
+
**kwargs,)
|
58 |
+
|
59 |
+
predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
|
60 |
+
if predictions_dtype != jnp.bfloat16:
|
61 |
+
raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
|
62 |
+
|
63 |
+
targets_dtype = infer_floating_dtype(targets_template) # pytype: disable=wrong-arg-types
|
64 |
+
return tree_map_cast(
|
65 |
+
predictions, input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
|
66 |
+
|
67 |
+
def loss(self,
|
68 |
+
inputs: xarray.Dataset,
|
69 |
+
targets: xarray.Dataset,
|
70 |
+
forcings: xarray.Dataset,
|
71 |
+
**kwargs,
|
72 |
+
) -> predictor_base.LossAndDiagnostics:
|
73 |
+
if not self._enabled:
|
74 |
+
return self._predictor.loss(inputs, targets, forcings, **kwargs)
|
75 |
+
|
76 |
+
with bfloat16_variable_view():
|
77 |
+
loss, scalars = self._predictor.loss(
|
78 |
+
*_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
|
79 |
+
|
80 |
+
if loss.dtype != jnp.bfloat16:
|
81 |
+
raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
|
82 |
+
|
83 |
+
targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
|
84 |
+
|
85 |
+
# Note that casting back the loss to e.g. float32 should not affect data
|
86 |
+
# types of the backwards pass, because the first thing the backwards pass
|
87 |
+
# should do is to go backwards the casting op and cast back to bfloat16
|
88 |
+
# (and xprofs seem to confirm this).
|
89 |
+
return tree_map_cast((loss, scalars),
|
90 |
+
input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
|
91 |
+
|
92 |
+
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
|
93 |
+
self,
|
94 |
+
inputs: xarray.Dataset,
|
95 |
+
targets: xarray.Dataset,
|
96 |
+
forcings: xarray.Dataset,
|
97 |
+
**kwargs,
|
98 |
+
) -> Tuple[predictor_base.LossAndDiagnostics,
|
99 |
+
xarray.Dataset]:
|
100 |
+
if not self._enabled:
|
101 |
+
return self._predictor.loss_and_predictions(inputs, targets, forcings, # pytype: disable=bad-return-type # jax-ndarray
|
102 |
+
**kwargs)
|
103 |
+
|
104 |
+
with bfloat16_variable_view():
|
105 |
+
(loss, scalars), predictions = self._predictor.loss_and_predictions(
|
106 |
+
*_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
|
107 |
+
|
108 |
+
if loss.dtype != jnp.bfloat16:
|
109 |
+
raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')
|
110 |
+
|
111 |
+
predictions_dtype = infer_floating_dtype(predictions) # pytype: disable=wrong-arg-types
|
112 |
+
if predictions_dtype != jnp.bfloat16:
|
113 |
+
raise ValueError(f'Expected bfloat16 output, got {predictions_dtype}')
|
114 |
+
|
115 |
+
targets_dtype = infer_floating_dtype(targets) # pytype: disable=wrong-arg-types
|
116 |
+
return tree_map_cast(((loss, scalars), predictions),
|
117 |
+
input_dtype=jnp.bfloat16, output_dtype=targets_dtype)
|
118 |
+
|
119 |
+
|
120 |
+
def infer_floating_dtype(data_vars: Mapping[str, chex.Array]) -> np.dtype:
|
121 |
+
"""Infers a floating dtype from an input mapping of data."""
|
122 |
+
dtypes = {
|
123 |
+
v.dtype
|
124 |
+
for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
|
125 |
+
if len(dtypes) != 1:
|
126 |
+
dtypes_and_shapes = {
|
127 |
+
k: (v.dtype, v.shape)
|
128 |
+
for k, v in data_vars.items() if jnp.issubdtype(v.dtype, np.floating)}
|
129 |
+
raise ValueError(
|
130 |
+
f'Did not found exactly one floating dtype {dtypes} in input variables:'
|
131 |
+
f'{dtypes_and_shapes}')
|
132 |
+
return list(dtypes)[0]
|
133 |
+
|
134 |
+
|
135 |
+
def _all_inputs_to_bfloat16(
|
136 |
+
inputs: xarray.Dataset,
|
137 |
+
targets: xarray.Dataset,
|
138 |
+
forcings: xarray.Dataset,
|
139 |
+
) -> Tuple[xarray.Dataset,
|
140 |
+
xarray.Dataset,
|
141 |
+
xarray.Dataset]:
|
142 |
+
return (inputs.astype(jnp.bfloat16),
|
143 |
+
jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
|
144 |
+
forcings.astype(jnp.bfloat16))
|
145 |
+
|
146 |
+
|
147 |
+
def tree_map_cast(inputs: PyTree, input_dtype: np.dtype, output_dtype: np.dtype,
|
148 |
+
) -> PyTree:
|
149 |
+
def cast_fn(x):
|
150 |
+
if x.dtype == input_dtype:
|
151 |
+
return x.astype(output_dtype)
|
152 |
+
return jax.tree_map(cast_fn, inputs)
|
153 |
+
|
154 |
+
|
155 |
+
@contextlib.contextmanager
|
156 |
+
def bfloat16_variable_view(enabled: bool = True):
|
157 |
+
"""Context for Haiku modules with float32 params, but bfloat16 activations.
|
158 |
+
|
159 |
+
It works as follows:
|
160 |
+
* Every time a variable is requested to be created/set as np.bfloat16,
|
161 |
+
it will create an underlying float32 variable, instead.
|
162 |
+
* Every time a variable a variable is requested as bfloat16, it will check the
|
163 |
+
variable is of float32 type, and cast the variable to bfloat16.
|
164 |
+
|
165 |
+
Note the gradients are still computed and accumulated as float32, because
|
166 |
+
the params returned by init are float32, so the gradient function with
|
167 |
+
respect to the params will already include an implicit casting to float32.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
enabled: Only enables bfloat16 behavior if True.
|
171 |
+
|
172 |
+
Yields:
|
173 |
+
None
|
174 |
+
"""
|
175 |
+
|
176 |
+
if enabled:
|
177 |
+
with hk.custom_creator(
|
178 |
+
_bfloat16_creator, state=True), hk.custom_getter(
|
179 |
+
_bfloat16_getter, state=True), hk.custom_setter(
|
180 |
+
_bfloat16_setter):
|
181 |
+
yield
|
182 |
+
else:
|
183 |
+
yield
|
184 |
+
|
185 |
+
|
186 |
+
def _bfloat16_creator(next_creator, shape, dtype, init, context):
|
187 |
+
"""Creates float32 variables when bfloat16 is requested."""
|
188 |
+
if context.original_dtype == jnp.bfloat16:
|
189 |
+
dtype = jnp.float32
|
190 |
+
return next_creator(shape, dtype, init)
|
191 |
+
|
192 |
+
|
193 |
+
def _bfloat16_getter(next_getter, value, context):
|
194 |
+
"""Casts float32 to bfloat16 when bfloat16 was originally requested."""
|
195 |
+
if context.original_dtype == jnp.bfloat16:
|
196 |
+
assert value.dtype == jnp.float32
|
197 |
+
value = value.astype(jnp.bfloat16)
|
198 |
+
return next_getter(value)
|
199 |
+
|
200 |
+
|
201 |
+
def _bfloat16_setter(next_setter, value, context):
|
202 |
+
"""Casts bfloat16 to float32 when bfloat16 was originally set."""
|
203 |
+
if context.original_dtype == jnp.bfloat16:
|
204 |
+
value = value.astype(jnp.float32)
|
205 |
+
return next_setter(value)
|
graphcast/checkpoint.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Serialize and deserialize trees."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import io
|
18 |
+
import types
|
19 |
+
from typing import Any, BinaryIO, Optional, TypeVar
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
|
23 |
+
_T = TypeVar("_T")
|
24 |
+
|
25 |
+
|
26 |
+
def dump(dest: BinaryIO, value: Any) -> None:
|
27 |
+
"""Dump a tree of dicts/dataclasses to a file object.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dest: a file object to write to.
|
31 |
+
value: A tree of dicts, lists, tuples and dataclasses of numpy arrays and
|
32 |
+
other basic types. Unions are not supported, other than Optional/None
|
33 |
+
which is only supported in dataclasses, not in dicts, lists or tuples.
|
34 |
+
All leaves must be coercible to a numpy array, and recoverable as a single
|
35 |
+
arg to a type.
|
36 |
+
"""
|
37 |
+
buffer = io.BytesIO() # In case the destination doesn't support seeking.
|
38 |
+
np.savez(buffer, **_flatten(value))
|
39 |
+
dest.write(buffer.getvalue())
|
40 |
+
|
41 |
+
|
42 |
+
def load(source: BinaryIO, typ: type[_T]) -> _T:
|
43 |
+
"""Load from a file object and convert it to the specified type.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
source: a file object to read from.
|
47 |
+
typ: a type object that acts as a schema for deserialization. It must match
|
48 |
+
what was serialized. If a type is Any, it will be returned however numpy
|
49 |
+
serialized it, which is what you want for a tree of numpy arrays.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
the deserialized value as the specified type.
|
53 |
+
"""
|
54 |
+
return _convert_types(typ, _unflatten(np.load(source)))
|
55 |
+
|
56 |
+
|
57 |
+
_SEP = ":"
|
58 |
+
|
59 |
+
|
60 |
+
def _flatten(tree: Any) -> dict[str, Any]:
|
61 |
+
"""Flatten a tree of dicts/dataclasses/lists/tuples to a single dict."""
|
62 |
+
if dataclasses.is_dataclass(tree):
|
63 |
+
# Don't use dataclasses.asdict as it is recursive so skips dropping None.
|
64 |
+
tree = {f.name: v for f in dataclasses.fields(tree)
|
65 |
+
if (v := getattr(tree, f.name)) is not None}
|
66 |
+
elif isinstance(tree, (list, tuple)):
|
67 |
+
tree = dict(enumerate(tree))
|
68 |
+
|
69 |
+
assert isinstance(tree, dict)
|
70 |
+
|
71 |
+
flat = {}
|
72 |
+
for k, v in tree.items():
|
73 |
+
k = str(k)
|
74 |
+
assert _SEP not in k
|
75 |
+
if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
76 |
+
for a, b in _flatten(v).items():
|
77 |
+
flat[f"{k}{_SEP}{a}"] = b
|
78 |
+
else:
|
79 |
+
assert v is not None
|
80 |
+
flat[k] = v
|
81 |
+
return flat
|
82 |
+
|
83 |
+
|
84 |
+
def _unflatten(flat: dict[str, Any]) -> dict[str, Any]:
|
85 |
+
"""Unflatten a dict to a tree of dicts."""
|
86 |
+
tree = {}
|
87 |
+
for flat_key, v in flat.items():
|
88 |
+
node = tree
|
89 |
+
keys = flat_key.split(_SEP)
|
90 |
+
for k in keys[:-1]:
|
91 |
+
if k not in node:
|
92 |
+
node[k] = {}
|
93 |
+
node = node[k]
|
94 |
+
node[keys[-1]] = v
|
95 |
+
return tree
|
96 |
+
|
97 |
+
|
98 |
+
def _convert_types(typ: type[_T], value: Any) -> _T:
|
99 |
+
"""Convert some structure into the given type. The structures must match."""
|
100 |
+
if typ in (Any, ...):
|
101 |
+
return value
|
102 |
+
|
103 |
+
if typ in (int, float, str, bool):
|
104 |
+
return typ(value)
|
105 |
+
|
106 |
+
if typ is np.ndarray:
|
107 |
+
assert isinstance(value, np.ndarray)
|
108 |
+
return value
|
109 |
+
|
110 |
+
if dataclasses.is_dataclass(typ):
|
111 |
+
kwargs = {}
|
112 |
+
for f in dataclasses.fields(typ):
|
113 |
+
# Only support Optional for dataclasses, as numpy can't serialize it
|
114 |
+
# directly (without pickle), and dataclasses are the only case where we
|
115 |
+
# can know the full set of values and types and therefore know the
|
116 |
+
# non-existence must mean None.
|
117 |
+
if isinstance(f.type, (types.UnionType, type(Optional[int]))):
|
118 |
+
constructors = [t for t in f.type.__args__ if t is not types.NoneType]
|
119 |
+
if len(constructors) != 1:
|
120 |
+
raise TypeError(
|
121 |
+
"Optional works, Union with anything except None doesn't")
|
122 |
+
if f.name not in value:
|
123 |
+
kwargs[f.name] = None
|
124 |
+
continue
|
125 |
+
constructor = constructors[0]
|
126 |
+
else:
|
127 |
+
constructor = f.type
|
128 |
+
|
129 |
+
if f.name in value:
|
130 |
+
kwargs[f.name] = _convert_types(constructor, value[f.name])
|
131 |
+
else:
|
132 |
+
raise ValueError(f"Missing value: {f.name}")
|
133 |
+
return typ(**kwargs)
|
134 |
+
|
135 |
+
base_type = getattr(typ, "__origin__", None)
|
136 |
+
|
137 |
+
if base_type is dict:
|
138 |
+
assert len(typ.__args__) == 2
|
139 |
+
key_type, value_type = typ.__args__
|
140 |
+
return {_convert_types(key_type, k): _convert_types(value_type, v)
|
141 |
+
for k, v in value.items()}
|
142 |
+
|
143 |
+
if base_type is list:
|
144 |
+
assert len(typ.__args__) == 1
|
145 |
+
value_type = typ.__args__[0]
|
146 |
+
return [_convert_types(value_type, v)
|
147 |
+
for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
|
148 |
+
|
149 |
+
if base_type is tuple:
|
150 |
+
if len(typ.__args__) == 2 and typ.__args__[1] == ...:
|
151 |
+
# An arbitrary length tuple of a single type, eg: tuple[int, ...]
|
152 |
+
value_type = typ.__args__[0]
|
153 |
+
return tuple(_convert_types(value_type, v)
|
154 |
+
for _, v in sorted(value.items(), key=lambda x: int(x[0])))
|
155 |
+
else:
|
156 |
+
# A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
|
157 |
+
assert len(typ.__args__) == len(value)
|
158 |
+
return tuple(
|
159 |
+
_convert_types(t, v)
|
160 |
+
for t, (_, v) in zip(
|
161 |
+
typ.__args__, sorted(value.items(), key=lambda x: int(x[0]))))
|
162 |
+
|
163 |
+
# This is probably unreachable with reasonable serializable inputs.
|
164 |
+
try:
|
165 |
+
return typ(value)
|
166 |
+
except TypeError as e:
|
167 |
+
raise TypeError(
|
168 |
+
"_convert_types expects the type argument to be a dataclass defined "
|
169 |
+
"with types that are valid constructors (eg tuple is fine, Tuple "
|
170 |
+
"isn't), and accept a numpy array as the sole argument.") from e
|
graphcast/checkpoint_test.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Check that the checkpoint serialization is reversable."""
|
15 |
+
|
16 |
+
import dataclasses
|
17 |
+
import io
|
18 |
+
from typing import Any, Optional, Union
|
19 |
+
|
20 |
+
from absl.testing import absltest
|
21 |
+
from graphcast import checkpoint
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
@dataclasses.dataclass
|
26 |
+
class SubConfig:
|
27 |
+
a: int
|
28 |
+
b: str
|
29 |
+
|
30 |
+
|
31 |
+
@dataclasses.dataclass
|
32 |
+
class Config:
|
33 |
+
bt: bool
|
34 |
+
bf: bool
|
35 |
+
i: int
|
36 |
+
f: float
|
37 |
+
o1: Optional[int]
|
38 |
+
o2: Optional[int]
|
39 |
+
o3: Union[int, None]
|
40 |
+
o4: Union[int, None]
|
41 |
+
o5: int | None
|
42 |
+
o6: int | None
|
43 |
+
li: list[int]
|
44 |
+
ls: list[str]
|
45 |
+
ldc: list[SubConfig]
|
46 |
+
tf: tuple[float, ...]
|
47 |
+
ts: tuple[str, ...]
|
48 |
+
t: tuple[str, int, SubConfig]
|
49 |
+
tdc: tuple[SubConfig, ...]
|
50 |
+
dsi: dict[str, int]
|
51 |
+
dss: dict[str, str]
|
52 |
+
dis: dict[int, str]
|
53 |
+
dsdis: dict[str, dict[int, str]]
|
54 |
+
dc: SubConfig
|
55 |
+
dco: Optional[SubConfig]
|
56 |
+
ddc: dict[str, SubConfig]
|
57 |
+
|
58 |
+
|
59 |
+
@dataclasses.dataclass
|
60 |
+
class Checkpoint:
|
61 |
+
params: dict[str, Any]
|
62 |
+
config: Config
|
63 |
+
|
64 |
+
|
65 |
+
class DataclassTest(absltest.TestCase):
|
66 |
+
|
67 |
+
def test_serialize_dataclass(self):
|
68 |
+
ckpt = Checkpoint(
|
69 |
+
params={
|
70 |
+
"layer1": {
|
71 |
+
"w": np.arange(10).reshape(2, 5),
|
72 |
+
"b": np.array([2, 6]),
|
73 |
+
},
|
74 |
+
"layer2": {
|
75 |
+
"w": np.arange(8).reshape(2, 4),
|
76 |
+
"b": np.array([2, 6]),
|
77 |
+
},
|
78 |
+
"blah": np.array([3, 9]),
|
79 |
+
},
|
80 |
+
config=Config(
|
81 |
+
bt=True,
|
82 |
+
bf=False,
|
83 |
+
i=42,
|
84 |
+
f=3.14,
|
85 |
+
o1=1,
|
86 |
+
o2=None,
|
87 |
+
o3=2,
|
88 |
+
o4=None,
|
89 |
+
o5=3,
|
90 |
+
o6=None,
|
91 |
+
li=[12, 9, 7, 15, 16, 14, 1, 6, 11, 4, 10, 5, 13, 3, 8, 2],
|
92 |
+
ls=list("qhjfdxtpzgemryoikwvblcaus"),
|
93 |
+
ldc=[SubConfig(1, "hello"), SubConfig(2, "world")],
|
94 |
+
tf=(1, 4, 2, 10, 5, 9, 13, 16, 15, 8, 12, 7, 11, 14, 3, 6),
|
95 |
+
ts=("hello", "world"),
|
96 |
+
t=("foo", 42, SubConfig(1, "bar")),
|
97 |
+
tdc=(SubConfig(1, "hello"), SubConfig(2, "world")),
|
98 |
+
dsi={"a": 1, "b": 2, "c": 3},
|
99 |
+
dss={"d": "e", "f": "g"},
|
100 |
+
dis={1: "a", 2: "b", 3: "c"},
|
101 |
+
dsdis={"a": {1: "hello", 2: "world"}, "b": {1: "world"}},
|
102 |
+
dc=SubConfig(1, "hello"),
|
103 |
+
dco=None,
|
104 |
+
ddc={"a": SubConfig(1, "hello"), "b": SubConfig(2, "world")},
|
105 |
+
))
|
106 |
+
|
107 |
+
buffer = io.BytesIO()
|
108 |
+
checkpoint.dump(buffer, ckpt)
|
109 |
+
buffer.seek(0)
|
110 |
+
ckpt2 = checkpoint.load(buffer, Checkpoint)
|
111 |
+
np.testing.assert_array_equal(ckpt.params["layer1"]["w"],
|
112 |
+
ckpt2.params["layer1"]["w"])
|
113 |
+
np.testing.assert_array_equal(ckpt.params["layer1"]["b"],
|
114 |
+
ckpt2.params["layer1"]["b"])
|
115 |
+
np.testing.assert_array_equal(ckpt.params["layer2"]["w"],
|
116 |
+
ckpt2.params["layer2"]["w"])
|
117 |
+
np.testing.assert_array_equal(ckpt.params["layer2"]["b"],
|
118 |
+
ckpt2.params["layer2"]["b"])
|
119 |
+
np.testing.assert_array_equal(ckpt.params["blah"], ckpt2.params["blah"])
|
120 |
+
self.assertEqual(ckpt.config, ckpt2.config)
|
121 |
+
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
absltest.main()
|
graphcast/data_utils.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Dataset utilities."""
|
15 |
+
|
16 |
+
from typing import Any, Mapping, Sequence, Tuple, Union
|
17 |
+
|
18 |
+
from graphcast import solar_radiation
|
19 |
+
import numpy as np
|
20 |
+
import pandas as pd
|
21 |
+
import xarray
|
22 |
+
|
23 |
+
TimedeltaLike = Any # Something convertible to pd.Timedelta.
|
24 |
+
TimedeltaStr = str # A string convertible to pd.Timedelta.
|
25 |
+
|
26 |
+
TargetLeadTimes = Union[
|
27 |
+
TimedeltaLike,
|
28 |
+
Sequence[TimedeltaLike],
|
29 |
+
slice # with TimedeltaLike as its start and stop.
|
30 |
+
]
|
31 |
+
|
32 |
+
_SEC_PER_HOUR = 3600
|
33 |
+
_HOUR_PER_DAY = 24
|
34 |
+
SEC_PER_DAY = _SEC_PER_HOUR * _HOUR_PER_DAY
|
35 |
+
_AVG_DAY_PER_YEAR = 365.24219
|
36 |
+
AVG_SEC_PER_YEAR = SEC_PER_DAY * _AVG_DAY_PER_YEAR
|
37 |
+
|
38 |
+
DAY_PROGRESS = "day_progress"
|
39 |
+
YEAR_PROGRESS = "year_progress"
|
40 |
+
_DERIVED_VARS = {
|
41 |
+
DAY_PROGRESS,
|
42 |
+
f"{DAY_PROGRESS}_sin",
|
43 |
+
f"{DAY_PROGRESS}_cos",
|
44 |
+
YEAR_PROGRESS,
|
45 |
+
f"{YEAR_PROGRESS}_sin",
|
46 |
+
f"{YEAR_PROGRESS}_cos",
|
47 |
+
}
|
48 |
+
TISR = "toa_incident_solar_radiation"
|
49 |
+
|
50 |
+
|
51 |
+
def get_year_progress(seconds_since_epoch: np.ndarray) -> np.ndarray:
|
52 |
+
"""Computes year progress for times in seconds.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
seconds_since_epoch: Times in seconds since the "epoch" (the point at which
|
56 |
+
UNIX time starts).
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
Year progress normalized to be in the [0, 1) interval for each time point.
|
60 |
+
"""
|
61 |
+
|
62 |
+
# Start with the pure integer division, and then float at the very end.
|
63 |
+
# We will try to keep as much precision as possible.
|
64 |
+
years_since_epoch = (
|
65 |
+
seconds_since_epoch / SEC_PER_DAY / np.float64(_AVG_DAY_PER_YEAR)
|
66 |
+
)
|
67 |
+
# Note depending on how these ops are down, we may end up with a "weak_type"
|
68 |
+
# which can cause issues in subtle ways, and hard to track here.
|
69 |
+
# In any case, casting to float32 should get rid of the weak type.
|
70 |
+
# [0, 1.) Interval.
|
71 |
+
return np.mod(years_since_epoch, 1.0).astype(np.float32)
|
72 |
+
|
73 |
+
|
74 |
+
def get_day_progress(
|
75 |
+
seconds_since_epoch: np.ndarray,
|
76 |
+
longitude: np.ndarray,
|
77 |
+
) -> np.ndarray:
|
78 |
+
"""Computes day progress for times in seconds at each longitude.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
seconds_since_epoch: 1D array of times in seconds since the 'epoch' (the
|
82 |
+
point at which UNIX time starts).
|
83 |
+
longitude: 1D array of longitudes at which day progress is computed.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
2D array of day progress values normalized to be in the [0, 1) inverval
|
87 |
+
for each time point at each longitude.
|
88 |
+
"""
|
89 |
+
|
90 |
+
# [0.0, 1.0) Interval.
|
91 |
+
day_progress_greenwich = (
|
92 |
+
np.mod(seconds_since_epoch, SEC_PER_DAY) / SEC_PER_DAY
|
93 |
+
)
|
94 |
+
|
95 |
+
# Offset the day progress to the longitude of each point on Earth.
|
96 |
+
longitude_offsets = np.deg2rad(longitude) / (2 * np.pi)
|
97 |
+
day_progress = np.mod(
|
98 |
+
day_progress_greenwich[..., np.newaxis] + longitude_offsets, 1.0
|
99 |
+
)
|
100 |
+
return day_progress.astype(np.float32)
|
101 |
+
|
102 |
+
|
103 |
+
def featurize_progress(
|
104 |
+
name: str, dims: Sequence[str], progress: np.ndarray
|
105 |
+
) -> Mapping[str, xarray.Variable]:
|
106 |
+
"""Derives features used by ML models from the `progress` variable.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
name: Base variable name from which features are derived.
|
110 |
+
dims: List of the output feature dimensions, e.g. ("day", "lon").
|
111 |
+
progress: Progress variable values.
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
Dictionary of xarray variables derived from the `progress` values. It
|
115 |
+
includes the original `progress` variable along with its sin and cos
|
116 |
+
transformations.
|
117 |
+
|
118 |
+
Raises:
|
119 |
+
ValueError if the number of feature dimensions is not equal to the number
|
120 |
+
of data dimensions.
|
121 |
+
"""
|
122 |
+
if len(dims) != progress.ndim:
|
123 |
+
raise ValueError(
|
124 |
+
f"Number of feature dimensions ({len(dims)}) must be equal to the"
|
125 |
+
f" number of data dimensions: {progress.ndim}."
|
126 |
+
)
|
127 |
+
progress_phase = progress * (2 * np.pi)
|
128 |
+
return {
|
129 |
+
name: xarray.Variable(dims, progress),
|
130 |
+
name + "_sin": xarray.Variable(dims, np.sin(progress_phase)),
|
131 |
+
name + "_cos": xarray.Variable(dims, np.cos(progress_phase)),
|
132 |
+
}
|
133 |
+
|
134 |
+
|
135 |
+
def add_derived_vars(data: xarray.Dataset) -> None:
|
136 |
+
"""Adds year and day progress features to `data` in place if missing.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
data: Xarray dataset to which derived features will be added.
|
140 |
+
|
141 |
+
Raises:
|
142 |
+
ValueError if `datetime` or `lon` are not in `data` coordinates.
|
143 |
+
"""
|
144 |
+
|
145 |
+
for coord in ("datetime", "lon"):
|
146 |
+
if coord not in data.coords:
|
147 |
+
raise ValueError(f"'{coord}' must be in `data` coordinates.")
|
148 |
+
|
149 |
+
# Compute seconds since epoch.
|
150 |
+
# Note `data.coords["datetime"].astype("datetime64[s]").astype(np.int64)`
|
151 |
+
# does not work as xarrays always cast dates into nanoseconds!
|
152 |
+
seconds_since_epoch = (
|
153 |
+
data.coords["datetime"].data.astype("datetime64[s]").astype(np.int64)
|
154 |
+
)
|
155 |
+
batch_dim = ("batch",) if "batch" in data.dims else ()
|
156 |
+
|
157 |
+
# Add year progress features if missing.
|
158 |
+
if YEAR_PROGRESS not in data.data_vars:
|
159 |
+
year_progress = get_year_progress(seconds_since_epoch)
|
160 |
+
data.update(
|
161 |
+
featurize_progress(
|
162 |
+
name=YEAR_PROGRESS,
|
163 |
+
dims=batch_dim + ("time",),
|
164 |
+
progress=year_progress,
|
165 |
+
)
|
166 |
+
)
|
167 |
+
|
168 |
+
# Add day progress features if missing.
|
169 |
+
if DAY_PROGRESS not in data.data_vars:
|
170 |
+
longitude_coord = data.coords["lon"]
|
171 |
+
day_progress = get_day_progress(seconds_since_epoch, longitude_coord.data)
|
172 |
+
data.update(
|
173 |
+
featurize_progress(
|
174 |
+
name=DAY_PROGRESS,
|
175 |
+
dims=batch_dim + ("time",) + longitude_coord.dims,
|
176 |
+
progress=day_progress,
|
177 |
+
)
|
178 |
+
)
|
179 |
+
|
180 |
+
|
181 |
+
def add_tisr_var(data: xarray.Dataset) -> None:
|
182 |
+
"""Adds TISR feature to `data` in place if missing.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
data: Xarray dataset to which TISR feature will be added.
|
186 |
+
|
187 |
+
Raises:
|
188 |
+
ValueError if `datetime`, 'lat', or `lon` are not in `data` coordinates.
|
189 |
+
"""
|
190 |
+
|
191 |
+
if TISR in data.data_vars:
|
192 |
+
return
|
193 |
+
|
194 |
+
for coord in ("datetime", "lat", "lon"):
|
195 |
+
if coord not in data.coords:
|
196 |
+
raise ValueError(f"'{coord}' must be in `data` coordinates.")
|
197 |
+
|
198 |
+
# Remove `batch` dimension of size one if present. An error will be raised if
|
199 |
+
# the `batch` dimension exists and has size greater than one.
|
200 |
+
data_no_batch = data.squeeze("batch") if "batch" in data.dims else data
|
201 |
+
|
202 |
+
tisr = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
|
203 |
+
data_no_batch, use_jit=True
|
204 |
+
)
|
205 |
+
|
206 |
+
if "batch" in data.dims:
|
207 |
+
tisr = tisr.expand_dims("batch", axis=0)
|
208 |
+
|
209 |
+
data.update({TISR: tisr})
|
210 |
+
|
211 |
+
|
212 |
+
def extract_input_target_times(
|
213 |
+
dataset: xarray.Dataset,
|
214 |
+
input_duration: TimedeltaLike,
|
215 |
+
target_lead_times: TargetLeadTimes,
|
216 |
+
) -> Tuple[xarray.Dataset, xarray.Dataset]:
|
217 |
+
"""Extracts inputs and targets for prediction, from a Dataset with a time dim.
|
218 |
+
|
219 |
+
The input period is assumed to be contiguous (specified by a duration), but
|
220 |
+
the targets can be a list of arbitrary lead times.
|
221 |
+
|
222 |
+
Examples:
|
223 |
+
|
224 |
+
# Use 18 hours of data as inputs, and two specific lead times as targets:
|
225 |
+
# 3 days and 5 days after the final input.
|
226 |
+
extract_inputs_targets(
|
227 |
+
dataset,
|
228 |
+
input_duration='18h',
|
229 |
+
target_lead_times=('3d', '5d')
|
230 |
+
)
|
231 |
+
|
232 |
+
# Use 1 day of data as input, and all lead times between 6 hours and
|
233 |
+
# 24 hours inclusive as targets. Demonstrates a friendlier supported string
|
234 |
+
# syntax.
|
235 |
+
extract_inputs_targets(
|
236 |
+
dataset,
|
237 |
+
input_duration='1 day',
|
238 |
+
target_lead_times=slice('6 hours', '24 hours')
|
239 |
+
)
|
240 |
+
|
241 |
+
# Just use a single target lead time of 3 days:
|
242 |
+
extract_inputs_targets(
|
243 |
+
dataset,
|
244 |
+
input_duration='24h',
|
245 |
+
target_lead_times='3d'
|
246 |
+
)
|
247 |
+
|
248 |
+
Args:
|
249 |
+
dataset: An xarray.Dataset with a 'time' dimension whose coordinates are
|
250 |
+
timedeltas. It's assumed that the time coordinates have a fixed offset /
|
251 |
+
time resolution, and that the input_duration and target_lead_times are
|
252 |
+
multiples of this.
|
253 |
+
input_duration: pandas.Timedelta or something convertible to it (e.g. a
|
254 |
+
shorthand string like '6h' or '5d12h').
|
255 |
+
target_lead_times: Either a single lead time, a slice with start and stop
|
256 |
+
(inclusive) lead times, or a sequence of lead times. Lead times should be
|
257 |
+
Timedeltas (or something convertible to). They are given relative to the
|
258 |
+
final input timestep, and should be positive.
|
259 |
+
|
260 |
+
Returns:
|
261 |
+
inputs:
|
262 |
+
targets:
|
263 |
+
Two datasets with the same shape as the input dataset except that a
|
264 |
+
selection has been made from the time axis, and the origin of the
|
265 |
+
time coordinate will be shifted to refer to lead times relative to the
|
266 |
+
final input timestep. So for inputs the times will end at lead time 0,
|
267 |
+
for targets the time coordinates will refer to the lead times requested.
|
268 |
+
"""
|
269 |
+
|
270 |
+
(target_lead_times, target_duration
|
271 |
+
) = _process_target_lead_times_and_get_duration(target_lead_times)
|
272 |
+
|
273 |
+
# Shift the coordinates for the time axis so that a timedelta of zero
|
274 |
+
# corresponds to the forecast reference time. That is, the final timestep
|
275 |
+
# that's available as input to the forecast, with all following timesteps
|
276 |
+
# forming the target period which needs to be predicted.
|
277 |
+
# This means the time coordinates are now forecast lead times.
|
278 |
+
time = dataset.coords["time"]
|
279 |
+
dataset = dataset.assign_coords(time=time + target_duration - time[-1])
|
280 |
+
|
281 |
+
# Slice out targets:
|
282 |
+
targets = dataset.sel({"time": target_lead_times})
|
283 |
+
|
284 |
+
input_duration = pd.Timedelta(input_duration)
|
285 |
+
# Both endpoints are inclusive with label-based slicing, so we offset by a
|
286 |
+
# small epsilon to make one of the endpoints non-inclusive:
|
287 |
+
zero = pd.Timedelta(0)
|
288 |
+
epsilon = pd.Timedelta(1, "ns")
|
289 |
+
inputs = dataset.sel({"time": slice(-input_duration + epsilon, zero)})
|
290 |
+
return inputs, targets
|
291 |
+
|
292 |
+
|
293 |
+
def _process_target_lead_times_and_get_duration(
|
294 |
+
target_lead_times: TargetLeadTimes) -> TimedeltaLike:
|
295 |
+
"""Returns the minimum duration for the target lead times."""
|
296 |
+
if isinstance(target_lead_times, slice):
|
297 |
+
# A slice of lead times. xarray already accepts timedelta-like values for
|
298 |
+
# the begin/end/step of the slice.
|
299 |
+
if target_lead_times.start is None:
|
300 |
+
# If the start isn't specified, we assume it starts at the next timestep
|
301 |
+
# after lead time 0 (lead time 0 is the final input timestep):
|
302 |
+
target_lead_times = slice(
|
303 |
+
pd.Timedelta(1, "ns"), target_lead_times.stop, target_lead_times.step
|
304 |
+
)
|
305 |
+
target_duration = pd.Timedelta(target_lead_times.stop)
|
306 |
+
else:
|
307 |
+
if not isinstance(target_lead_times, (list, tuple, set)):
|
308 |
+
# A single lead time, which we wrap as a length-1 array to ensure there
|
309 |
+
# still remains a time dimension (here of length 1) for consistency.
|
310 |
+
target_lead_times = [target_lead_times]
|
311 |
+
|
312 |
+
# A list of multiple (not necessarily contiguous) lead times:
|
313 |
+
target_lead_times = [pd.Timedelta(x) for x in target_lead_times]
|
314 |
+
target_lead_times.sort()
|
315 |
+
target_duration = target_lead_times[-1]
|
316 |
+
return target_lead_times, target_duration
|
317 |
+
|
318 |
+
|
319 |
+
def extract_inputs_targets_forcings(
|
320 |
+
dataset: xarray.Dataset,
|
321 |
+
*,
|
322 |
+
input_variables: Tuple[str, ...],
|
323 |
+
target_variables: Tuple[str, ...],
|
324 |
+
forcing_variables: Tuple[str, ...],
|
325 |
+
pressure_levels: Tuple[int, ...],
|
326 |
+
input_duration: TimedeltaLike,
|
327 |
+
target_lead_times: TargetLeadTimes,
|
328 |
+
) -> Tuple[xarray.Dataset, xarray.Dataset, xarray.Dataset]:
|
329 |
+
"""Extracts inputs, targets and forcings according to requirements."""
|
330 |
+
dataset = dataset.sel(level=list(pressure_levels))
|
331 |
+
|
332 |
+
# "Forcings" include derived variables that do not exist in the original ERA5
|
333 |
+
# or HRES datasets, as well as other variables (e.g. tisr) that need to be
|
334 |
+
# computed manually for the target lead times. Compute the requested ones.
|
335 |
+
if set(forcing_variables) & _DERIVED_VARS:
|
336 |
+
add_derived_vars(dataset)
|
337 |
+
if set(forcing_variables) & {TISR}:
|
338 |
+
add_tisr_var(dataset)
|
339 |
+
|
340 |
+
# `datetime` is needed by add_derived_vars but breaks autoregressive rollouts.
|
341 |
+
dataset = dataset.drop_vars("datetime")
|
342 |
+
|
343 |
+
inputs, targets = extract_input_target_times(
|
344 |
+
dataset,
|
345 |
+
input_duration=input_duration,
|
346 |
+
target_lead_times=target_lead_times)
|
347 |
+
|
348 |
+
if set(forcing_variables) & set(target_variables):
|
349 |
+
raise ValueError(
|
350 |
+
f"Forcing variables {forcing_variables} should not "
|
351 |
+
f"overlap with target variables {target_variables}."
|
352 |
+
)
|
353 |
+
|
354 |
+
inputs = inputs[list(input_variables)]
|
355 |
+
# The forcing uses the same time coordinates as the target.
|
356 |
+
forcings = targets[list(forcing_variables)]
|
357 |
+
targets = targets[list(target_variables)]
|
358 |
+
|
359 |
+
return inputs, targets, forcings
|
graphcast/data_utils_test.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tests for `data_utils.py`."""
|
15 |
+
|
16 |
+
import datetime
|
17 |
+
from absl.testing import absltest
|
18 |
+
from absl.testing import parameterized
|
19 |
+
from graphcast import data_utils
|
20 |
+
import numpy as np
|
21 |
+
import xarray as xa
|
22 |
+
|
23 |
+
|
24 |
+
class DataUtilsTest(parameterized.TestCase):
|
25 |
+
|
26 |
+
def setUp(self):
|
27 |
+
super().setUp()
|
28 |
+
# Fix the seed for reproducibility.
|
29 |
+
np.random.seed(0)
|
30 |
+
|
31 |
+
def test_year_progress_is_zero_at_year_start_or_end(self):
|
32 |
+
year_progress = data_utils.get_year_progress(
|
33 |
+
np.array([
|
34 |
+
0,
|
35 |
+
data_utils.AVG_SEC_PER_YEAR,
|
36 |
+
data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
|
37 |
+
])
|
38 |
+
)
|
39 |
+
np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
|
40 |
+
|
41 |
+
def test_year_progress_is_almost_one_before_year_ends(self):
|
42 |
+
year_progress = data_utils.get_year_progress(
|
43 |
+
np.array([
|
44 |
+
data_utils.AVG_SEC_PER_YEAR - 1,
|
45 |
+
(data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
|
46 |
+
])
|
47 |
+
)
|
48 |
+
with self.subTest("Year progress values are close to 1"):
|
49 |
+
self.assertTrue(np.all(year_progress > 0.999))
|
50 |
+
with self.subTest("Year progress values != 1"):
|
51 |
+
self.assertTrue(np.all(year_progress < 1.0))
|
52 |
+
|
53 |
+
def test_day_progress_computes_for_all_times_and_longitudes(self):
|
54 |
+
times = np.random.randint(low=0, high=1e10, size=10)
|
55 |
+
longitudes = np.arange(0, 360.0, 1.0)
|
56 |
+
day_progress = data_utils.get_day_progress(times, longitudes)
|
57 |
+
with self.subTest("Day progress is computed for all times and longinutes"):
|
58 |
+
self.assertSequenceEqual(
|
59 |
+
day_progress.shape, (len(times), len(longitudes))
|
60 |
+
)
|
61 |
+
|
62 |
+
@parameterized.named_parameters(
|
63 |
+
dict(
|
64 |
+
testcase_name="random_date_1",
|
65 |
+
year=1988,
|
66 |
+
month=11,
|
67 |
+
day=7,
|
68 |
+
hour=2,
|
69 |
+
minute=45,
|
70 |
+
second=34,
|
71 |
+
),
|
72 |
+
dict(
|
73 |
+
testcase_name="random_date_2",
|
74 |
+
year=2022,
|
75 |
+
month=3,
|
76 |
+
day=12,
|
77 |
+
hour=7,
|
78 |
+
minute=1,
|
79 |
+
second=0,
|
80 |
+
),
|
81 |
+
)
|
82 |
+
def test_day_progress_is_in_between_zero_and_one(
|
83 |
+
self, year, month, day, hour, minute, second
|
84 |
+
):
|
85 |
+
# Datetime from a timestamp.
|
86 |
+
dt = datetime.datetime(year, month, day, hour, minute, second)
|
87 |
+
# Epoch time.
|
88 |
+
epoch_time = datetime.datetime(1970, 1, 1)
|
89 |
+
# Seconds since epoch.
|
90 |
+
seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
|
91 |
+
|
92 |
+
# Longitudes with 1 degree resolution.
|
93 |
+
longitudes = np.arange(0, 360.0, 1.0)
|
94 |
+
|
95 |
+
day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
|
96 |
+
with self.subTest("Day progress >= 0"):
|
97 |
+
self.assertTrue(np.all(day_progress >= 0.0))
|
98 |
+
with self.subTest("Day progress < 1"):
|
99 |
+
self.assertTrue(np.all(day_progress < 1.0))
|
100 |
+
|
101 |
+
def test_day_progress_is_zero_at_day_start_or_end(self):
|
102 |
+
day_progress = data_utils.get_day_progress(
|
103 |
+
seconds_since_epoch=np.array([
|
104 |
+
0,
|
105 |
+
data_utils.SEC_PER_DAY,
|
106 |
+
data_utils.SEC_PER_DAY * 42, # 42 days.
|
107 |
+
]),
|
108 |
+
longitude=np.array([0.0]),
|
109 |
+
)
|
110 |
+
np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
|
111 |
+
|
112 |
+
def test_day_progress_specific_value(self):
|
113 |
+
day_progress = data_utils.get_day_progress(
|
114 |
+
seconds_since_epoch=np.array([123]),
|
115 |
+
longitude=np.array([0.0]),
|
116 |
+
)
|
117 |
+
np.testing.assert_array_almost_equal(
|
118 |
+
day_progress, np.array([[0.00142361]]), decimal=6
|
119 |
+
)
|
120 |
+
|
121 |
+
def test_featurize_progress_valid_values_and_dimensions(self):
|
122 |
+
day_progress = np.array([0.0, 0.45, 0.213])
|
123 |
+
feature_dimensions = ("time",)
|
124 |
+
progress_features = data_utils.featurize_progress(
|
125 |
+
name="day_progress", dims=feature_dimensions, progress=day_progress
|
126 |
+
)
|
127 |
+
for feature in progress_features.values():
|
128 |
+
with self.subTest(f"Valid dimensions for {feature}"):
|
129 |
+
self.assertSequenceEqual(feature.dims, feature_dimensions)
|
130 |
+
|
131 |
+
with self.subTest("Valid values for day_progress"):
|
132 |
+
np.testing.assert_array_equal(
|
133 |
+
day_progress, progress_features["day_progress"].values
|
134 |
+
)
|
135 |
+
|
136 |
+
with self.subTest("Valid values for day_progress_sin"):
|
137 |
+
np.testing.assert_array_almost_equal(
|
138 |
+
np.array([0.0, 0.30901699, 0.97309851]),
|
139 |
+
progress_features["day_progress_sin"].values,
|
140 |
+
decimal=6,
|
141 |
+
)
|
142 |
+
|
143 |
+
with self.subTest("Valid values for day_progress_cos"):
|
144 |
+
np.testing.assert_array_almost_equal(
|
145 |
+
np.array([1.0, -0.95105652, 0.23038943]),
|
146 |
+
progress_features["day_progress_cos"].values,
|
147 |
+
decimal=6,
|
148 |
+
)
|
149 |
+
|
150 |
+
def test_featurize_progress_invalid_dimensions(self):
|
151 |
+
year_progress = np.array([0.0, 0.45, 0.213])
|
152 |
+
feature_dimensions = ("time", "longitude")
|
153 |
+
with self.assertRaises(ValueError):
|
154 |
+
data_utils.featurize_progress(
|
155 |
+
name="year_progress", dims=feature_dimensions, progress=year_progress
|
156 |
+
)
|
157 |
+
|
158 |
+
def test_add_derived_vars_variables_added(self):
|
159 |
+
data = xa.Dataset(
|
160 |
+
data_vars={
|
161 |
+
"var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
|
162 |
+
},
|
163 |
+
coords={
|
164 |
+
"lon": np.array([0.0, 0.5]),
|
165 |
+
"datetime": np.array([
|
166 |
+
datetime.datetime(2021, 1, 1),
|
167 |
+
datetime.datetime(2023, 1, 1),
|
168 |
+
datetime.datetime(2023, 1, 3),
|
169 |
+
]),
|
170 |
+
},
|
171 |
+
)
|
172 |
+
data_utils.add_derived_vars(data)
|
173 |
+
all_variables = set(data.variables)
|
174 |
+
|
175 |
+
with self.subTest("Original value was not removed"):
|
176 |
+
self.assertIn("var1", all_variables)
|
177 |
+
with self.subTest("Year progress feature was added"):
|
178 |
+
self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
|
179 |
+
with self.subTest("Day progress feature was added"):
|
180 |
+
self.assertIn(data_utils.DAY_PROGRESS, all_variables)
|
181 |
+
|
182 |
+
def test_add_derived_vars_existing_vars_not_overridden(self):
|
183 |
+
dims = ["x", "lon", "datetime"]
|
184 |
+
data = xa.Dataset(
|
185 |
+
data_vars={
|
186 |
+
"var1": (dims, 8 * np.random.randn(2, 2, 3)),
|
187 |
+
data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
|
188 |
+
data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
|
189 |
+
},
|
190 |
+
coords={
|
191 |
+
"lon": np.array([0.0, 0.5]),
|
192 |
+
"datetime": np.array([
|
193 |
+
datetime.datetime(2021, 1, 1),
|
194 |
+
datetime.datetime(2023, 1, 1),
|
195 |
+
datetime.datetime(2023, 1, 3),
|
196 |
+
]),
|
197 |
+
},
|
198 |
+
)
|
199 |
+
|
200 |
+
data_utils.add_derived_vars(data)
|
201 |
+
|
202 |
+
with self.subTest("Year progress feature was not overridden"):
|
203 |
+
np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
|
204 |
+
with self.subTest("Day progress feature was not overridden"):
|
205 |
+
np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)
|
206 |
+
|
207 |
+
@parameterized.named_parameters(
|
208 |
+
dict(testcase_name="missing_datetime", coord_name="lon"),
|
209 |
+
dict(testcase_name="missing_lon", coord_name="datetime"),
|
210 |
+
)
|
211 |
+
def test_add_derived_vars_missing_coordinate_raises_value_error(
|
212 |
+
self, coord_name
|
213 |
+
):
|
214 |
+
with self.subTest(f"Missing {coord_name} coordinate"):
|
215 |
+
data = xa.Dataset(
|
216 |
+
data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
|
217 |
+
coords={
|
218 |
+
coord_name: np.array([0.0, 0.5]),
|
219 |
+
},
|
220 |
+
)
|
221 |
+
with self.assertRaises(ValueError):
|
222 |
+
data_utils.add_derived_vars(data)
|
223 |
+
|
224 |
+
def test_add_tisr_var_variable_added(self):
|
225 |
+
data = xa.Dataset(
|
226 |
+
data_vars={
|
227 |
+
"var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))
|
228 |
+
},
|
229 |
+
coords={
|
230 |
+
"lat": np.array([2.0, 1.0]),
|
231 |
+
"lon": np.array([0.0, 0.5]),
|
232 |
+
"time": np.array([100, 200], dtype="timedelta64[s]"),
|
233 |
+
"datetime": xa.Variable(
|
234 |
+
"time", np.array([10, 20], dtype="datetime64[D]")
|
235 |
+
),
|
236 |
+
},
|
237 |
+
)
|
238 |
+
|
239 |
+
data_utils.add_tisr_var(data)
|
240 |
+
|
241 |
+
self.assertIn(data_utils.TISR, set(data.variables))
|
242 |
+
|
243 |
+
def test_add_tisr_var_existing_var_not_overridden(self):
|
244 |
+
dims = ["time", "lat", "lon"]
|
245 |
+
data = xa.Dataset(
|
246 |
+
data_vars={
|
247 |
+
"var1": (dims, np.full((2, 2, 2), 8.0)),
|
248 |
+
data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
|
249 |
+
},
|
250 |
+
coords={
|
251 |
+
"lat": np.array([2.0, 1.0]),
|
252 |
+
"lon": np.array([0.0, 0.5]),
|
253 |
+
"time": np.array([100, 200], dtype="timedelta64[s]"),
|
254 |
+
"datetime": xa.Variable(
|
255 |
+
"time", np.array([10, 20], dtype="datetime64[D]")
|
256 |
+
),
|
257 |
+
},
|
258 |
+
)
|
259 |
+
|
260 |
+
data_utils.add_derived_vars(data)
|
261 |
+
|
262 |
+
np.testing.assert_allclose(data[data_utils.TISR], 1200.0)
|
263 |
+
|
264 |
+
def test_add_tisr_var_works_with_batch_dim_size_one(self):
|
265 |
+
data = xa.Dataset(
|
266 |
+
data_vars={
|
267 |
+
"var1": (
|
268 |
+
["batch", "time", "lat", "lon"],
|
269 |
+
np.full((1, 2, 2, 2), 8.0),
|
270 |
+
)
|
271 |
+
},
|
272 |
+
coords={
|
273 |
+
"lat": np.array([2.0, 1.0]),
|
274 |
+
"lon": np.array([0.0, 0.5]),
|
275 |
+
"time": np.array([100, 200], dtype="timedelta64[s]"),
|
276 |
+
"datetime": xa.Variable(
|
277 |
+
("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
|
278 |
+
),
|
279 |
+
},
|
280 |
+
)
|
281 |
+
|
282 |
+
data_utils.add_tisr_var(data)
|
283 |
+
|
284 |
+
self.assertIn(data_utils.TISR, set(data.variables))
|
285 |
+
|
286 |
+
def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
|
287 |
+
data = xa.Dataset(
|
288 |
+
data_vars={
|
289 |
+
"var1": (
|
290 |
+
["batch", "time", "lat", "lon"],
|
291 |
+
np.full((2, 2, 2, 2), 8.0),
|
292 |
+
)
|
293 |
+
},
|
294 |
+
coords={
|
295 |
+
"lat": np.array([2.0, 1.0]),
|
296 |
+
"lon": np.array([0.0, 0.5]),
|
297 |
+
"time": np.array([100, 200], dtype="timedelta64[s]"),
|
298 |
+
"datetime": xa.Variable(
|
299 |
+
("batch", "time"),
|
300 |
+
np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
|
301 |
+
),
|
302 |
+
},
|
303 |
+
)
|
304 |
+
|
305 |
+
with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
|
306 |
+
data_utils.add_tisr_var(data)
|
307 |
+
|
308 |
+
|
309 |
+
if __name__ == "__main__":
|
310 |
+
absltest.main()
|
graphcast/deep_typed_graph_net.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""JAX implementation of Graph Networks Simulator.
|
15 |
+
|
16 |
+
Generalization to TypedGraphs of the deep Graph Neural Network from:
|
17 |
+
|
18 |
+
@inproceedings{pfaff2021learning,
|
19 |
+
title={Learning Mesh-Based Simulation with Graph Networks},
|
20 |
+
author={Pfaff, Tobias and Fortunato, Meire and Sanchez-Gonzalez, Alvaro and
|
21 |
+
Battaglia, Peter},
|
22 |
+
booktitle={International Conference on Learning Representations},
|
23 |
+
year={2021}
|
24 |
+
}
|
25 |
+
|
26 |
+
@inproceedings{sanchez2020learning,
|
27 |
+
title={Learning to simulate complex physics with graph networks},
|
28 |
+
author={Sanchez-Gonzalez, Alvaro and Godwin, Jonathan and Pfaff, Tobias and
|
29 |
+
Ying, Rex and Leskovec, Jure and Battaglia, Peter},
|
30 |
+
booktitle={International conference on machine learning},
|
31 |
+
pages={8459--8468},
|
32 |
+
year={2020},
|
33 |
+
organization={PMLR}
|
34 |
+
}
|
35 |
+
"""
|
36 |
+
|
37 |
+
from typing import Mapping, Optional
|
38 |
+
|
39 |
+
from graphcast import typed_graph
|
40 |
+
from graphcast import typed_graph_net
|
41 |
+
import haiku as hk
|
42 |
+
import jax
|
43 |
+
import jax.numpy as jnp
|
44 |
+
import jraph
|
45 |
+
|
46 |
+
|
47 |
+
class DeepTypedGraphNet(hk.Module):
|
48 |
+
"""Deep Graph Neural Network.
|
49 |
+
|
50 |
+
It works with TypedGraphs with typed nodes and edges. It runs message
|
51 |
+
passing on all of the node sets and all of the edge sets in the graph. For
|
52 |
+
each message passing step a `typed_graph_net.InteractionNetwork` is used to
|
53 |
+
update the full TypedGraph by using different MLPs for each of the node sets
|
54 |
+
and each of the edge sets.
|
55 |
+
|
56 |
+
If embed_{nodes,edges} is specified the node/edge features will be embedded
|
57 |
+
into a fixed dimensionality before running the first step of message passing.
|
58 |
+
|
59 |
+
If {node,edge}_output_size the final node/edge features will be embedded into
|
60 |
+
the specified output size.
|
61 |
+
|
62 |
+
This class may be used for shared or unshared message passing:
|
63 |
+
* num_message_passing_steps = N, num_processor_repetitions = 1, gives
|
64 |
+
N layers of message passing with fully unshared weights:
|
65 |
+
[W_1, W_2, ... , W_M] (default)
|
66 |
+
* num_message_passing_steps = 1, num_processor_repetitions = M, gives
|
67 |
+
N layers of message passing with fully shared weights:
|
68 |
+
[W_1] * M
|
69 |
+
* num_message_passing_steps = N, num_processor_repetitions = M, gives
|
70 |
+
M*N layers of message passing with both shared and unshared message passing
|
71 |
+
such that the weights used at each iteration are:
|
72 |
+
[W_1, W_2, ... , W_N] * M
|
73 |
+
|
74 |
+
"""
|
75 |
+
|
76 |
+
def __init__(self,
|
77 |
+
*,
|
78 |
+
node_latent_size: Mapping[str, int],
|
79 |
+
edge_latent_size: Mapping[str, int],
|
80 |
+
mlp_hidden_size: int,
|
81 |
+
mlp_num_hidden_layers: int,
|
82 |
+
num_message_passing_steps: int,
|
83 |
+
num_processor_repetitions: int = 1,
|
84 |
+
embed_nodes: bool = True,
|
85 |
+
embed_edges: bool = True,
|
86 |
+
node_output_size: Optional[Mapping[str, int]] = None,
|
87 |
+
edge_output_size: Optional[Mapping[str, int]] = None,
|
88 |
+
include_sent_messages_in_node_update: bool = False,
|
89 |
+
use_layer_norm: bool = True,
|
90 |
+
activation: str = "relu",
|
91 |
+
f32_aggregation: bool = False,
|
92 |
+
aggregate_edges_for_nodes_fn: str = "segment_sum",
|
93 |
+
aggregate_normalization: Optional[float] = None,
|
94 |
+
name: str = "DeepTypedGraphNet"):
|
95 |
+
"""Inits the model.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
node_latent_size: Size of the node latent representations.
|
99 |
+
edge_latent_size: Size of the edge latent representations.
|
100 |
+
mlp_hidden_size: Hidden layer size for all MLPs.
|
101 |
+
mlp_num_hidden_layers: Number of hidden layers in all MLPs.
|
102 |
+
num_message_passing_steps: Number of unshared message passing steps
|
103 |
+
in the processor steps.
|
104 |
+
num_processor_repetitions: Number of times that the same processor is
|
105 |
+
applied sequencially.
|
106 |
+
embed_nodes: If False, the node embedder will be omitted.
|
107 |
+
embed_edges: If False, the edge embedder will be omitted.
|
108 |
+
node_output_size: Size of the output node representations for
|
109 |
+
each node type. For node types not specified here, the latent node
|
110 |
+
representation from the output of the processor will be returned.
|
111 |
+
edge_output_size: Size of the output edge representations for
|
112 |
+
each edge type. For edge types not specified here, the latent edge
|
113 |
+
representation from the output of the processor will be returned.
|
114 |
+
include_sent_messages_in_node_update: Whether to include pooled sent
|
115 |
+
messages from each node in the node update.
|
116 |
+
use_layer_norm: Whether it uses layer norm or not.
|
117 |
+
activation: name of activation function.
|
118 |
+
f32_aggregation: Use float32 in the edge aggregation.
|
119 |
+
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
|
120 |
+
node.
|
121 |
+
aggregate_normalization: An optional constant that normalizes the output
|
122 |
+
of aggregate_edges_for_nodes_fn. For context, this can be used to
|
123 |
+
reduce the shock the model undergoes when switching resolution, which
|
124 |
+
increase the number of edges connected to a node. In particular, this is
|
125 |
+
useful when using segment_sum, but should not be combined with
|
126 |
+
segment_mean.
|
127 |
+
name: Name of the model.
|
128 |
+
"""
|
129 |
+
|
130 |
+
super().__init__(name=name)
|
131 |
+
|
132 |
+
self._node_latent_size = node_latent_size
|
133 |
+
self._edge_latent_size = edge_latent_size
|
134 |
+
self._mlp_hidden_size = mlp_hidden_size
|
135 |
+
self._mlp_num_hidden_layers = mlp_num_hidden_layers
|
136 |
+
self._num_message_passing_steps = num_message_passing_steps
|
137 |
+
self._num_processor_repetitions = num_processor_repetitions
|
138 |
+
self._embed_nodes = embed_nodes
|
139 |
+
self._embed_edges = embed_edges
|
140 |
+
self._node_output_size = node_output_size
|
141 |
+
self._edge_output_size = edge_output_size
|
142 |
+
self._include_sent_messages_in_node_update = (
|
143 |
+
include_sent_messages_in_node_update)
|
144 |
+
self._use_layer_norm = use_layer_norm
|
145 |
+
self._activation = _get_activation_fn(activation)
|
146 |
+
self._initialized = False
|
147 |
+
self._f32_aggregation = f32_aggregation
|
148 |
+
self._aggregate_edges_for_nodes_fn = _get_aggregate_edges_for_nodes_fn(
|
149 |
+
aggregate_edges_for_nodes_fn)
|
150 |
+
self._aggregate_normalization = aggregate_normalization
|
151 |
+
|
152 |
+
if aggregate_normalization:
|
153 |
+
# using aggregate_normalization only makes sense with segment_sum.
|
154 |
+
assert aggregate_edges_for_nodes_fn == "segment_sum"
|
155 |
+
|
156 |
+
def __call__(self,
|
157 |
+
input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
158 |
+
"""Forward pass of the learnable dynamics model."""
|
159 |
+
self._networks_builder(input_graph)
|
160 |
+
|
161 |
+
# Embed input features (if applicable).
|
162 |
+
latent_graph_0 = self._embed(input_graph)
|
163 |
+
|
164 |
+
# Do `m` message passing steps in the latent graphs.
|
165 |
+
latent_graph_m = self._process(latent_graph_0)
|
166 |
+
|
167 |
+
# Compute outputs from the last latent graph (if applicable).
|
168 |
+
return self._output(latent_graph_m)
|
169 |
+
|
170 |
+
def _networks_builder(self, graph_template):
|
171 |
+
if self._initialized:
|
172 |
+
return
|
173 |
+
self._initialized = True
|
174 |
+
|
175 |
+
def build_mlp(name, output_size):
|
176 |
+
mlp = hk.nets.MLP(
|
177 |
+
output_sizes=[self._mlp_hidden_size] * self._mlp_num_hidden_layers + [
|
178 |
+
output_size], name=name + "_mlp", activation=self._activation)
|
179 |
+
return jraph.concatenated_args(mlp)
|
180 |
+
|
181 |
+
def build_mlp_with_maybe_layer_norm(name, output_size):
|
182 |
+
network = build_mlp(name, output_size)
|
183 |
+
if self._use_layer_norm:
|
184 |
+
layer_norm = hk.LayerNorm(
|
185 |
+
axis=-1, create_scale=True, create_offset=True,
|
186 |
+
name=name + "_layer_norm")
|
187 |
+
network = hk.Sequential([network, layer_norm])
|
188 |
+
return jraph.concatenated_args(network)
|
189 |
+
|
190 |
+
# The embedder graph network independently embeds edge and node features.
|
191 |
+
if self._embed_edges:
|
192 |
+
embed_edge_fn = _build_update_fns_for_edge_types(
|
193 |
+
build_mlp_with_maybe_layer_norm,
|
194 |
+
graph_template,
|
195 |
+
"encoder_edges_",
|
196 |
+
output_sizes=self._edge_latent_size)
|
197 |
+
else:
|
198 |
+
embed_edge_fn = None
|
199 |
+
if self._embed_nodes:
|
200 |
+
embed_node_fn = _build_update_fns_for_node_types(
|
201 |
+
build_mlp_with_maybe_layer_norm,
|
202 |
+
graph_template,
|
203 |
+
"encoder_nodes_",
|
204 |
+
output_sizes=self._node_latent_size)
|
205 |
+
else:
|
206 |
+
embed_node_fn = None
|
207 |
+
embedder_kwargs = dict(
|
208 |
+
embed_edge_fn=embed_edge_fn,
|
209 |
+
embed_node_fn=embed_node_fn,
|
210 |
+
)
|
211 |
+
self._embedder_network = typed_graph_net.GraphMapFeatures(
|
212 |
+
**embedder_kwargs)
|
213 |
+
|
214 |
+
if self._f32_aggregation:
|
215 |
+
def aggregate_fn(data, *args, **kwargs):
|
216 |
+
dtype = data.dtype
|
217 |
+
data = data.astype(jnp.float32)
|
218 |
+
output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
|
219 |
+
if self._aggregate_normalization:
|
220 |
+
output = output / self._aggregate_normalization
|
221 |
+
output = output.astype(dtype)
|
222 |
+
return output
|
223 |
+
|
224 |
+
else:
|
225 |
+
def aggregate_fn(data, *args, **kwargs):
|
226 |
+
output = self._aggregate_edges_for_nodes_fn(data, *args, **kwargs)
|
227 |
+
if self._aggregate_normalization:
|
228 |
+
output = output / self._aggregate_normalization
|
229 |
+
return output
|
230 |
+
|
231 |
+
# Create `num_message_passing_steps` graph networks with unshared parameters
|
232 |
+
# that update the node and edge latent features.
|
233 |
+
# Note that we can use `modules.InteractionNetwork` because
|
234 |
+
# it also outputs the messages as updated edge latent features.
|
235 |
+
self._processor_networks = []
|
236 |
+
for step_i in range(self._num_message_passing_steps):
|
237 |
+
self._processor_networks.append(
|
238 |
+
typed_graph_net.InteractionNetwork(
|
239 |
+
update_edge_fn=_build_update_fns_for_edge_types(
|
240 |
+
build_mlp_with_maybe_layer_norm,
|
241 |
+
graph_template,
|
242 |
+
f"processor_edges_{step_i}_",
|
243 |
+
output_sizes=self._edge_latent_size),
|
244 |
+
update_node_fn=_build_update_fns_for_node_types(
|
245 |
+
build_mlp_with_maybe_layer_norm,
|
246 |
+
graph_template,
|
247 |
+
f"processor_nodes_{step_i}_",
|
248 |
+
output_sizes=self._node_latent_size),
|
249 |
+
aggregate_edges_for_nodes_fn=aggregate_fn,
|
250 |
+
include_sent_messages_in_node_update=(
|
251 |
+
self._include_sent_messages_in_node_update),
|
252 |
+
))
|
253 |
+
|
254 |
+
# The output MLPs converts edge/node latent features into the output sizes.
|
255 |
+
output_kwargs = dict(
|
256 |
+
embed_edge_fn=_build_update_fns_for_edge_types(
|
257 |
+
build_mlp, graph_template, "decoder_edges_", self._edge_output_size)
|
258 |
+
if self._edge_output_size else None,
|
259 |
+
embed_node_fn=_build_update_fns_for_node_types(
|
260 |
+
build_mlp, graph_template, "decoder_nodes_", self._node_output_size)
|
261 |
+
if self._node_output_size else None,)
|
262 |
+
self._output_network = typed_graph_net.GraphMapFeatures(
|
263 |
+
**output_kwargs)
|
264 |
+
|
265 |
+
def _embed(
|
266 |
+
self, input_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
267 |
+
"""Embeds the input graph features into a latent graph."""
|
268 |
+
|
269 |
+
# Copy the context to all of the node types, if applicable.
|
270 |
+
context_features = input_graph.context.features
|
271 |
+
if jax.tree_util.tree_leaves(context_features):
|
272 |
+
# This code assumes a single input feature array for the context and for
|
273 |
+
# each node type.
|
274 |
+
assert len(jax.tree_util.tree_leaves(context_features)) == 1
|
275 |
+
new_nodes = {}
|
276 |
+
for node_set_name, node_set in input_graph.nodes.items():
|
277 |
+
node_features = node_set.features
|
278 |
+
broadcasted_context = jnp.repeat(
|
279 |
+
context_features, node_set.n_node, axis=0,
|
280 |
+
total_repeat_length=node_features.shape[0])
|
281 |
+
new_nodes[node_set_name] = node_set._replace(
|
282 |
+
features=jnp.concatenate(
|
283 |
+
[node_features, broadcasted_context], axis=-1))
|
284 |
+
input_graph = input_graph._replace(
|
285 |
+
nodes=new_nodes,
|
286 |
+
context=input_graph.context._replace(features=()))
|
287 |
+
|
288 |
+
# Embeds the node and edge features.
|
289 |
+
latent_graph_0 = self._embedder_network(input_graph)
|
290 |
+
return latent_graph_0
|
291 |
+
|
292 |
+
def _process(
|
293 |
+
self, latent_graph_0: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
294 |
+
"""Processes the latent graph with several steps of message passing."""
|
295 |
+
|
296 |
+
# Do `num_message_passing_steps` with each of the `self._processor_networks`
|
297 |
+
# with unshared weights, and repeat that `self._num_processor_repetitions`
|
298 |
+
# times.
|
299 |
+
latent_graph = latent_graph_0
|
300 |
+
for unused_repetition_i in range(self._num_processor_repetitions):
|
301 |
+
for processor_network in self._processor_networks:
|
302 |
+
latent_graph = self._process_step(processor_network, latent_graph)
|
303 |
+
|
304 |
+
return latent_graph
|
305 |
+
|
306 |
+
def _process_step(
|
307 |
+
self, processor_network_k,
|
308 |
+
latent_graph_prev_k: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
309 |
+
"""Single step of message passing with node/edge residual connections."""
|
310 |
+
|
311 |
+
# One step of message passing.
|
312 |
+
latent_graph_k = processor_network_k(latent_graph_prev_k)
|
313 |
+
|
314 |
+
# Add residuals.
|
315 |
+
nodes_with_residuals = {}
|
316 |
+
for k, prev_set in latent_graph_prev_k.nodes.items():
|
317 |
+
nodes_with_residuals[k] = prev_set._replace(
|
318 |
+
features=prev_set.features + latent_graph_k.nodes[k].features)
|
319 |
+
|
320 |
+
edges_with_residuals = {}
|
321 |
+
for k, prev_set in latent_graph_prev_k.edges.items():
|
322 |
+
edges_with_residuals[k] = prev_set._replace(
|
323 |
+
features=prev_set.features + latent_graph_k.edges[k].features)
|
324 |
+
|
325 |
+
latent_graph_k = latent_graph_k._replace(
|
326 |
+
nodes=nodes_with_residuals, edges=edges_with_residuals)
|
327 |
+
return latent_graph_k
|
328 |
+
|
329 |
+
def _output(self,
|
330 |
+
latent_graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
331 |
+
"""Produces the output from the latent graph."""
|
332 |
+
return self._output_network(latent_graph)
|
333 |
+
|
334 |
+
|
335 |
+
def _build_update_fns_for_node_types(
|
336 |
+
builder_fn, graph_template, prefix, output_sizes=None):
|
337 |
+
"""Builds an update function for all node types or a subset of them."""
|
338 |
+
|
339 |
+
output_fns = {}
|
340 |
+
for node_set_name in graph_template.nodes.keys():
|
341 |
+
if output_sizes is None:
|
342 |
+
# Use the default output size for all types.
|
343 |
+
output_size = None
|
344 |
+
else:
|
345 |
+
# Otherwise, ignore any type that does not have an explicit output size.
|
346 |
+
if node_set_name in output_sizes:
|
347 |
+
output_size = output_sizes[node_set_name]
|
348 |
+
else:
|
349 |
+
continue
|
350 |
+
output_fns[node_set_name] = builder_fn(
|
351 |
+
f"{prefix}{node_set_name}", output_size)
|
352 |
+
return output_fns
|
353 |
+
|
354 |
+
|
355 |
+
def _build_update_fns_for_edge_types(
|
356 |
+
builder_fn, graph_template, prefix, output_sizes=None):
|
357 |
+
"""Builds an edge function for all node types or a subset of them."""
|
358 |
+
output_fns = {}
|
359 |
+
for edge_set_key in graph_template.edges.keys():
|
360 |
+
edge_set_name = edge_set_key.name
|
361 |
+
if output_sizes is None:
|
362 |
+
# Use the default output size for all types.
|
363 |
+
output_size = None
|
364 |
+
else:
|
365 |
+
# Otherwise, ignore any type that does not have an explicit output size.
|
366 |
+
if edge_set_name in output_sizes:
|
367 |
+
output_size = output_sizes[edge_set_name]
|
368 |
+
else:
|
369 |
+
continue
|
370 |
+
output_fns[edge_set_name] = builder_fn(
|
371 |
+
f"{prefix}{edge_set_name}", output_size)
|
372 |
+
return output_fns
|
373 |
+
|
374 |
+
|
375 |
+
def _get_activation_fn(name):
|
376 |
+
"""Return activation function corresponding to function_name."""
|
377 |
+
if name == "identity":
|
378 |
+
return lambda x: x
|
379 |
+
if hasattr(jax.nn, name):
|
380 |
+
return getattr(jax.nn, name)
|
381 |
+
if hasattr(jnp, name):
|
382 |
+
return getattr(jnp, name)
|
383 |
+
raise ValueError(f"Unknown activation function {name} specified.")
|
384 |
+
|
385 |
+
|
386 |
+
def _get_aggregate_edges_for_nodes_fn(name):
|
387 |
+
"""Return aggregate_edges_for_nodes_fn corresponding to function_name."""
|
388 |
+
if hasattr(jraph, name):
|
389 |
+
return getattr(jraph, name)
|
390 |
+
raise ValueError(
|
391 |
+
f"Unknown aggregate_edges_for_nodes_fn function {name} specified.")
|
graphcast/graphcast.py
ADDED
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""A predictor that runs multiple graph neural networks on mesh data.
|
15 |
+
|
16 |
+
It learns to interpolate between the grid and the mesh nodes, with the loss
|
17 |
+
and the rollouts ultimately computed at the grid level.
|
18 |
+
|
19 |
+
It uses ideas similar to those in Keisler (2022):
|
20 |
+
|
21 |
+
Reference:
|
22 |
+
https://arxiv.org/pdf/2202.07575.pdf
|
23 |
+
|
24 |
+
It assumes data across time and level is stacked, and operates only operates in
|
25 |
+
a 2D mesh over latitudes and longitudes.
|
26 |
+
"""
|
27 |
+
|
28 |
+
from typing import Any, Callable, Mapping, Optional
|
29 |
+
|
30 |
+
import chex
|
31 |
+
from graphcast import deep_typed_graph_net
|
32 |
+
from graphcast import grid_mesh_connectivity
|
33 |
+
from graphcast import icosahedral_mesh
|
34 |
+
from graphcast import losses
|
35 |
+
from graphcast import model_utils
|
36 |
+
from graphcast import predictor_base
|
37 |
+
from graphcast import typed_graph
|
38 |
+
from graphcast import xarray_jax
|
39 |
+
import jax.numpy as jnp
|
40 |
+
import jraph
|
41 |
+
import numpy as np
|
42 |
+
import xarray
|
43 |
+
|
44 |
+
Kwargs = Mapping[str, Any]
|
45 |
+
|
46 |
+
GNN = Callable[[jraph.GraphsTuple], jraph.GraphsTuple]
|
47 |
+
|
48 |
+
|
49 |
+
# https://www.ecmwf.int/en/forecasts/dataset/ecmwf-reanalysis-v5
|
50 |
+
PRESSURE_LEVELS_ERA5_37 = (
|
51 |
+
1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 125, 150, 175, 200, 225, 250, 300,
|
52 |
+
350, 400, 450, 500, 550, 600, 650, 700, 750, 775, 800, 825, 850, 875, 900,
|
53 |
+
925, 950, 975, 1000)
|
54 |
+
|
55 |
+
# https://www.ecmwf.int/en/forecasts/datasets/set-i
|
56 |
+
PRESSURE_LEVELS_HRES_25 = (
|
57 |
+
1, 2, 3, 5, 7, 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 400, 500, 600,
|
58 |
+
700, 800, 850, 900, 925, 950, 1000)
|
59 |
+
|
60 |
+
# https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2020MS002203
|
61 |
+
PRESSURE_LEVELS_WEATHERBENCH_13 = (
|
62 |
+
50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000)
|
63 |
+
|
64 |
+
PRESSURE_LEVELS = {
|
65 |
+
13: PRESSURE_LEVELS_WEATHERBENCH_13,
|
66 |
+
25: PRESSURE_LEVELS_HRES_25,
|
67 |
+
37: PRESSURE_LEVELS_ERA5_37,
|
68 |
+
}
|
69 |
+
|
70 |
+
# The list of all possible atmospheric variables. Taken from:
|
71 |
+
# https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#ERA5:datadocumentation-Table9
|
72 |
+
ALL_ATMOSPHERIC_VARS = (
|
73 |
+
"potential_vorticity",
|
74 |
+
"specific_rain_water_content",
|
75 |
+
"specific_snow_water_content",
|
76 |
+
"geopotential",
|
77 |
+
"temperature",
|
78 |
+
"u_component_of_wind",
|
79 |
+
"v_component_of_wind",
|
80 |
+
"specific_humidity",
|
81 |
+
"vertical_velocity",
|
82 |
+
"vorticity",
|
83 |
+
"divergence",
|
84 |
+
"relative_humidity",
|
85 |
+
"ozone_mass_mixing_ratio",
|
86 |
+
"specific_cloud_liquid_water_content",
|
87 |
+
"specific_cloud_ice_water_content",
|
88 |
+
"fraction_of_cloud_cover",
|
89 |
+
)
|
90 |
+
|
91 |
+
TARGET_SURFACE_VARS = (
|
92 |
+
"2m_temperature",
|
93 |
+
"mean_sea_level_pressure",
|
94 |
+
"10m_v_component_of_wind",
|
95 |
+
"10m_u_component_of_wind",
|
96 |
+
"total_precipitation_6hr",
|
97 |
+
)
|
98 |
+
TARGET_SURFACE_NO_PRECIP_VARS = (
|
99 |
+
"2m_temperature",
|
100 |
+
"mean_sea_level_pressure",
|
101 |
+
"10m_v_component_of_wind",
|
102 |
+
"10m_u_component_of_wind",
|
103 |
+
)
|
104 |
+
TARGET_ATMOSPHERIC_VARS = (
|
105 |
+
"temperature",
|
106 |
+
"geopotential",
|
107 |
+
"u_component_of_wind",
|
108 |
+
"v_component_of_wind",
|
109 |
+
"vertical_velocity",
|
110 |
+
"specific_humidity",
|
111 |
+
)
|
112 |
+
TARGET_ATMOSPHERIC_NO_W_VARS = (
|
113 |
+
"temperature",
|
114 |
+
"geopotential",
|
115 |
+
"u_component_of_wind",
|
116 |
+
"v_component_of_wind",
|
117 |
+
"specific_humidity",
|
118 |
+
)
|
119 |
+
EXTERNAL_FORCING_VARS = (
|
120 |
+
"toa_incident_solar_radiation",
|
121 |
+
)
|
122 |
+
GENERATED_FORCING_VARS = (
|
123 |
+
"year_progress_sin",
|
124 |
+
"year_progress_cos",
|
125 |
+
"day_progress_sin",
|
126 |
+
"day_progress_cos",
|
127 |
+
)
|
128 |
+
FORCING_VARS = EXTERNAL_FORCING_VARS + GENERATED_FORCING_VARS
|
129 |
+
STATIC_VARS = (
|
130 |
+
"geopotential_at_surface",
|
131 |
+
"land_sea_mask",
|
132 |
+
)
|
133 |
+
|
134 |
+
|
135 |
+
@chex.dataclass(frozen=True, eq=True)
|
136 |
+
class TaskConfig:
|
137 |
+
"""Defines inputs and targets on which a model is trained and/or evaluated."""
|
138 |
+
input_variables: tuple[str, ...]
|
139 |
+
# Target variables which the model is expected to predict.
|
140 |
+
target_variables: tuple[str, ...]
|
141 |
+
forcing_variables: tuple[str, ...]
|
142 |
+
pressure_levels: tuple[int, ...]
|
143 |
+
input_duration: str
|
144 |
+
|
145 |
+
TASK = TaskConfig(
|
146 |
+
input_variables=(
|
147 |
+
TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
|
148 |
+
STATIC_VARS),
|
149 |
+
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
|
150 |
+
forcing_variables=FORCING_VARS,
|
151 |
+
pressure_levels=PRESSURE_LEVELS_ERA5_37,
|
152 |
+
input_duration="12h",
|
153 |
+
)
|
154 |
+
TASK_13 = TaskConfig(
|
155 |
+
input_variables=(
|
156 |
+
TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
|
157 |
+
STATIC_VARS),
|
158 |
+
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
|
159 |
+
forcing_variables=FORCING_VARS,
|
160 |
+
pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
|
161 |
+
input_duration="12h",
|
162 |
+
)
|
163 |
+
TASK_13_PRECIP_OUT = TaskConfig(
|
164 |
+
input_variables=(
|
165 |
+
TARGET_SURFACE_NO_PRECIP_VARS + TARGET_ATMOSPHERIC_VARS + FORCING_VARS +
|
166 |
+
STATIC_VARS),
|
167 |
+
target_variables=TARGET_SURFACE_VARS + TARGET_ATMOSPHERIC_VARS,
|
168 |
+
forcing_variables=FORCING_VARS,
|
169 |
+
pressure_levels=PRESSURE_LEVELS_WEATHERBENCH_13,
|
170 |
+
input_duration="12h",
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
@chex.dataclass(frozen=True, eq=True)
|
175 |
+
class ModelConfig:
|
176 |
+
"""Defines the architecture of the GraphCast neural network architecture.
|
177 |
+
|
178 |
+
Properties:
|
179 |
+
resolution: The resolution of the data, in degrees (e.g. 0.25 or 1.0).
|
180 |
+
mesh_size: How many refinements to do on the multi-mesh.
|
181 |
+
gnn_msg_steps: How many Graph Network message passing steps to do.
|
182 |
+
latent_size: How many latent features to include in the various MLPs.
|
183 |
+
hidden_layers: How many hidden layers for each MLP.
|
184 |
+
radius_query_fraction_edge_length: Scalar that will be multiplied by the
|
185 |
+
length of the longest edge of the finest mesh to define the radius of
|
186 |
+
connectivity to use in the Grid2Mesh graph. Reasonable values are
|
187 |
+
between 0.6 and 1. 0.6 reduces the number of grid points feeding into
|
188 |
+
multiple mesh nodes and therefore reduces edge count and memory use, but
|
189 |
+
1 gives better predictions.
|
190 |
+
mesh2grid_edge_normalization_factor: Allows explicitly controlling edge
|
191 |
+
normalization for mesh2grid edges. If None, defaults to max edge length.
|
192 |
+
This supports using pre-trained model weights with a different graph
|
193 |
+
structure to what it was trained on.
|
194 |
+
"""
|
195 |
+
resolution: float
|
196 |
+
mesh_size: int
|
197 |
+
latent_size: int
|
198 |
+
gnn_msg_steps: int
|
199 |
+
hidden_layers: int
|
200 |
+
radius_query_fraction_edge_length: float
|
201 |
+
mesh2grid_edge_normalization_factor: Optional[float] = None
|
202 |
+
|
203 |
+
|
204 |
+
@chex.dataclass(frozen=True, eq=True)
|
205 |
+
class CheckPoint:
|
206 |
+
params: dict[str, Any]
|
207 |
+
model_config: ModelConfig
|
208 |
+
task_config: TaskConfig
|
209 |
+
description: str
|
210 |
+
license: str
|
211 |
+
|
212 |
+
|
213 |
+
class GraphCast(predictor_base.Predictor):
|
214 |
+
"""GraphCast Predictor.
|
215 |
+
|
216 |
+
The model works on graphs that take into account:
|
217 |
+
* Mesh nodes: nodes for the vertices of the mesh.
|
218 |
+
* Grid nodes: nodes for the points of the grid.
|
219 |
+
* Nodes: When referring to just "nodes", this means the joint set of
|
220 |
+
both mesh nodes, concatenated with grid nodes.
|
221 |
+
|
222 |
+
The model works with 3 graphs:
|
223 |
+
* Grid2Mesh graph: Graph that contains all nodes. This graph is strictly
|
224 |
+
bipartite with edges going from grid nodes to mesh nodes using a
|
225 |
+
fixed radius query. The grid2mesh_gnn will operate in this graph. The output
|
226 |
+
of this stage will be a latent representation for the mesh nodes, and a
|
227 |
+
latent representation for the grid nodes.
|
228 |
+
* Mesh graph: Graph that contains mesh nodes only. The mesh_gnn will
|
229 |
+
operate in this graph. It will update the latent state of the mesh nodes
|
230 |
+
only.
|
231 |
+
* Mesh2Grid graph: Graph that contains all nodes. This graph is strictly
|
232 |
+
bipartite with edges going from mesh nodes to grid nodes such that each grid
|
233 |
+
nodes is connected to 3 nodes of the mesh triangular face that contains
|
234 |
+
the grid points. The mesh2grid_gnn will operate in this graph. It will
|
235 |
+
process the updated latent state of the mesh nodes, and the latent state
|
236 |
+
of the grid nodes, to produce the final output for the grid nodes.
|
237 |
+
|
238 |
+
The model is built on top of `TypedGraph`s so the different types of nodes and
|
239 |
+
edges can be stored and treated separately.
|
240 |
+
|
241 |
+
"""
|
242 |
+
|
243 |
+
def __init__(self, model_config: ModelConfig, task_config: TaskConfig):
|
244 |
+
"""Initializes the predictor."""
|
245 |
+
self._spatial_features_kwargs = dict(
|
246 |
+
add_node_positions=False,
|
247 |
+
add_node_latitude=True,
|
248 |
+
add_node_longitude=True,
|
249 |
+
add_relative_positions=True,
|
250 |
+
relative_longitude_local_coordinates=True,
|
251 |
+
relative_latitude_local_coordinates=True,
|
252 |
+
)
|
253 |
+
|
254 |
+
# Specification of the multimesh.
|
255 |
+
self._meshes = (
|
256 |
+
icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
|
257 |
+
splits=model_config.mesh_size))
|
258 |
+
|
259 |
+
# Encoder, which moves data from the grid to the mesh with a single message
|
260 |
+
# passing step.
|
261 |
+
self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
|
262 |
+
embed_nodes=True, # Embed raw features of the grid and mesh nodes.
|
263 |
+
embed_edges=True, # Embed raw features of the grid2mesh edges.
|
264 |
+
edge_latent_size=dict(grid2mesh=model_config.latent_size),
|
265 |
+
node_latent_size=dict(
|
266 |
+
mesh_nodes=model_config.latent_size,
|
267 |
+
grid_nodes=model_config.latent_size),
|
268 |
+
mlp_hidden_size=model_config.latent_size,
|
269 |
+
mlp_num_hidden_layers=model_config.hidden_layers,
|
270 |
+
num_message_passing_steps=1,
|
271 |
+
use_layer_norm=True,
|
272 |
+
include_sent_messages_in_node_update=False,
|
273 |
+
activation="swish",
|
274 |
+
f32_aggregation=True,
|
275 |
+
aggregate_normalization=None,
|
276 |
+
name="grid2mesh_gnn",
|
277 |
+
)
|
278 |
+
|
279 |
+
# Processor, which performs message passing on the multi-mesh.
|
280 |
+
self._mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
|
281 |
+
embed_nodes=False, # Node features already embdded by previous layers.
|
282 |
+
embed_edges=True, # Embed raw features of the multi-mesh edges.
|
283 |
+
node_latent_size=dict(mesh_nodes=model_config.latent_size),
|
284 |
+
edge_latent_size=dict(mesh=model_config.latent_size),
|
285 |
+
mlp_hidden_size=model_config.latent_size,
|
286 |
+
mlp_num_hidden_layers=model_config.hidden_layers,
|
287 |
+
num_message_passing_steps=model_config.gnn_msg_steps,
|
288 |
+
use_layer_norm=True,
|
289 |
+
include_sent_messages_in_node_update=False,
|
290 |
+
activation="swish",
|
291 |
+
f32_aggregation=False,
|
292 |
+
name="mesh_gnn",
|
293 |
+
)
|
294 |
+
|
295 |
+
num_surface_vars = len(
|
296 |
+
set(task_config.target_variables) - set(ALL_ATMOSPHERIC_VARS))
|
297 |
+
num_atmospheric_vars = len(
|
298 |
+
set(task_config.target_variables) & set(ALL_ATMOSPHERIC_VARS))
|
299 |
+
num_outputs = (num_surface_vars +
|
300 |
+
len(task_config.pressure_levels) * num_atmospheric_vars)
|
301 |
+
|
302 |
+
# Decoder, which moves data from the mesh back into the grid with a single
|
303 |
+
# message passing step.
|
304 |
+
self._mesh2grid_gnn = deep_typed_graph_net.DeepTypedGraphNet(
|
305 |
+
# Require a specific node dimensionaly for the grid node outputs.
|
306 |
+
node_output_size=dict(grid_nodes=num_outputs),
|
307 |
+
embed_nodes=False, # Node features already embdded by previous layers.
|
308 |
+
embed_edges=True, # Embed raw features of the mesh2grid edges.
|
309 |
+
edge_latent_size=dict(mesh2grid=model_config.latent_size),
|
310 |
+
node_latent_size=dict(
|
311 |
+
mesh_nodes=model_config.latent_size,
|
312 |
+
grid_nodes=model_config.latent_size),
|
313 |
+
mlp_hidden_size=model_config.latent_size,
|
314 |
+
mlp_num_hidden_layers=model_config.hidden_layers,
|
315 |
+
num_message_passing_steps=1,
|
316 |
+
use_layer_norm=True,
|
317 |
+
include_sent_messages_in_node_update=False,
|
318 |
+
activation="swish",
|
319 |
+
f32_aggregation=False,
|
320 |
+
name="mesh2grid_gnn",
|
321 |
+
)
|
322 |
+
|
323 |
+
# Obtain the query radius in absolute units for the unit-sphere for the
|
324 |
+
# grid2mesh model, by rescaling the `radius_query_fraction_edge_length`.
|
325 |
+
self._query_radius = (_get_max_edge_distance(self._finest_mesh)
|
326 |
+
* model_config.radius_query_fraction_edge_length)
|
327 |
+
self._mesh2grid_edge_normalization_factor = (
|
328 |
+
model_config.mesh2grid_edge_normalization_factor
|
329 |
+
)
|
330 |
+
|
331 |
+
# Other initialization is delayed until the first call (`_maybe_init`)
|
332 |
+
# when we get some sample data so we know the lat/lon values.
|
333 |
+
self._initialized = False
|
334 |
+
|
335 |
+
# A "_init_mesh_properties":
|
336 |
+
# This one could be initialized at init but we delay it for consistency too.
|
337 |
+
self._num_mesh_nodes = None # num_mesh_nodes
|
338 |
+
self._mesh_nodes_lat = None # [num_mesh_nodes]
|
339 |
+
self._mesh_nodes_lon = None # [num_mesh_nodes]
|
340 |
+
|
341 |
+
# A "_init_grid_properties":
|
342 |
+
self._grid_lat = None # [num_lat_points]
|
343 |
+
self._grid_lon = None # [num_lon_points]
|
344 |
+
self._num_grid_nodes = None # num_lat_points * num_lon_points
|
345 |
+
self._grid_nodes_lat = None # [num_grid_nodes]
|
346 |
+
self._grid_nodes_lon = None # [num_grid_nodes]
|
347 |
+
|
348 |
+
# A "_init_{grid2mesh,processor,mesh2grid}_graph"
|
349 |
+
self._grid2mesh_graph_structure = None
|
350 |
+
self._mesh_graph_structure = None
|
351 |
+
self._mesh2grid_graph_structure = None
|
352 |
+
|
353 |
+
@property
|
354 |
+
def _finest_mesh(self):
|
355 |
+
return self._meshes[-1]
|
356 |
+
|
357 |
+
def __call__(self,
|
358 |
+
inputs: xarray.Dataset,
|
359 |
+
targets_template: xarray.Dataset,
|
360 |
+
forcings: xarray.Dataset,
|
361 |
+
is_training: bool = False,
|
362 |
+
) -> xarray.Dataset:
|
363 |
+
self._maybe_init(inputs)
|
364 |
+
|
365 |
+
# Convert all input data into flat vectors for each of the grid nodes.
|
366 |
+
# xarray (batch, time, lat, lon, level, multiple vars, forcings)
|
367 |
+
# -> [num_grid_nodes, batch, num_channels]
|
368 |
+
grid_node_features = self._inputs_to_grid_node_features(inputs, forcings)
|
369 |
+
|
370 |
+
# Transfer data for the grid to the mesh,
|
371 |
+
# [num_mesh_nodes, batch, latent_size], [num_grid_nodes, batch, latent_size]
|
372 |
+
(latent_mesh_nodes, latent_grid_nodes
|
373 |
+
) = self._run_grid2mesh_gnn(grid_node_features)
|
374 |
+
|
375 |
+
# Run message passing in the multimesh.
|
376 |
+
# [num_mesh_nodes, batch, latent_size]
|
377 |
+
updated_latent_mesh_nodes = self._run_mesh_gnn(latent_mesh_nodes)
|
378 |
+
|
379 |
+
# Transfer data frome the mesh to the grid.
|
380 |
+
# [num_grid_nodes, batch, output_size]
|
381 |
+
output_grid_nodes = self._run_mesh2grid_gnn(
|
382 |
+
updated_latent_mesh_nodes, latent_grid_nodes)
|
383 |
+
|
384 |
+
# Conver output flat vectors for the grid nodes to the format of the output.
|
385 |
+
# [num_grid_nodes, batch, output_size] ->
|
386 |
+
# xarray (batch, one time step, lat, lon, level, multiple vars)
|
387 |
+
return self._grid_node_outputs_to_prediction(
|
388 |
+
output_grid_nodes, targets_template)
|
389 |
+
|
390 |
+
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
|
391 |
+
self,
|
392 |
+
inputs: xarray.Dataset,
|
393 |
+
targets: xarray.Dataset,
|
394 |
+
forcings: xarray.Dataset,
|
395 |
+
) -> tuple[predictor_base.LossAndDiagnostics, xarray.Dataset]:
|
396 |
+
# Forward pass.
|
397 |
+
predictions = self(
|
398 |
+
inputs, targets_template=targets, forcings=forcings, is_training=True)
|
399 |
+
# Compute loss.
|
400 |
+
loss = losses.weighted_mse_per_level(
|
401 |
+
predictions, targets,
|
402 |
+
per_variable_weights={
|
403 |
+
# Any variables not specified here are weighted as 1.0.
|
404 |
+
# A single-level variable, but an important headline variable
|
405 |
+
# and also one which we have struggled to get good performance
|
406 |
+
# on at short lead times, so leaving it weighted at 1.0, equal
|
407 |
+
# to the multi-level variables:
|
408 |
+
"2m_temperature": 1.0,
|
409 |
+
# New single-level variables, which we don't weight too highly
|
410 |
+
# to avoid hurting performance on other variables.
|
411 |
+
"10m_u_component_of_wind": 0.1,
|
412 |
+
"10m_v_component_of_wind": 0.1,
|
413 |
+
"mean_sea_level_pressure": 0.1,
|
414 |
+
"total_precipitation_6hr": 0.1,
|
415 |
+
})
|
416 |
+
return loss, predictions # pytype: disable=bad-return-type # jax-ndarray
|
417 |
+
|
418 |
+
def loss( # pytype: disable=signature-mismatch # jax-ndarray
|
419 |
+
self,
|
420 |
+
inputs: xarray.Dataset,
|
421 |
+
targets: xarray.Dataset,
|
422 |
+
forcings: xarray.Dataset,
|
423 |
+
) -> predictor_base.LossAndDiagnostics:
|
424 |
+
loss, _ = self.loss_and_predictions(inputs, targets, forcings)
|
425 |
+
return loss # pytype: disable=bad-return-type # jax-ndarray
|
426 |
+
|
427 |
+
def _maybe_init(self, sample_inputs: xarray.Dataset):
|
428 |
+
"""Inits everything that has a dependency on the input coordinates."""
|
429 |
+
if not self._initialized:
|
430 |
+
self._init_mesh_properties()
|
431 |
+
self._init_grid_properties(
|
432 |
+
grid_lat=sample_inputs.lat, grid_lon=sample_inputs.lon)
|
433 |
+
self._grid2mesh_graph_structure = self._init_grid2mesh_graph()
|
434 |
+
self._mesh_graph_structure = self._init_mesh_graph()
|
435 |
+
self._mesh2grid_graph_structure = self._init_mesh2grid_graph()
|
436 |
+
|
437 |
+
self._initialized = True
|
438 |
+
|
439 |
+
def _init_mesh_properties(self):
|
440 |
+
"""Inits static properties that have to do with mesh nodes."""
|
441 |
+
self._num_mesh_nodes = self._finest_mesh.vertices.shape[0]
|
442 |
+
mesh_phi, mesh_theta = model_utils.cartesian_to_spherical(
|
443 |
+
self._finest_mesh.vertices[:, 0],
|
444 |
+
self._finest_mesh.vertices[:, 1],
|
445 |
+
self._finest_mesh.vertices[:, 2])
|
446 |
+
(
|
447 |
+
mesh_nodes_lat,
|
448 |
+
mesh_nodes_lon,
|
449 |
+
) = model_utils.spherical_to_lat_lon(
|
450 |
+
phi=mesh_phi, theta=mesh_theta)
|
451 |
+
# Convert to f32 to ensure the lat/lon features aren't in f64.
|
452 |
+
self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32)
|
453 |
+
self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32)
|
454 |
+
|
455 |
+
def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray):
|
456 |
+
"""Inits static properties that have to do with grid nodes."""
|
457 |
+
self._grid_lat = grid_lat.astype(np.float32)
|
458 |
+
self._grid_lon = grid_lon.astype(np.float32)
|
459 |
+
# Initialized the counters.
|
460 |
+
self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0]
|
461 |
+
|
462 |
+
# Initialize lat and lon for the grid.
|
463 |
+
grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat)
|
464 |
+
self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32)
|
465 |
+
self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32)
|
466 |
+
|
467 |
+
def _init_grid2mesh_graph(self) -> typed_graph.TypedGraph:
|
468 |
+
"""Build Grid2Mesh graph."""
|
469 |
+
|
470 |
+
# Create some edges according to distance between mesh and grid nodes.
|
471 |
+
assert self._grid_lat is not None and self._grid_lon is not None
|
472 |
+
(grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices(
|
473 |
+
grid_latitude=self._grid_lat,
|
474 |
+
grid_longitude=self._grid_lon,
|
475 |
+
mesh=self._finest_mesh,
|
476 |
+
radius=self._query_radius)
|
477 |
+
|
478 |
+
# Edges sending info from grid to mesh.
|
479 |
+
senders = grid_indices
|
480 |
+
receivers = mesh_indices
|
481 |
+
|
482 |
+
# Precompute structural node and edge features according to config options.
|
483 |
+
# Structural features are those that depend on the fixed values of the
|
484 |
+
# latitude and longitudes of the nodes.
|
485 |
+
(senders_node_features, receivers_node_features,
|
486 |
+
edge_features) = model_utils.get_bipartite_graph_spatial_features(
|
487 |
+
senders_node_lat=self._grid_nodes_lat,
|
488 |
+
senders_node_lon=self._grid_nodes_lon,
|
489 |
+
receivers_node_lat=self._mesh_nodes_lat,
|
490 |
+
receivers_node_lon=self._mesh_nodes_lon,
|
491 |
+
senders=senders,
|
492 |
+
receivers=receivers,
|
493 |
+
edge_normalization_factor=None,
|
494 |
+
**self._spatial_features_kwargs,
|
495 |
+
)
|
496 |
+
|
497 |
+
n_grid_node = np.array([self._num_grid_nodes])
|
498 |
+
n_mesh_node = np.array([self._num_mesh_nodes])
|
499 |
+
n_edge = np.array([mesh_indices.shape[0]])
|
500 |
+
grid_node_set = typed_graph.NodeSet(
|
501 |
+
n_node=n_grid_node, features=senders_node_features)
|
502 |
+
mesh_node_set = typed_graph.NodeSet(
|
503 |
+
n_node=n_mesh_node, features=receivers_node_features)
|
504 |
+
edge_set = typed_graph.EdgeSet(
|
505 |
+
n_edge=n_edge,
|
506 |
+
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
|
507 |
+
features=edge_features)
|
508 |
+
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
|
509 |
+
edges = {
|
510 |
+
typed_graph.EdgeSetKey("grid2mesh", ("grid_nodes", "mesh_nodes")):
|
511 |
+
edge_set
|
512 |
+
}
|
513 |
+
grid2mesh_graph = typed_graph.TypedGraph(
|
514 |
+
context=typed_graph.Context(n_graph=np.array([1]), features=()),
|
515 |
+
nodes=nodes,
|
516 |
+
edges=edges)
|
517 |
+
return grid2mesh_graph
|
518 |
+
|
519 |
+
def _init_mesh_graph(self) -> typed_graph.TypedGraph:
|
520 |
+
"""Build Mesh graph."""
|
521 |
+
merged_mesh = icosahedral_mesh.merge_meshes(self._meshes)
|
522 |
+
|
523 |
+
# Work simply on the mesh edges.
|
524 |
+
senders, receivers = icosahedral_mesh.faces_to_edges(merged_mesh.faces)
|
525 |
+
|
526 |
+
# Precompute structural node and edge features according to config options.
|
527 |
+
# Structural features are those that depend on the fixed values of the
|
528 |
+
# latitude and longitudes of the nodes.
|
529 |
+
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
|
530 |
+
node_features, edge_features = model_utils.get_graph_spatial_features(
|
531 |
+
node_lat=self._mesh_nodes_lat,
|
532 |
+
node_lon=self._mesh_nodes_lon,
|
533 |
+
senders=senders,
|
534 |
+
receivers=receivers,
|
535 |
+
**self._spatial_features_kwargs,
|
536 |
+
)
|
537 |
+
|
538 |
+
n_mesh_node = np.array([self._num_mesh_nodes])
|
539 |
+
n_edge = np.array([senders.shape[0]])
|
540 |
+
assert n_mesh_node == len(node_features)
|
541 |
+
mesh_node_set = typed_graph.NodeSet(
|
542 |
+
n_node=n_mesh_node, features=node_features)
|
543 |
+
edge_set = typed_graph.EdgeSet(
|
544 |
+
n_edge=n_edge,
|
545 |
+
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
|
546 |
+
features=edge_features)
|
547 |
+
nodes = {"mesh_nodes": mesh_node_set}
|
548 |
+
edges = {
|
549 |
+
typed_graph.EdgeSetKey("mesh", ("mesh_nodes", "mesh_nodes")): edge_set
|
550 |
+
}
|
551 |
+
mesh_graph = typed_graph.TypedGraph(
|
552 |
+
context=typed_graph.Context(n_graph=np.array([1]), features=()),
|
553 |
+
nodes=nodes,
|
554 |
+
edges=edges)
|
555 |
+
|
556 |
+
return mesh_graph
|
557 |
+
|
558 |
+
def _init_mesh2grid_graph(self) -> typed_graph.TypedGraph:
|
559 |
+
"""Build Mesh2Grid graph."""
|
560 |
+
|
561 |
+
# Create some edges according to how the grid nodes are contained by
|
562 |
+
# mesh triangles.
|
563 |
+
(grid_indices,
|
564 |
+
mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices(
|
565 |
+
grid_latitude=self._grid_lat,
|
566 |
+
grid_longitude=self._grid_lon,
|
567 |
+
mesh=self._finest_mesh)
|
568 |
+
|
569 |
+
# Edges sending info from mesh to grid.
|
570 |
+
senders = mesh_indices
|
571 |
+
receivers = grid_indices
|
572 |
+
|
573 |
+
# Precompute structural node and edge features according to config options.
|
574 |
+
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None
|
575 |
+
(senders_node_features, receivers_node_features,
|
576 |
+
edge_features) = model_utils.get_bipartite_graph_spatial_features(
|
577 |
+
senders_node_lat=self._mesh_nodes_lat,
|
578 |
+
senders_node_lon=self._mesh_nodes_lon,
|
579 |
+
receivers_node_lat=self._grid_nodes_lat,
|
580 |
+
receivers_node_lon=self._grid_nodes_lon,
|
581 |
+
senders=senders,
|
582 |
+
receivers=receivers,
|
583 |
+
edge_normalization_factor=self._mesh2grid_edge_normalization_factor,
|
584 |
+
**self._spatial_features_kwargs,
|
585 |
+
)
|
586 |
+
|
587 |
+
n_grid_node = np.array([self._num_grid_nodes])
|
588 |
+
n_mesh_node = np.array([self._num_mesh_nodes])
|
589 |
+
n_edge = np.array([senders.shape[0]])
|
590 |
+
grid_node_set = typed_graph.NodeSet(
|
591 |
+
n_node=n_grid_node, features=receivers_node_features)
|
592 |
+
mesh_node_set = typed_graph.NodeSet(
|
593 |
+
n_node=n_mesh_node, features=senders_node_features)
|
594 |
+
edge_set = typed_graph.EdgeSet(
|
595 |
+
n_edge=n_edge,
|
596 |
+
indices=typed_graph.EdgesIndices(senders=senders, receivers=receivers),
|
597 |
+
features=edge_features)
|
598 |
+
nodes = {"grid_nodes": grid_node_set, "mesh_nodes": mesh_node_set}
|
599 |
+
edges = {
|
600 |
+
typed_graph.EdgeSetKey("mesh2grid", ("mesh_nodes", "grid_nodes")):
|
601 |
+
edge_set
|
602 |
+
}
|
603 |
+
mesh2grid_graph = typed_graph.TypedGraph(
|
604 |
+
context=typed_graph.Context(n_graph=np.array([1]), features=()),
|
605 |
+
nodes=nodes,
|
606 |
+
edges=edges)
|
607 |
+
return mesh2grid_graph
|
608 |
+
|
609 |
+
def _run_grid2mesh_gnn(self, grid_node_features: chex.Array,
|
610 |
+
) -> tuple[chex.Array, chex.Array]:
|
611 |
+
"""Runs the grid2mesh_gnn, extracting latent mesh and grid nodes."""
|
612 |
+
|
613 |
+
# Concatenate node structural features with input features.
|
614 |
+
batch_size = grid_node_features.shape[1]
|
615 |
+
|
616 |
+
grid2mesh_graph = self._grid2mesh_graph_structure
|
617 |
+
assert grid2mesh_graph is not None
|
618 |
+
grid_nodes = grid2mesh_graph.nodes["grid_nodes"]
|
619 |
+
mesh_nodes = grid2mesh_graph.nodes["mesh_nodes"]
|
620 |
+
new_grid_nodes = grid_nodes._replace(
|
621 |
+
features=jnp.concatenate([
|
622 |
+
grid_node_features,
|
623 |
+
_add_batch_second_axis(
|
624 |
+
grid_nodes.features.astype(grid_node_features.dtype),
|
625 |
+
batch_size)
|
626 |
+
],
|
627 |
+
axis=-1))
|
628 |
+
|
629 |
+
# To make sure capacity of the embedded is identical for the grid nodes and
|
630 |
+
# the mesh nodes, we also append some dummy zero input features for the
|
631 |
+
# mesh nodes.
|
632 |
+
dummy_mesh_node_features = jnp.zeros(
|
633 |
+
(self._num_mesh_nodes,) + grid_node_features.shape[1:],
|
634 |
+
dtype=grid_node_features.dtype)
|
635 |
+
new_mesh_nodes = mesh_nodes._replace(
|
636 |
+
features=jnp.concatenate([
|
637 |
+
dummy_mesh_node_features,
|
638 |
+
_add_batch_second_axis(
|
639 |
+
mesh_nodes.features.astype(dummy_mesh_node_features.dtype),
|
640 |
+
batch_size)
|
641 |
+
],
|
642 |
+
axis=-1))
|
643 |
+
|
644 |
+
# Broadcast edge structural features to the required batch size.
|
645 |
+
grid2mesh_edges_key = grid2mesh_graph.edge_key_by_name("grid2mesh")
|
646 |
+
edges = grid2mesh_graph.edges[grid2mesh_edges_key]
|
647 |
+
|
648 |
+
new_edges = edges._replace(
|
649 |
+
features=_add_batch_second_axis(
|
650 |
+
edges.features.astype(dummy_mesh_node_features.dtype), batch_size))
|
651 |
+
|
652 |
+
input_graph = self._grid2mesh_graph_structure._replace(
|
653 |
+
edges={grid2mesh_edges_key: new_edges},
|
654 |
+
nodes={
|
655 |
+
"grid_nodes": new_grid_nodes,
|
656 |
+
"mesh_nodes": new_mesh_nodes
|
657 |
+
})
|
658 |
+
|
659 |
+
# Run the GNN.
|
660 |
+
grid2mesh_out = self._grid2mesh_gnn(input_graph)
|
661 |
+
latent_mesh_nodes = grid2mesh_out.nodes["mesh_nodes"].features
|
662 |
+
latent_grid_nodes = grid2mesh_out.nodes["grid_nodes"].features
|
663 |
+
return latent_mesh_nodes, latent_grid_nodes
|
664 |
+
|
665 |
+
def _run_mesh_gnn(self, latent_mesh_nodes: chex.Array) -> chex.Array:
|
666 |
+
"""Runs the mesh_gnn, extracting updated latent mesh nodes."""
|
667 |
+
|
668 |
+
# Add the structural edge features of this graph. Note we don't need
|
669 |
+
# to add the structural node features, because these are already part of
|
670 |
+
# the latent state, via the original Grid2Mesh gnn, however, we need
|
671 |
+
# the edge ones, because it is the first time we are seeing this particular
|
672 |
+
# set of edges.
|
673 |
+
batch_size = latent_mesh_nodes.shape[1]
|
674 |
+
|
675 |
+
mesh_graph = self._mesh_graph_structure
|
676 |
+
assert mesh_graph is not None
|
677 |
+
mesh_edges_key = mesh_graph.edge_key_by_name("mesh")
|
678 |
+
edges = mesh_graph.edges[mesh_edges_key]
|
679 |
+
|
680 |
+
# We are assuming here that the mesh gnn uses a single set of edge keys
|
681 |
+
# named "mesh" for the edges and that it uses a single set of nodes named
|
682 |
+
# "mesh_nodes"
|
683 |
+
msg = ("The setup currently requires to only have one kind of edge in the"
|
684 |
+
" mesh GNN.")
|
685 |
+
assert len(mesh_graph.edges) == 1, msg
|
686 |
+
|
687 |
+
new_edges = edges._replace(
|
688 |
+
features=_add_batch_second_axis(
|
689 |
+
edges.features.astype(latent_mesh_nodes.dtype), batch_size))
|
690 |
+
|
691 |
+
nodes = mesh_graph.nodes["mesh_nodes"]
|
692 |
+
nodes = nodes._replace(features=latent_mesh_nodes)
|
693 |
+
|
694 |
+
input_graph = mesh_graph._replace(
|
695 |
+
edges={mesh_edges_key: new_edges}, nodes={"mesh_nodes": nodes})
|
696 |
+
|
697 |
+
# Run the GNN.
|
698 |
+
return self._mesh_gnn(input_graph).nodes["mesh_nodes"].features
|
699 |
+
|
700 |
+
def _run_mesh2grid_gnn(self,
|
701 |
+
updated_latent_mesh_nodes: chex.Array,
|
702 |
+
latent_grid_nodes: chex.Array,
|
703 |
+
) -> chex.Array:
|
704 |
+
"""Runs the mesh2grid_gnn, extracting the output grid nodes."""
|
705 |
+
|
706 |
+
# Add the structural edge features of this graph. Note we don't need
|
707 |
+
# to add the structural node features, because these are already part of
|
708 |
+
# the latent state, via the original Grid2Mesh gnn, however, we need
|
709 |
+
# the edge ones, because it is the first time we are seeing this particular
|
710 |
+
# set of edges.
|
711 |
+
batch_size = updated_latent_mesh_nodes.shape[1]
|
712 |
+
|
713 |
+
mesh2grid_graph = self._mesh2grid_graph_structure
|
714 |
+
assert mesh2grid_graph is not None
|
715 |
+
mesh_nodes = mesh2grid_graph.nodes["mesh_nodes"]
|
716 |
+
grid_nodes = mesh2grid_graph.nodes["grid_nodes"]
|
717 |
+
new_mesh_nodes = mesh_nodes._replace(features=updated_latent_mesh_nodes)
|
718 |
+
new_grid_nodes = grid_nodes._replace(features=latent_grid_nodes)
|
719 |
+
mesh2grid_key = mesh2grid_graph.edge_key_by_name("mesh2grid")
|
720 |
+
edges = mesh2grid_graph.edges[mesh2grid_key]
|
721 |
+
|
722 |
+
new_edges = edges._replace(
|
723 |
+
features=_add_batch_second_axis(
|
724 |
+
edges.features.astype(latent_grid_nodes.dtype), batch_size))
|
725 |
+
|
726 |
+
input_graph = mesh2grid_graph._replace(
|
727 |
+
edges={mesh2grid_key: new_edges},
|
728 |
+
nodes={
|
729 |
+
"mesh_nodes": new_mesh_nodes,
|
730 |
+
"grid_nodes": new_grid_nodes
|
731 |
+
})
|
732 |
+
|
733 |
+
# Run the GNN.
|
734 |
+
output_graph = self._mesh2grid_gnn(input_graph)
|
735 |
+
output_grid_nodes = output_graph.nodes["grid_nodes"].features
|
736 |
+
|
737 |
+
return output_grid_nodes
|
738 |
+
|
739 |
+
def _inputs_to_grid_node_features(
|
740 |
+
self,
|
741 |
+
inputs: xarray.Dataset,
|
742 |
+
forcings: xarray.Dataset,
|
743 |
+
) -> chex.Array:
|
744 |
+
"""xarrays -> [num_grid_nodes, batch, num_channels]."""
|
745 |
+
|
746 |
+
# xarray `Dataset` (batch, time, lat, lon, level, multiple vars)
|
747 |
+
# to xarray `DataArray` (batch, lat, lon, channels)
|
748 |
+
stacked_inputs = model_utils.dataset_to_stacked(inputs)
|
749 |
+
stacked_forcings = model_utils.dataset_to_stacked(forcings)
|
750 |
+
stacked_inputs = xarray.concat(
|
751 |
+
[stacked_inputs, stacked_forcings], dim="channels")
|
752 |
+
|
753 |
+
# xarray `DataArray` (batch, lat, lon, channels)
|
754 |
+
# to single numpy array with shape [lat_lon_node, batch, channels]
|
755 |
+
grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(
|
756 |
+
stacked_inputs)
|
757 |
+
return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
|
758 |
+
(-1,) + grid_xarray_lat_lon_leading.data.shape[2:])
|
759 |
+
|
760 |
+
def _grid_node_outputs_to_prediction(
|
761 |
+
self,
|
762 |
+
grid_node_outputs: chex.Array,
|
763 |
+
targets_template: xarray.Dataset,
|
764 |
+
) -> xarray.Dataset:
|
765 |
+
"""[num_grid_nodes, batch, num_outputs] -> xarray."""
|
766 |
+
|
767 |
+
# numpy array with shape [lat_lon_node, batch, channels]
|
768 |
+
# to xarray `DataArray` (batch, lat, lon, channels)
|
769 |
+
assert self._grid_lat is not None and self._grid_lon is not None
|
770 |
+
grid_shape = (self._grid_lat.shape[0], self._grid_lon.shape[0])
|
771 |
+
grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
|
772 |
+
grid_shape + grid_node_outputs.shape[1:])
|
773 |
+
dims = ("lat", "lon", "batch", "channels")
|
774 |
+
grid_xarray_lat_lon_leading = xarray_jax.DataArray(
|
775 |
+
data=grid_outputs_lat_lon_leading,
|
776 |
+
dims=dims)
|
777 |
+
grid_xarray = model_utils.restore_leading_axes(grid_xarray_lat_lon_leading)
|
778 |
+
|
779 |
+
# xarray `DataArray` (batch, lat, lon, channels)
|
780 |
+
# to xarray `Dataset` (batch, one time step, lat, lon, level, multiple vars)
|
781 |
+
return model_utils.stacked_to_dataset(
|
782 |
+
grid_xarray.variable, targets_template)
|
783 |
+
|
784 |
+
|
785 |
+
def _add_batch_second_axis(data, batch_size):
|
786 |
+
# data [leading_dim, trailing_dim]
|
787 |
+
assert data.ndim == 2
|
788 |
+
ones = jnp.ones([batch_size, 1], dtype=data.dtype)
|
789 |
+
return data[:, None] * ones # [leading_dim, batch, trailing_dim]
|
790 |
+
|
791 |
+
|
792 |
+
def _get_max_edge_distance(mesh):
|
793 |
+
senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces)
|
794 |
+
edge_distances = np.linalg.norm(
|
795 |
+
mesh.vertices[senders] - mesh.vertices[receivers], axis=-1)
|
796 |
+
return edge_distances.max()
|
graphcast/grid_mesh_connectivity.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tools for converting from regular grids on a sphere, to triangular meshes."""
|
15 |
+
|
16 |
+
from graphcast import icosahedral_mesh
|
17 |
+
import numpy as np
|
18 |
+
import scipy
|
19 |
+
import trimesh
|
20 |
+
|
21 |
+
|
22 |
+
def _grid_lat_lon_to_coordinates(
|
23 |
+
grid_latitude: np.ndarray, grid_longitude: np.ndarray) -> np.ndarray:
|
24 |
+
"""Lat [num_lat] lon [num_lon] to 3d coordinates [num_lat, num_lon, 3]."""
|
25 |
+
# Convert to spherical coordinates phi and theta defined in the grid.
|
26 |
+
# Each [num_latitude_points, num_longitude_points]
|
27 |
+
phi_grid, theta_grid = np.meshgrid(
|
28 |
+
np.deg2rad(grid_longitude),
|
29 |
+
np.deg2rad(90 - grid_latitude))
|
30 |
+
|
31 |
+
# [num_latitude_points, num_longitude_points, 3]
|
32 |
+
# Note this assumes unit radius, since for now we model the earth as a
|
33 |
+
# sphere of unit radius, and keep any vertical dimension as a regular grid.
|
34 |
+
return np.stack(
|
35 |
+
[np.cos(phi_grid)*np.sin(theta_grid),
|
36 |
+
np.sin(phi_grid)*np.sin(theta_grid),
|
37 |
+
np.cos(theta_grid)], axis=-1)
|
38 |
+
|
39 |
+
|
40 |
+
def radius_query_indices(
|
41 |
+
*,
|
42 |
+
grid_latitude: np.ndarray,
|
43 |
+
grid_longitude: np.ndarray,
|
44 |
+
mesh: icosahedral_mesh.TriangularMesh,
|
45 |
+
radius: float) -> tuple[np.ndarray, np.ndarray]:
|
46 |
+
"""Returns mesh-grid edge indices for radius query.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
grid_latitude: Latitude values for the grid [num_lat_points]
|
50 |
+
grid_longitude: Longitude values for the grid [num_lon_points]
|
51 |
+
mesh: Mesh object.
|
52 |
+
radius: Radius of connectivity in R3. for a sphere of unit radius.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
tuple with `grid_indices` and `mesh_indices` indicating edges between the
|
56 |
+
grid and the mesh such that the distances in a straight line (not geodesic)
|
57 |
+
are smaller than or equal to `radius`.
|
58 |
+
* grid_indices: Indices of shape [num_edges], that index into a
|
59 |
+
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
|
60 |
+
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
|
61 |
+
"""
|
62 |
+
|
63 |
+
# [num_grid_points=num_lat_points * num_lon_points, 3]
|
64 |
+
grid_positions = _grid_lat_lon_to_coordinates(
|
65 |
+
grid_latitude, grid_longitude).reshape([-1, 3])
|
66 |
+
|
67 |
+
# [num_mesh_points, 3]
|
68 |
+
mesh_positions = mesh.vertices
|
69 |
+
kd_tree = scipy.spatial.cKDTree(mesh_positions)
|
70 |
+
|
71 |
+
# [num_grid_points, num_mesh_points_per_grid_point]
|
72 |
+
# Note `num_mesh_points_per_grid_point` is not constant, so this is a list
|
73 |
+
# of arrays, rather than a 2d array.
|
74 |
+
query_indices = kd_tree.query_ball_point(x=grid_positions, r=radius)
|
75 |
+
|
76 |
+
grid_edge_indices = []
|
77 |
+
mesh_edge_indices = []
|
78 |
+
for grid_index, mesh_neighbors in enumerate(query_indices):
|
79 |
+
grid_edge_indices.append(np.repeat(grid_index, len(mesh_neighbors)))
|
80 |
+
mesh_edge_indices.append(mesh_neighbors)
|
81 |
+
|
82 |
+
# [num_edges]
|
83 |
+
grid_edge_indices = np.concatenate(grid_edge_indices, axis=0).astype(int)
|
84 |
+
mesh_edge_indices = np.concatenate(mesh_edge_indices, axis=0).astype(int)
|
85 |
+
|
86 |
+
return grid_edge_indices, mesh_edge_indices
|
87 |
+
|
88 |
+
|
89 |
+
def in_mesh_triangle_indices(
|
90 |
+
*,
|
91 |
+
grid_latitude: np.ndarray,
|
92 |
+
grid_longitude: np.ndarray,
|
93 |
+
mesh: icosahedral_mesh.TriangularMesh) -> tuple[np.ndarray, np.ndarray]:
|
94 |
+
"""Returns mesh-grid edge indices for grid points contained in mesh triangles.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
grid_latitude: Latitude values for the grid [num_lat_points]
|
98 |
+
grid_longitude: Longitude values for the grid [num_lon_points]
|
99 |
+
mesh: Mesh object.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
tuple with `grid_indices` and `mesh_indices` indicating edges between the
|
103 |
+
grid and the mesh vertices of the triangle that contain each grid point.
|
104 |
+
The number of edges is always num_lat_points * num_lon_points * 3
|
105 |
+
* grid_indices: Indices of shape [num_edges], that index into a
|
106 |
+
[num_lat_points, num_lon_points] grid, after flattening the leading axes.
|
107 |
+
* mesh_indices: Indices of shape [num_edges], that index into mesh.vertices.
|
108 |
+
"""
|
109 |
+
|
110 |
+
# [num_grid_points=num_lat_points * num_lon_points, 3]
|
111 |
+
grid_positions = _grid_lat_lon_to_coordinates(
|
112 |
+
grid_latitude, grid_longitude).reshape([-1, 3])
|
113 |
+
|
114 |
+
mesh_trimesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces)
|
115 |
+
|
116 |
+
# [num_grid_points] with mesh face indices for each grid point.
|
117 |
+
_, _, query_face_indices = trimesh.proximity.closest_point(
|
118 |
+
mesh_trimesh, grid_positions)
|
119 |
+
|
120 |
+
# [num_grid_points, 3] with mesh node indices for each grid point.
|
121 |
+
mesh_edge_indices = mesh.faces[query_face_indices]
|
122 |
+
|
123 |
+
# [num_grid_points, 3] with grid node indices, where every row simply contains
|
124 |
+
# the row (grid_point) index.
|
125 |
+
grid_indices = np.arange(grid_positions.shape[0])
|
126 |
+
grid_edge_indices = np.tile(grid_indices.reshape([-1, 1]), [1, 3])
|
127 |
+
|
128 |
+
# Flatten to get a regular list.
|
129 |
+
# [num_edges=num_grid_points*3]
|
130 |
+
mesh_edge_indices = mesh_edge_indices.reshape([-1])
|
131 |
+
grid_edge_indices = grid_edge_indices.reshape([-1])
|
132 |
+
|
133 |
+
return grid_edge_indices, mesh_edge_indices
|
graphcast/grid_mesh_connectivity_test.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tests for graphcast.grid_mesh_connectivity."""
|
15 |
+
|
16 |
+
from absl.testing import absltest
|
17 |
+
from graphcast import grid_mesh_connectivity
|
18 |
+
from graphcast import icosahedral_mesh
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
class GridMeshConnectivityTest(absltest.TestCase):
|
23 |
+
|
24 |
+
def test_grid_lat_lon_to_coordinates(self):
|
25 |
+
|
26 |
+
# Intervals of 30 degrees.
|
27 |
+
grid_latitude = np.array([-45., 0., 45])
|
28 |
+
grid_longitude = np.array([0., 90., 180., 270.])
|
29 |
+
|
30 |
+
inv_sqrt2 = 1 / np.sqrt(2)
|
31 |
+
expected_coordinates = np.array([
|
32 |
+
[[inv_sqrt2, 0., -inv_sqrt2],
|
33 |
+
[0., inv_sqrt2, -inv_sqrt2],
|
34 |
+
[-inv_sqrt2, 0., -inv_sqrt2],
|
35 |
+
[0., -inv_sqrt2, -inv_sqrt2]],
|
36 |
+
[[1., 0., 0.],
|
37 |
+
[0., 1., 0.],
|
38 |
+
[-1., 0., 0.],
|
39 |
+
[0., -1., 0.]],
|
40 |
+
[[inv_sqrt2, 0., inv_sqrt2],
|
41 |
+
[0., inv_sqrt2, inv_sqrt2],
|
42 |
+
[-inv_sqrt2, 0., inv_sqrt2],
|
43 |
+
[0., -inv_sqrt2, inv_sqrt2]],
|
44 |
+
])
|
45 |
+
|
46 |
+
coordinates = grid_mesh_connectivity._grid_lat_lon_to_coordinates(
|
47 |
+
grid_latitude, grid_longitude)
|
48 |
+
np.testing.assert_allclose(expected_coordinates, coordinates, atol=1e-15)
|
49 |
+
|
50 |
+
def test_radius_query_indices_smoke(self):
|
51 |
+
# TODO(alvarosg): Add non-smoke test?
|
52 |
+
grid_latitude = np.linspace(-75, 75, 6)
|
53 |
+
grid_longitude = np.arange(12) * 30.
|
54 |
+
mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
|
55 |
+
splits=3)[-1]
|
56 |
+
grid_mesh_connectivity.radius_query_indices(
|
57 |
+
grid_latitude=grid_latitude,
|
58 |
+
grid_longitude=grid_longitude,
|
59 |
+
mesh=mesh, radius=0.2)
|
60 |
+
|
61 |
+
def test_in_mesh_triangle_indices_smoke(self):
|
62 |
+
# TODO(alvarosg): Add non-smoke test?
|
63 |
+
grid_latitude = np.linspace(-75, 75, 6)
|
64 |
+
grid_longitude = np.arange(12) * 30.
|
65 |
+
mesh = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
|
66 |
+
splits=3)[-1]
|
67 |
+
grid_mesh_connectivity.in_mesh_triangle_indices(
|
68 |
+
grid_latitude=grid_latitude,
|
69 |
+
grid_longitude=grid_longitude,
|
70 |
+
mesh=mesh)
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
absltest.main()
|
graphcast/icosahedral_mesh.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utils for creating icosahedral meshes."""
|
15 |
+
|
16 |
+
import itertools
|
17 |
+
from typing import List, NamedTuple, Sequence, Tuple
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
from scipy.spatial import transform
|
21 |
+
|
22 |
+
|
23 |
+
class TriangularMesh(NamedTuple):
|
24 |
+
"""Data structure for triangular meshes.
|
25 |
+
|
26 |
+
Attributes:
|
27 |
+
vertices: spatial positions of the vertices of the mesh of shape
|
28 |
+
[num_vertices, num_dims].
|
29 |
+
faces: triangular faces of the mesh of shape [num_faces, 3]. Contains
|
30 |
+
integer indices into `vertices`.
|
31 |
+
|
32 |
+
"""
|
33 |
+
vertices: np.ndarray
|
34 |
+
faces: np.ndarray
|
35 |
+
|
36 |
+
|
37 |
+
def merge_meshes(
|
38 |
+
mesh_list: Sequence[TriangularMesh]) -> TriangularMesh:
|
39 |
+
"""Merges all meshes into one. Assumes the last mesh is the finest.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
mesh_list: Sequence of meshes, from coarse to fine refinement levels. The
|
43 |
+
vertices and faces may contain those from preceding, coarser levels.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
`TriangularMesh` for which the vertices correspond to the highest
|
47 |
+
resolution mesh in the hierarchy, and the faces are the join set of the
|
48 |
+
faces at all levels of the hierarchy.
|
49 |
+
"""
|
50 |
+
for mesh_i, mesh_ip1 in itertools.pairwise(mesh_list):
|
51 |
+
num_nodes_mesh_i = mesh_i.vertices.shape[0]
|
52 |
+
assert np.allclose(mesh_i.vertices, mesh_ip1.vertices[:num_nodes_mesh_i])
|
53 |
+
|
54 |
+
return TriangularMesh(
|
55 |
+
vertices=mesh_list[-1].vertices,
|
56 |
+
faces=np.concatenate([mesh.faces for mesh in mesh_list], axis=0))
|
57 |
+
|
58 |
+
|
59 |
+
def get_hierarchy_of_triangular_meshes_for_sphere(
|
60 |
+
splits: int) -> List[TriangularMesh]:
|
61 |
+
"""Returns a sequence of meshes, each with triangularization sphere.
|
62 |
+
|
63 |
+
Starting with a regular icosahedron (12 vertices, 20 faces, 30 edges) with
|
64 |
+
circumscribed unit sphere. Then, each triangular face is iteratively
|
65 |
+
subdivided into 4 triangular faces `splits` times. The new vertices are then
|
66 |
+
projected back onto the unit sphere. All resulting meshes are returned in a
|
67 |
+
list, from lowest to highest resolution.
|
68 |
+
|
69 |
+
The vertices in each face are specified in counter-clockwise order as
|
70 |
+
observed from the outside the icosahedron.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
splits: How many times to split each triangle.
|
74 |
+
Returns:
|
75 |
+
Sequence of `TriangularMesh`s of length `splits + 1` each with:
|
76 |
+
|
77 |
+
vertices: [num_vertices, 3] vertex positions in 3D, all with unit norm.
|
78 |
+
faces: [num_faces, 3] with triangular faces joining sets of 3 vertices.
|
79 |
+
Each row contains three indices into the vertices array, indicating
|
80 |
+
the vertices adjacent to the face. Always with positive orientation
|
81 |
+
(counterclock-wise when looking from the outside).
|
82 |
+
"""
|
83 |
+
current_mesh = get_icosahedron()
|
84 |
+
output_meshes = [current_mesh]
|
85 |
+
for _ in range(splits):
|
86 |
+
current_mesh = _two_split_unit_sphere_triangle_faces(current_mesh)
|
87 |
+
output_meshes.append(current_mesh)
|
88 |
+
return output_meshes
|
89 |
+
|
90 |
+
|
91 |
+
def get_icosahedron() -> TriangularMesh:
|
92 |
+
"""Returns a regular icosahedral mesh with circumscribed unit sphere.
|
93 |
+
|
94 |
+
See https://en.wikipedia.org/wiki/Regular_icosahedron#Cartesian_coordinates
|
95 |
+
for details on the construction of the regular icosahedron.
|
96 |
+
|
97 |
+
The vertices in each face are specified in counter-clockwise order as observed
|
98 |
+
from the outside of the icosahedron.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
TriangularMesh with:
|
102 |
+
|
103 |
+
vertices: [num_vertices=12, 3] vertex positions in 3D, all with unit norm.
|
104 |
+
faces: [num_faces=20, 3] with triangular faces joining sets of 3 vertices.
|
105 |
+
Each row contains three indices into the vertices array, indicating
|
106 |
+
the vertices adjacent to the face. Always with positive orientation (
|
107 |
+
counterclock-wise when looking from the outside).
|
108 |
+
|
109 |
+
"""
|
110 |
+
phi = (1 + np.sqrt(5)) / 2
|
111 |
+
vertices = []
|
112 |
+
for c1 in [1., -1.]:
|
113 |
+
for c2 in [phi, -phi]:
|
114 |
+
vertices.append((c1, c2, 0.))
|
115 |
+
vertices.append((0., c1, c2))
|
116 |
+
vertices.append((c2, 0., c1))
|
117 |
+
|
118 |
+
vertices = np.array(vertices, dtype=np.float32)
|
119 |
+
vertices /= np.linalg.norm([1., phi])
|
120 |
+
|
121 |
+
# I did this manually, checking the orientation one by one.
|
122 |
+
faces = [(0, 1, 2),
|
123 |
+
(0, 6, 1),
|
124 |
+
(8, 0, 2),
|
125 |
+
(8, 4, 0),
|
126 |
+
(3, 8, 2),
|
127 |
+
(3, 2, 7),
|
128 |
+
(7, 2, 1),
|
129 |
+
(0, 4, 6),
|
130 |
+
(4, 11, 6),
|
131 |
+
(6, 11, 5),
|
132 |
+
(1, 5, 7),
|
133 |
+
(4, 10, 11),
|
134 |
+
(4, 8, 10),
|
135 |
+
(10, 8, 3),
|
136 |
+
(10, 3, 9),
|
137 |
+
(11, 10, 9),
|
138 |
+
(11, 9, 5),
|
139 |
+
(5, 9, 7),
|
140 |
+
(9, 3, 7),
|
141 |
+
(1, 6, 5),
|
142 |
+
]
|
143 |
+
|
144 |
+
# By default the top is an aris parallel to the Y axis.
|
145 |
+
# Need to rotate around the y axis by half the supplementary to the
|
146 |
+
# angle between faces divided by two to get the desired orientation.
|
147 |
+
# /O\ (top arist)
|
148 |
+
# / \ Z
|
149 |
+
# (adjacent face)/ \ (adjacent face) ^
|
150 |
+
# / angle_between_faces \ |
|
151 |
+
# / \ |
|
152 |
+
# / \ YO-----> X
|
153 |
+
# This results in:
|
154 |
+
# (adjacent faceis now top plane)
|
155 |
+
# ----------------------O\ (top arist)
|
156 |
+
# \
|
157 |
+
# \
|
158 |
+
# \ (adjacent face)
|
159 |
+
# \
|
160 |
+
# \
|
161 |
+
# \
|
162 |
+
|
163 |
+
angle_between_faces = 2 * np.arcsin(phi / np.sqrt(3))
|
164 |
+
rotation_angle = (np.pi - angle_between_faces) / 2
|
165 |
+
rotation = transform.Rotation.from_euler(seq="y", angles=rotation_angle)
|
166 |
+
rotation_matrix = rotation.as_matrix()
|
167 |
+
vertices = np.dot(vertices, rotation_matrix)
|
168 |
+
|
169 |
+
return TriangularMesh(vertices=vertices.astype(np.float32),
|
170 |
+
faces=np.array(faces, dtype=np.int32))
|
171 |
+
|
172 |
+
|
173 |
+
def _two_split_unit_sphere_triangle_faces(
|
174 |
+
triangular_mesh: TriangularMesh) -> TriangularMesh:
|
175 |
+
"""Splits each triangular face into 4 triangles keeping the orientation."""
|
176 |
+
|
177 |
+
# Every time we split a triangle into 4 we will be adding 3 extra vertices,
|
178 |
+
# located at the edge centres.
|
179 |
+
# This class handles the positioning of the new vertices, and avoids creating
|
180 |
+
# duplicates.
|
181 |
+
new_vertices_builder = _ChildVerticesBuilder(triangular_mesh.vertices)
|
182 |
+
|
183 |
+
new_faces = []
|
184 |
+
for ind1, ind2, ind3 in triangular_mesh.faces:
|
185 |
+
# Transform each triangular face into 4 triangles,
|
186 |
+
# preserving the orientation.
|
187 |
+
# ind3
|
188 |
+
# / \
|
189 |
+
# / \
|
190 |
+
# / #3 \
|
191 |
+
# / \
|
192 |
+
# ind31 -------------- ind23
|
193 |
+
# / \ / \
|
194 |
+
# / \ #4 / \
|
195 |
+
# / #1 \ / #2 \
|
196 |
+
# / \ / \
|
197 |
+
# ind1 ------------ ind12 ------------ ind2
|
198 |
+
ind12 = new_vertices_builder.get_new_child_vertex_index((ind1, ind2))
|
199 |
+
ind23 = new_vertices_builder.get_new_child_vertex_index((ind2, ind3))
|
200 |
+
ind31 = new_vertices_builder.get_new_child_vertex_index((ind3, ind1))
|
201 |
+
# Note how each of the 4 triangular new faces specifies the order of the
|
202 |
+
# vertices to preserve the orientation of the original face. As the input
|
203 |
+
# face should always be counter-clockwise as specified in the diagram,
|
204 |
+
# this means child faces should also be counter-clockwise.
|
205 |
+
new_faces.extend([[ind1, ind12, ind31], # 1
|
206 |
+
[ind12, ind2, ind23], # 2
|
207 |
+
[ind31, ind23, ind3], # 3
|
208 |
+
[ind12, ind23, ind31], # 4
|
209 |
+
])
|
210 |
+
return TriangularMesh(vertices=new_vertices_builder.get_all_vertices(),
|
211 |
+
faces=np.array(new_faces, dtype=np.int32))
|
212 |
+
|
213 |
+
|
214 |
+
class _ChildVerticesBuilder(object):
|
215 |
+
"""Bookkeeping of new child vertices added to an existing set of vertices."""
|
216 |
+
|
217 |
+
def __init__(self, parent_vertices):
|
218 |
+
|
219 |
+
# Because the same new vertex will be required when splitting adjacent
|
220 |
+
# triangles (which share an edge) we keep them in a hash table indexed by
|
221 |
+
# sorted indices of the vertices adjacent to the edge, to avoid creating
|
222 |
+
# duplicated child vertices.
|
223 |
+
self._child_vertices_index_mapping = {}
|
224 |
+
self._parent_vertices = parent_vertices
|
225 |
+
# We start with all previous vertices.
|
226 |
+
self._all_vertices_list = list(parent_vertices)
|
227 |
+
|
228 |
+
def _get_child_vertex_key(self, parent_vertex_indices):
|
229 |
+
return tuple(sorted(parent_vertex_indices))
|
230 |
+
|
231 |
+
def _create_child_vertex(self, parent_vertex_indices):
|
232 |
+
"""Creates a new vertex."""
|
233 |
+
# Position for new vertex is the middle point, between the parent points,
|
234 |
+
# projected to unit sphere.
|
235 |
+
child_vertex_position = self._parent_vertices[
|
236 |
+
list(parent_vertex_indices)].mean(0)
|
237 |
+
child_vertex_position /= np.linalg.norm(child_vertex_position)
|
238 |
+
|
239 |
+
# Add the vertex to the output list. The index for this new vertex will
|
240 |
+
# match the length of the list before adding it.
|
241 |
+
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
|
242 |
+
self._child_vertices_index_mapping[child_vertex_key] = len(
|
243 |
+
self._all_vertices_list)
|
244 |
+
self._all_vertices_list.append(child_vertex_position)
|
245 |
+
|
246 |
+
def get_new_child_vertex_index(self, parent_vertex_indices):
|
247 |
+
"""Returns index for a child vertex, creating it if necessary."""
|
248 |
+
# Get the key to see if we already have a new vertex in the middle.
|
249 |
+
child_vertex_key = self._get_child_vertex_key(parent_vertex_indices)
|
250 |
+
if child_vertex_key not in self._child_vertices_index_mapping:
|
251 |
+
self._create_child_vertex(parent_vertex_indices)
|
252 |
+
return self._child_vertices_index_mapping[child_vertex_key]
|
253 |
+
|
254 |
+
def get_all_vertices(self):
|
255 |
+
"""Returns an array with old vertices."""
|
256 |
+
return np.array(self._all_vertices_list)
|
257 |
+
|
258 |
+
|
259 |
+
def faces_to_edges(faces: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
260 |
+
"""Transforms polygonal faces to sender and receiver indices.
|
261 |
+
|
262 |
+
It does so by transforming every face into N_i edges. Such if the triangular
|
263 |
+
face has indices [0, 1, 2], three edges are added 0->1, 1->2, and 2->0.
|
264 |
+
|
265 |
+
If all faces have consistent orientation, and the surface represented by the
|
266 |
+
faces is closed, then every edge in a polygon with a certain orientation
|
267 |
+
is also part of another polygon with the opposite orientation. In this
|
268 |
+
situation, the edges returned by the method are always bidirectional.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
faces: Integer array of shape [num_faces, 3]. Contains node indices
|
272 |
+
adjacent to each face.
|
273 |
+
Returns:
|
274 |
+
Tuple with sender/receiver indices, each of shape [num_edges=num_faces*3].
|
275 |
+
|
276 |
+
"""
|
277 |
+
assert faces.ndim == 2
|
278 |
+
assert faces.shape[-1] == 3
|
279 |
+
senders = np.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]])
|
280 |
+
receivers = np.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]])
|
281 |
+
return senders, receivers
|
graphcast/icosahedral_mesh_test.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tests for icosahedral_mesh."""
|
15 |
+
|
16 |
+
from absl.testing import absltest
|
17 |
+
from absl.testing import parameterized
|
18 |
+
import chex
|
19 |
+
from graphcast import icosahedral_mesh
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
|
23 |
+
def _get_mesh_spec(splits: int):
|
24 |
+
"""Returns size of the final icosahedral mesh resulting from the splitting."""
|
25 |
+
num_vertices = 12
|
26 |
+
num_faces = 20
|
27 |
+
for _ in range(splits):
|
28 |
+
# Each previous face adds three new vertices, but each vertex is shared
|
29 |
+
# by two faces.
|
30 |
+
num_vertices += num_faces * 3 // 2
|
31 |
+
num_faces *= 4
|
32 |
+
return num_vertices, num_faces
|
33 |
+
|
34 |
+
|
35 |
+
class IcosahedralMeshTest(parameterized.TestCase):
|
36 |
+
|
37 |
+
def test_icosahedron(self):
|
38 |
+
mesh = icosahedral_mesh.get_icosahedron()
|
39 |
+
_assert_valid_mesh(
|
40 |
+
mesh, num_expected_vertices=12, num_expected_faces=20)
|
41 |
+
|
42 |
+
@parameterized.parameters(list(range(5)))
|
43 |
+
def test_get_hierarchy_of_triangular_meshes_for_sphere(self, splits):
|
44 |
+
meshes = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
|
45 |
+
splits=splits)
|
46 |
+
prev_vertices = None
|
47 |
+
for mesh_i, mesh in enumerate(meshes):
|
48 |
+
# Check that `mesh` is valid.
|
49 |
+
num_expected_vertices, num_expected_faces = _get_mesh_spec(mesh_i)
|
50 |
+
_assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces)
|
51 |
+
|
52 |
+
# Check that the first N vertices from this mesh match all of the
|
53 |
+
# vertices from the previous mesh.
|
54 |
+
if prev_vertices is not None:
|
55 |
+
leading_mesh_vertices = mesh.vertices[:prev_vertices.shape[0]]
|
56 |
+
np.testing.assert_array_equal(leading_mesh_vertices, prev_vertices)
|
57 |
+
|
58 |
+
# Increase the expected/previous values for the next iteration.
|
59 |
+
if mesh_i < len(meshes) - 1:
|
60 |
+
prev_vertices = mesh.vertices
|
61 |
+
|
62 |
+
@parameterized.parameters(list(range(4)))
|
63 |
+
def test_merge_meshes(self, splits):
|
64 |
+
mesh_hierarchy = (
|
65 |
+
icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere(
|
66 |
+
splits=splits))
|
67 |
+
mesh = icosahedral_mesh.merge_meshes(mesh_hierarchy)
|
68 |
+
|
69 |
+
expected_faces = np.concatenate([m.faces for m in mesh_hierarchy], axis=0)
|
70 |
+
np.testing.assert_array_equal(mesh.vertices, mesh_hierarchy[-1].vertices)
|
71 |
+
np.testing.assert_array_equal(mesh.faces, expected_faces)
|
72 |
+
|
73 |
+
def test_faces_to_edges(self):
|
74 |
+
|
75 |
+
faces = np.array([[0, 1, 2],
|
76 |
+
[3, 4, 5]])
|
77 |
+
|
78 |
+
# This also documents the order of the edges returned by the method.
|
79 |
+
expected_edges = np.array(
|
80 |
+
[[0, 1],
|
81 |
+
[3, 4],
|
82 |
+
[1, 2],
|
83 |
+
[4, 5],
|
84 |
+
[2, 0],
|
85 |
+
[5, 3]])
|
86 |
+
expected_senders = expected_edges[:, 0]
|
87 |
+
expected_receivers = expected_edges[:, 1]
|
88 |
+
|
89 |
+
senders, receivers = icosahedral_mesh.faces_to_edges(faces)
|
90 |
+
|
91 |
+
np.testing.assert_array_equal(senders, expected_senders)
|
92 |
+
np.testing.assert_array_equal(receivers, expected_receivers)
|
93 |
+
|
94 |
+
|
95 |
+
def _assert_valid_mesh(mesh, num_expected_vertices, num_expected_faces):
|
96 |
+
vertices = mesh.vertices
|
97 |
+
faces = mesh.faces
|
98 |
+
chex.assert_shape(vertices, [num_expected_vertices, 3])
|
99 |
+
chex.assert_shape(faces, [num_expected_faces, 3])
|
100 |
+
|
101 |
+
# Vertices norm should be 1.
|
102 |
+
vertices_norm = np.linalg.norm(vertices, axis=-1)
|
103 |
+
np.testing.assert_allclose(vertices_norm, 1., rtol=1e-6)
|
104 |
+
|
105 |
+
_assert_positive_face_orientation(vertices, faces)
|
106 |
+
|
107 |
+
|
108 |
+
def _assert_positive_face_orientation(vertices, faces):
|
109 |
+
|
110 |
+
# Obtain a unit vector that points, in the direction of the face.
|
111 |
+
face_orientation = np.cross(vertices[faces[:, 1]] - vertices[faces[:, 0]],
|
112 |
+
vertices[faces[:, 2]] - vertices[faces[:, 1]])
|
113 |
+
face_orientation /= np.linalg.norm(face_orientation, axis=-1, keepdims=True)
|
114 |
+
|
115 |
+
# And a unit vector pointing from the origin to the center of the face.
|
116 |
+
face_centers = vertices[faces].mean(1)
|
117 |
+
face_centers /= np.linalg.norm(face_centers, axis=-1, keepdims=True)
|
118 |
+
|
119 |
+
# Positive orientation means those two vectors should be parallel
|
120 |
+
# (dot product, 1), and not anti-parallel (dot product, -1).
|
121 |
+
dot_center_orientation = np.einsum("ik,ik->i", face_orientation, face_centers)
|
122 |
+
|
123 |
+
# Check that the face normal is parallel to the vector that joins the center
|
124 |
+
# of the face to the center of the sphere. Note we need a small tolerance
|
125 |
+
# because some discretizations are not exactly uniform, so it will not be
|
126 |
+
# exactly parallel.
|
127 |
+
np.testing.assert_allclose(dot_center_orientation, 1., atol=6e-4)
|
128 |
+
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
absltest.main()
|
graphcast/losses.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Loss functions (and terms for use in loss functions) used for weather."""
|
15 |
+
|
16 |
+
from typing import Mapping
|
17 |
+
|
18 |
+
from graphcast import xarray_tree
|
19 |
+
import numpy as np
|
20 |
+
from typing_extensions import Protocol
|
21 |
+
import xarray
|
22 |
+
|
23 |
+
|
24 |
+
LossAndDiagnostics = tuple[xarray.DataArray, xarray.Dataset]
|
25 |
+
|
26 |
+
|
27 |
+
class LossFunction(Protocol):
|
28 |
+
"""A loss function.
|
29 |
+
|
30 |
+
This is a protocol so it's fine to use a plain function which 'quacks like'
|
31 |
+
this. This is just to document the interface.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __call__(self,
|
35 |
+
predictions: xarray.Dataset,
|
36 |
+
targets: xarray.Dataset,
|
37 |
+
**optional_kwargs) -> LossAndDiagnostics:
|
38 |
+
"""Computes a loss function.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
predictions: Dataset of predictions.
|
42 |
+
targets: Dataset of targets.
|
43 |
+
**optional_kwargs: Implementations may support extra optional kwargs.
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
loss: A DataArray with dimensions ('batch',) containing losses for each
|
47 |
+
element of the batch. These will be averaged to give the final
|
48 |
+
loss, locally and across replicas.
|
49 |
+
diagnostics: Mapping of additional quantities to log by name alongside the
|
50 |
+
loss. These will will typically correspond to terms in the loss. They
|
51 |
+
should also have dimensions ('batch',) and will be averaged over the
|
52 |
+
batch before logging.
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
def weighted_mse_per_level(
|
57 |
+
predictions: xarray.Dataset,
|
58 |
+
targets: xarray.Dataset,
|
59 |
+
per_variable_weights: Mapping[str, float],
|
60 |
+
) -> LossAndDiagnostics:
|
61 |
+
"""Latitude- and pressure-level-weighted MSE loss."""
|
62 |
+
def loss(prediction, target):
|
63 |
+
loss = (prediction - target)**2
|
64 |
+
loss *= normalized_latitude_weights(target).astype(loss.dtype)
|
65 |
+
if 'level' in target.dims:
|
66 |
+
loss *= normalized_level_weights(target).astype(loss.dtype)
|
67 |
+
return _mean_preserving_batch(loss)
|
68 |
+
|
69 |
+
losses = xarray_tree.map_structure(loss, predictions, targets)
|
70 |
+
return sum_per_variable_losses(losses, per_variable_weights)
|
71 |
+
|
72 |
+
|
73 |
+
def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
|
74 |
+
return x.mean([d for d in x.dims if d != 'batch'], skipna=False)
|
75 |
+
|
76 |
+
|
77 |
+
def sum_per_variable_losses(
|
78 |
+
per_variable_losses: Mapping[str, xarray.DataArray],
|
79 |
+
weights: Mapping[str, float],
|
80 |
+
) -> LossAndDiagnostics:
|
81 |
+
"""Weighted sum of per-variable losses."""
|
82 |
+
if not set(weights.keys()).issubset(set(per_variable_losses.keys())):
|
83 |
+
raise ValueError(
|
84 |
+
'Passing a weight that does not correspond to any variable '
|
85 |
+
f'{set(weights.keys())-set(per_variable_losses.keys())}')
|
86 |
+
|
87 |
+
weighted_per_variable_losses = {
|
88 |
+
name: loss * weights.get(name, 1)
|
89 |
+
for name, loss in per_variable_losses.items()
|
90 |
+
}
|
91 |
+
total = xarray.concat(
|
92 |
+
weighted_per_variable_losses.values(), dim='variable', join='exact').sum(
|
93 |
+
'variable', skipna=False)
|
94 |
+
return total, per_variable_losses # pytype: disable=bad-return-type
|
95 |
+
|
96 |
+
|
97 |
+
def normalized_level_weights(data: xarray.DataArray) -> xarray.DataArray:
|
98 |
+
"""Weights proportional to pressure at each level."""
|
99 |
+
level = data.coords['level']
|
100 |
+
return level / level.mean(skipna=False)
|
101 |
+
|
102 |
+
|
103 |
+
def normalized_latitude_weights(data: xarray.DataArray) -> xarray.DataArray:
|
104 |
+
"""Weights based on latitude, roughly proportional to grid cell area.
|
105 |
+
|
106 |
+
This method supports two use cases only (both for equispaced values):
|
107 |
+
* Latitude values such that the closest value to the pole is at latitude
|
108 |
+
(90 - d_lat/2), where d_lat is the difference between contiguous latitudes.
|
109 |
+
For example: [-89, -87, -85, ..., 85, 87, 89]) (d_lat = 2)
|
110 |
+
In this case each point with `lat` value represents a sphere slice between
|
111 |
+
`lat - d_lat/2` and `lat + d_lat/2`, and the area of this slice would be
|
112 |
+
proportional to:
|
113 |
+
`sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)`, and
|
114 |
+
we can simply omit the term `2 * sin(d_lat/2)` which is just a constant
|
115 |
+
that cancels during normalization.
|
116 |
+
* Latitude values that fall exactly at the poles.
|
117 |
+
For example: [-90, -88, -86, ..., 86, 88, 90]) (d_lat = 2)
|
118 |
+
In this case each point with `lat` value also represents
|
119 |
+
a sphere slice between `lat - d_lat/2` and `lat + d_lat/2`,
|
120 |
+
except for the points at the poles, that represent a slice between
|
121 |
+
`90 - d_lat/2` and `90` or, `-90` and `-90 + d_lat/2`.
|
122 |
+
The areas of the first type of point are still proportional to:
|
123 |
+
* sin(lat + d_lat/2) - sin(lat - d_lat/2) = 2 * sin(d_lat/2) * cos(lat)
|
124 |
+
but for the points at the poles now is:
|
125 |
+
* sin(90) - sin(90 - d_lat/2) = 2 * sin(d_lat/4) ^ 2
|
126 |
+
and we will be using these weights, depending on whether we are looking at
|
127 |
+
pole cells, or non-pole cells (omitting the common factor of 2 which will be
|
128 |
+
absorbed by the normalization).
|
129 |
+
|
130 |
+
It can be shown via a limit, or simple geometry, that in the small angles
|
131 |
+
regime, the proportion of area per pole-point is equal to 1/8th
|
132 |
+
the proportion of area covered by each of the nearest non-pole point, and we
|
133 |
+
test for this in the test.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
data: `DataArray` with latitude coordinates.
|
137 |
+
Returns:
|
138 |
+
Unit mean latitude weights.
|
139 |
+
"""
|
140 |
+
latitude = data.coords['lat']
|
141 |
+
|
142 |
+
if np.any(np.isclose(np.abs(latitude), 90.)):
|
143 |
+
weights = _weight_for_latitude_vector_with_poles(latitude)
|
144 |
+
else:
|
145 |
+
weights = _weight_for_latitude_vector_without_poles(latitude)
|
146 |
+
|
147 |
+
return weights / weights.mean(skipna=False)
|
148 |
+
|
149 |
+
|
150 |
+
def _weight_for_latitude_vector_without_poles(latitude):
|
151 |
+
"""Weights for uniform latitudes of the form [+-90-+d/2, ..., -+90+-d/2]."""
|
152 |
+
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
|
153 |
+
if (not np.isclose(np.max(latitude), 90 - delta_latitude/2) or
|
154 |
+
not np.isclose(np.min(latitude), -90 + delta_latitude/2)):
|
155 |
+
raise ValueError(
|
156 |
+
f'Latitude vector {latitude} does not start/end at '
|
157 |
+
'+- (90 - delta_latitude/2) degrees.')
|
158 |
+
return np.cos(np.deg2rad(latitude))
|
159 |
+
|
160 |
+
|
161 |
+
def _weight_for_latitude_vector_with_poles(latitude):
|
162 |
+
"""Weights for uniform latitudes of the form [+- 90, ..., -+90]."""
|
163 |
+
delta_latitude = np.abs(_check_uniform_spacing_and_get_delta(latitude))
|
164 |
+
if (not np.isclose(np.max(latitude), 90.) or
|
165 |
+
not np.isclose(np.min(latitude), -90.)):
|
166 |
+
raise ValueError(
|
167 |
+
f'Latitude vector {latitude} does not start/end at +- 90 degrees.')
|
168 |
+
weights = np.cos(np.deg2rad(latitude)) * np.sin(np.deg2rad(delta_latitude/2))
|
169 |
+
# The two checks above enough to guarantee that latitudes are sorted, so
|
170 |
+
# the extremes are the poles
|
171 |
+
weights[[0, -1]] = np.sin(np.deg2rad(delta_latitude/4)) ** 2
|
172 |
+
return weights
|
173 |
+
|
174 |
+
|
175 |
+
def _check_uniform_spacing_and_get_delta(vector):
|
176 |
+
diff = np.diff(vector)
|
177 |
+
if not np.all(np.isclose(diff[0], diff)):
|
178 |
+
raise ValueError(f'Vector {diff} is not uniformly spaced.')
|
179 |
+
return diff[0]
|
graphcast/model_utils.py
ADDED
@@ -0,0 +1,724 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utilities for building models."""
|
15 |
+
|
16 |
+
from typing import Mapping, Optional, Tuple
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
from scipy.spatial import transform
|
20 |
+
import xarray
|
21 |
+
|
22 |
+
|
23 |
+
def get_graph_spatial_features(
|
24 |
+
*, node_lat: np.ndarray, node_lon: np.ndarray,
|
25 |
+
senders: np.ndarray, receivers: np.ndarray,
|
26 |
+
add_node_positions: bool,
|
27 |
+
add_node_latitude: bool,
|
28 |
+
add_node_longitude: bool,
|
29 |
+
add_relative_positions: bool,
|
30 |
+
relative_longitude_local_coordinates: bool,
|
31 |
+
relative_latitude_local_coordinates: bool,
|
32 |
+
sine_cosine_encoding: bool = False,
|
33 |
+
encoding_num_freqs: int = 10,
|
34 |
+
encoding_multiplicative_factor: float = 1.2,
|
35 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
36 |
+
"""Computes spatial features for the nodes.
|
37 |
+
|
38 |
+
Args:
|
39 |
+
node_lat: Latitudes in the [-90, 90] interval of shape [num_nodes]
|
40 |
+
node_lon: Longitudes in the [0, 360] interval of shape [num_nodes]
|
41 |
+
senders: Sender indices of shape [num_edges]
|
42 |
+
receivers: Receiver indices of shape [num_edges]
|
43 |
+
add_node_positions: Add unit norm absolute positions.
|
44 |
+
add_node_latitude: Add a feature for latitude (cos(90 - lat))
|
45 |
+
Note even if this is set to False, the model may be able to infer the
|
46 |
+
longitude from relative features, unless
|
47 |
+
`relative_latitude_local_coordinates` is also True, or if there is any
|
48 |
+
bias on the relative edge sizes for different longitudes.
|
49 |
+
add_node_longitude: Add features for longitude (cos(lon), sin(lon)).
|
50 |
+
Note even if this is set to False, the model may be able to infer the
|
51 |
+
longitude from relative features, unless
|
52 |
+
`relative_longitude_local_coordinates` is also True, or if there is any
|
53 |
+
bias on the relative edge sizes for different longitudes.
|
54 |
+
add_relative_positions: Whether to relative positions in R3 to the edges.
|
55 |
+
relative_longitude_local_coordinates: If True, relative positions are
|
56 |
+
computed in a local space where the receiver is at 0 longitude.
|
57 |
+
relative_latitude_local_coordinates: If True, relative positions are
|
58 |
+
computed in a local space where the receiver is at 0 latitude.
|
59 |
+
sine_cosine_encoding: If True, we will transform the node/edge features
|
60 |
+
with sine and cosine functions, similar to NERF.
|
61 |
+
encoding_num_freqs: frequency parameter
|
62 |
+
encoding_multiplicative_factor: used for calculating the frequency.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
|
66 |
+
with node and edge features.
|
67 |
+
|
68 |
+
"""
|
69 |
+
|
70 |
+
num_nodes = node_lat.shape[0]
|
71 |
+
num_edges = senders.shape[0]
|
72 |
+
dtype = node_lat.dtype
|
73 |
+
node_phi, node_theta = lat_lon_deg_to_spherical(node_lat, node_lon)
|
74 |
+
|
75 |
+
# Computing some node features.
|
76 |
+
node_features = []
|
77 |
+
if add_node_positions:
|
78 |
+
# Already in [-1, 1.] range.
|
79 |
+
node_features.extend(spherical_to_cartesian(node_phi, node_theta))
|
80 |
+
|
81 |
+
if add_node_latitude:
|
82 |
+
# Using the cos of theta.
|
83 |
+
# From 1. (north pole) to -1 (south pole).
|
84 |
+
node_features.append(np.cos(node_theta))
|
85 |
+
|
86 |
+
if add_node_longitude:
|
87 |
+
# Using the cos and sin, which is already normalized.
|
88 |
+
node_features.append(np.cos(node_phi))
|
89 |
+
node_features.append(np.sin(node_phi))
|
90 |
+
|
91 |
+
if not node_features:
|
92 |
+
node_features = np.zeros([num_nodes, 0], dtype=dtype)
|
93 |
+
else:
|
94 |
+
node_features = np.stack(node_features, axis=-1)
|
95 |
+
|
96 |
+
# Computing some edge features.
|
97 |
+
edge_features = []
|
98 |
+
|
99 |
+
if add_relative_positions:
|
100 |
+
|
101 |
+
relative_position = get_relative_position_in_receiver_local_coordinates(
|
102 |
+
node_phi=node_phi,
|
103 |
+
node_theta=node_theta,
|
104 |
+
senders=senders,
|
105 |
+
receivers=receivers,
|
106 |
+
latitude_local_coordinates=relative_latitude_local_coordinates,
|
107 |
+
longitude_local_coordinates=relative_longitude_local_coordinates
|
108 |
+
)
|
109 |
+
|
110 |
+
# Note this is L2 distance in 3d space, rather than geodesic distance.
|
111 |
+
relative_edge_distances = np.linalg.norm(
|
112 |
+
relative_position, axis=-1, keepdims=True)
|
113 |
+
|
114 |
+
# Normalize to the maximum edge distance. Note that we expect to always
|
115 |
+
# have an edge that goes in the opposite direction of any given edge
|
116 |
+
# so the distribution of relative positions should be symmetric around
|
117 |
+
# zero. So by scaling by the maximum length, we expect all relative
|
118 |
+
# positions to fall in the [-1., 1.] interval, and all relative distances
|
119 |
+
# to fall in the [0., 1.] interval.
|
120 |
+
max_edge_distance = relative_edge_distances.max()
|
121 |
+
edge_features.append(relative_edge_distances / max_edge_distance)
|
122 |
+
edge_features.append(relative_position / max_edge_distance)
|
123 |
+
|
124 |
+
if not edge_features:
|
125 |
+
edge_features = np.zeros([num_edges, 0], dtype=dtype)
|
126 |
+
else:
|
127 |
+
edge_features = np.concatenate(edge_features, axis=-1)
|
128 |
+
|
129 |
+
if sine_cosine_encoding:
|
130 |
+
def sine_cosine_transform(x: np.ndarray) -> np.ndarray:
|
131 |
+
freqs = encoding_multiplicative_factor**np.arange(encoding_num_freqs)
|
132 |
+
phases = freqs * x[..., None]
|
133 |
+
x_sin = np.sin(phases)
|
134 |
+
x_cos = np.cos(phases)
|
135 |
+
x_cat = np.concatenate([x_sin, x_cos], axis=-1)
|
136 |
+
return x_cat.reshape([x.shape[0], -1])
|
137 |
+
|
138 |
+
node_features = sine_cosine_transform(node_features)
|
139 |
+
edge_features = sine_cosine_transform(edge_features)
|
140 |
+
|
141 |
+
return node_features, edge_features
|
142 |
+
|
143 |
+
|
144 |
+
def lat_lon_to_leading_axes(
|
145 |
+
grid_xarray: xarray.DataArray) -> xarray.DataArray:
|
146 |
+
"""Reorders xarray so lat/lon axes come first."""
|
147 |
+
# leading + ["lat", "lon"] + trailing
|
148 |
+
# to
|
149 |
+
# ["lat", "lon"] + leading + trailing
|
150 |
+
return grid_xarray.transpose("lat", "lon", ...)
|
151 |
+
|
152 |
+
|
153 |
+
def restore_leading_axes(grid_xarray: xarray.DataArray) -> xarray.DataArray:
|
154 |
+
"""Reorders xarray so batch/time/level axes come first (if present)."""
|
155 |
+
|
156 |
+
# ["lat", "lon"] + [(batch,) (time,) (level,)] + trailing
|
157 |
+
# to
|
158 |
+
# [(batch,) (time,) (level,)] + ["lat", "lon"] + trailing
|
159 |
+
|
160 |
+
input_dims = list(grid_xarray.dims)
|
161 |
+
output_dims = list(input_dims)
|
162 |
+
for leading_key in ["level", "time", "batch"]: # reverse order for insert
|
163 |
+
if leading_key in input_dims:
|
164 |
+
output_dims.remove(leading_key)
|
165 |
+
output_dims.insert(0, leading_key)
|
166 |
+
return grid_xarray.transpose(*output_dims)
|
167 |
+
|
168 |
+
|
169 |
+
def lat_lon_deg_to_spherical(node_lat: np.ndarray,
|
170 |
+
node_lon: np.ndarray,
|
171 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
172 |
+
phi = np.deg2rad(node_lon)
|
173 |
+
theta = np.deg2rad(90 - node_lat)
|
174 |
+
return phi, theta
|
175 |
+
|
176 |
+
|
177 |
+
def spherical_to_lat_lon(phi: np.ndarray,
|
178 |
+
theta: np.ndarray,
|
179 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
180 |
+
lon = np.mod(np.rad2deg(phi), 360)
|
181 |
+
lat = 90 - np.rad2deg(theta)
|
182 |
+
return lat, lon
|
183 |
+
|
184 |
+
|
185 |
+
def cartesian_to_spherical(x: np.ndarray,
|
186 |
+
y: np.ndarray,
|
187 |
+
z: np.ndarray,
|
188 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
189 |
+
phi = np.arctan2(y, x)
|
190 |
+
with np.errstate(invalid="ignore"): # circumventing b/253179568
|
191 |
+
theta = np.arccos(z) # Assuming unit radius.
|
192 |
+
return phi, theta
|
193 |
+
|
194 |
+
|
195 |
+
def spherical_to_cartesian(
|
196 |
+
phi: np.ndarray, theta: np.ndarray
|
197 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
198 |
+
# Assuming unit radius.
|
199 |
+
return (np.cos(phi)*np.sin(theta),
|
200 |
+
np.sin(phi)*np.sin(theta),
|
201 |
+
np.cos(theta))
|
202 |
+
|
203 |
+
|
204 |
+
def get_relative_position_in_receiver_local_coordinates(
|
205 |
+
node_phi: np.ndarray,
|
206 |
+
node_theta: np.ndarray,
|
207 |
+
senders: np.ndarray,
|
208 |
+
receivers: np.ndarray,
|
209 |
+
latitude_local_coordinates: bool,
|
210 |
+
longitude_local_coordinates: bool
|
211 |
+
) -> np.ndarray:
|
212 |
+
"""Returns relative position features for the edges.
|
213 |
+
|
214 |
+
The relative positions will be computed in a rotated space for a local
|
215 |
+
coordinate system as defined by the receiver. The relative positions are
|
216 |
+
simply obtained by subtracting sender position minues receiver position in
|
217 |
+
that local coordinate system after the rotation in R^3.
|
218 |
+
|
219 |
+
Args:
|
220 |
+
node_phi: [num_nodes] with polar angles.
|
221 |
+
node_theta: [num_nodes] with azimuthal angles.
|
222 |
+
senders: [num_edges] with indices.
|
223 |
+
receivers: [num_edges] with indices.
|
224 |
+
latitude_local_coordinates: Whether to rotate edges such that in the
|
225 |
+
positions are computed such that the receiver is always at latitude 0.
|
226 |
+
longitude_local_coordinates: Whether to rotate edges such that in the
|
227 |
+
positions are computed such that the receiver is always at longitude 0.
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
Array of relative positions in R3 [num_edges, 3]
|
231 |
+
"""
|
232 |
+
|
233 |
+
node_pos = np.stack(spherical_to_cartesian(node_phi, node_theta), axis=-1)
|
234 |
+
|
235 |
+
# No rotation in this case.
|
236 |
+
if not (latitude_local_coordinates or longitude_local_coordinates):
|
237 |
+
return node_pos[senders] - node_pos[receivers]
|
238 |
+
|
239 |
+
# Get rotation matrices for the local space space for every node.
|
240 |
+
rotation_matrices = get_rotation_matrices_to_local_coordinates(
|
241 |
+
reference_phi=node_phi,
|
242 |
+
reference_theta=node_theta,
|
243 |
+
rotate_latitude=latitude_local_coordinates,
|
244 |
+
rotate_longitude=longitude_local_coordinates)
|
245 |
+
|
246 |
+
# Each edge will be rotated according to the rotation matrix of its receiver
|
247 |
+
# node.
|
248 |
+
edge_rotation_matrices = rotation_matrices[receivers]
|
249 |
+
|
250 |
+
# Rotate all nodes to the rotated space of the corresponding edge.
|
251 |
+
# Note for receivers we can also do the matmul first and the gather second:
|
252 |
+
# ```
|
253 |
+
# receiver_pos_in_rotated_space = rotate_with_matrices(
|
254 |
+
# rotation_matrices, node_pos)[receivers]
|
255 |
+
# ```
|
256 |
+
# which is more efficient, however, we do gather first to keep it more
|
257 |
+
# symmetric with the sender computation.
|
258 |
+
receiver_pos_in_rotated_space = rotate_with_matrices(
|
259 |
+
edge_rotation_matrices, node_pos[receivers])
|
260 |
+
sender_pos_in_in_rotated_space = rotate_with_matrices(
|
261 |
+
edge_rotation_matrices, node_pos[senders])
|
262 |
+
# Note, here, that because the rotated space is chosen according to the
|
263 |
+
# receiver, if:
|
264 |
+
# * latitude_local_coordinates = True: latitude for the receivers will be
|
265 |
+
# 0, that is the z coordinate will always be 0.
|
266 |
+
# * longitude_local_coordinates = True: longitude for the receivers will be
|
267 |
+
# 0, that is the y coordinate will be 0.
|
268 |
+
|
269 |
+
# Now we can just subtract.
|
270 |
+
# Note we are rotating to a local coordinate system, where the y-z axes are
|
271 |
+
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
|
272 |
+
# Note that if both `latitude_local_coordinates` and
|
273 |
+
# `longitude_local_coordinates` are True, and edges are short,
|
274 |
+
# then the difference in x coordinate between sender and receiver
|
275 |
+
# should be small, so we could consider dropping the new x coordinate if
|
276 |
+
# we wanted to the tangent plane, however in doing so
|
277 |
+
# we would lose information about the curvature of the mesh, which may be
|
278 |
+
# important for very coarse meshes.
|
279 |
+
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
|
280 |
+
|
281 |
+
|
282 |
+
def get_rotation_matrices_to_local_coordinates(
|
283 |
+
reference_phi: np.ndarray,
|
284 |
+
reference_theta: np.ndarray,
|
285 |
+
rotate_latitude: bool,
|
286 |
+
rotate_longitude: bool) -> np.ndarray:
|
287 |
+
|
288 |
+
"""Returns a rotation matrix to rotate to a point based on a reference vector.
|
289 |
+
|
290 |
+
The rotation matrix is build such that, a vector in the
|
291 |
+
same coordinate system at the reference point that points towards the pole
|
292 |
+
before the rotation, continues to point towards the pole after the rotation.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
reference_phi: [leading_axis] Polar angles of the reference.
|
296 |
+
reference_theta: [leading_axis] Azimuthal angles of the reference.
|
297 |
+
rotate_latitude: Whether to produce a rotation matrix that would rotate
|
298 |
+
R^3 vectors to zero latitude.
|
299 |
+
rotate_longitude: Whether to produce a rotation matrix that would rotate
|
300 |
+
R^3 vectors to zero longitude.
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
Matrices of shape [leading_axis] such that when applied to the reference
|
304 |
+
position with `rotate_with_matrices(rotation_matrices, reference_pos)`
|
305 |
+
|
306 |
+
* phi goes to 0. if "rotate_longitude" is True.
|
307 |
+
|
308 |
+
* theta goes to np.pi / 2 if "rotate_latitude" is True.
|
309 |
+
|
310 |
+
The rotation consists of:
|
311 |
+
* rotate_latitude = False, rotate_longitude = True:
|
312 |
+
Latitude preserving rotation.
|
313 |
+
* rotate_latitude = True, rotate_longitude = True:
|
314 |
+
Latitude preserving rotation, followed by longitude preserving
|
315 |
+
rotation.
|
316 |
+
* rotate_latitude = True, rotate_longitude = False:
|
317 |
+
Latitude preserving rotation, followed by longitude preserving
|
318 |
+
rotation, and the inverse of the latitude preserving rotation. Note
|
319 |
+
this is computationally different from rotating the longitude only
|
320 |
+
and is. We do it like this, so the polar geodesic curve, continues
|
321 |
+
to be aligned with one of the axis after the rotation.
|
322 |
+
|
323 |
+
"""
|
324 |
+
|
325 |
+
if rotate_longitude and rotate_latitude:
|
326 |
+
|
327 |
+
# We first rotate around the z axis "minus the azimuthal angle", to get the
|
328 |
+
# point with zero longitude
|
329 |
+
azimuthal_rotation = - reference_phi
|
330 |
+
|
331 |
+
# One then we will do a polar rotation (which can be done along the y
|
332 |
+
# axis now that we are at longitude 0.), "minus the polar angle plus 2pi"
|
333 |
+
# to get the point with zero latitude.
|
334 |
+
polar_rotation = - reference_theta + np.pi/2
|
335 |
+
|
336 |
+
return transform.Rotation.from_euler(
|
337 |
+
"zy", np.stack([azimuthal_rotation, polar_rotation],
|
338 |
+
axis=1)).as_matrix()
|
339 |
+
elif rotate_longitude:
|
340 |
+
# Just like the previous case, but applying only the azimuthal rotation.
|
341 |
+
azimuthal_rotation = - reference_phi
|
342 |
+
return transform.Rotation.from_euler("z", -reference_phi).as_matrix()
|
343 |
+
elif rotate_latitude:
|
344 |
+
# Just like the first case, but after doing the polar rotation, undoing
|
345 |
+
# the azimuthal rotation.
|
346 |
+
azimuthal_rotation = - reference_phi
|
347 |
+
polar_rotation = - reference_theta + np.pi/2
|
348 |
+
|
349 |
+
return transform.Rotation.from_euler(
|
350 |
+
"zyz", np.stack(
|
351 |
+
[azimuthal_rotation, polar_rotation, -azimuthal_rotation]
|
352 |
+
, axis=1)).as_matrix()
|
353 |
+
else:
|
354 |
+
raise ValueError(
|
355 |
+
"At least one of longitude and latitude should be rotated.")
|
356 |
+
|
357 |
+
|
358 |
+
def rotate_with_matrices(rotation_matrices: np.ndarray, positions: np.ndarray
|
359 |
+
) -> np.ndarray:
|
360 |
+
return np.einsum("bji,bi->bj", rotation_matrices, positions)
|
361 |
+
|
362 |
+
|
363 |
+
def get_bipartite_graph_spatial_features(
|
364 |
+
*,
|
365 |
+
senders_node_lat: np.ndarray,
|
366 |
+
senders_node_lon: np.ndarray,
|
367 |
+
senders: np.ndarray,
|
368 |
+
receivers_node_lat: np.ndarray,
|
369 |
+
receivers_node_lon: np.ndarray,
|
370 |
+
receivers: np.ndarray,
|
371 |
+
add_node_positions: bool,
|
372 |
+
add_node_latitude: bool,
|
373 |
+
add_node_longitude: bool,
|
374 |
+
add_relative_positions: bool,
|
375 |
+
edge_normalization_factor: Optional[float] = None,
|
376 |
+
relative_longitude_local_coordinates: bool,
|
377 |
+
relative_latitude_local_coordinates: bool,
|
378 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
379 |
+
"""Computes spatial features for the nodes.
|
380 |
+
|
381 |
+
This function is almost identical to `get_graph_spatial_features`. The only
|
382 |
+
difference is that sender nodes and receiver nodes can be in different arrays.
|
383 |
+
This is necessary to enable combination with typed Graph.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
senders_node_lat: Latitudes in the [-90, 90] interval of shape
|
387 |
+
[num_sender_nodes]
|
388 |
+
senders_node_lon: Longitudes in the [0, 360] interval of shape
|
389 |
+
[num_sender_nodes]
|
390 |
+
senders: Sender indices of shape [num_edges], indices in [0,
|
391 |
+
num_sender_nodes)
|
392 |
+
receivers_node_lat: Latitudes in the [-90, 90] interval of shape
|
393 |
+
[num_receiver_nodes]
|
394 |
+
receivers_node_lon: Longitudes in the [0, 360] interval of shape
|
395 |
+
[num_receiver_nodes]
|
396 |
+
receivers: Receiver indices of shape [num_edges], indices in [0,
|
397 |
+
num_receiver_nodes)
|
398 |
+
add_node_positions: Add unit norm absolute positions.
|
399 |
+
add_node_latitude: Add a feature for latitude (cos(90 - lat)) Note even if
|
400 |
+
this is set to False, the model may be able to infer the longitude from
|
401 |
+
relative features, unless `relative_latitude_local_coordinates` is also
|
402 |
+
True, or if there is any bias on the relative edge sizes for different
|
403 |
+
longitudes.
|
404 |
+
add_node_longitude: Add features for longitude (cos(lon), sin(lon)). Note
|
405 |
+
even if this is set to False, the model may be able to infer the longitude
|
406 |
+
from relative features, unless `relative_longitude_local_coordinates` is
|
407 |
+
also True, or if there is any bias on the relative edge sizes for
|
408 |
+
different longitudes.
|
409 |
+
add_relative_positions: Whether to relative positions in R3 to the edges.
|
410 |
+
edge_normalization_factor: Allows explicitly controlling edge normalization.
|
411 |
+
If None, defaults to max edge length. This supports using pre-trained
|
412 |
+
model weights with a different graph structure to what it was trained on.
|
413 |
+
relative_longitude_local_coordinates: If True, relative positions are
|
414 |
+
computed in a local space where the receiver is at 0 longitude.
|
415 |
+
relative_latitude_local_coordinates: If True, relative positions are
|
416 |
+
computed in a local space where the receiver is at 0 latitude.
|
417 |
+
|
418 |
+
Returns:
|
419 |
+
Arrays of shape: [num_nodes, num_features] and [num_edges, num_features].
|
420 |
+
with node and edge features.
|
421 |
+
|
422 |
+
"""
|
423 |
+
|
424 |
+
num_senders = senders_node_lat.shape[0]
|
425 |
+
num_receivers = receivers_node_lat.shape[0]
|
426 |
+
num_edges = senders.shape[0]
|
427 |
+
dtype = senders_node_lat.dtype
|
428 |
+
assert receivers_node_lat.dtype == dtype
|
429 |
+
senders_node_phi, senders_node_theta = lat_lon_deg_to_spherical(
|
430 |
+
senders_node_lat, senders_node_lon)
|
431 |
+
receivers_node_phi, receivers_node_theta = lat_lon_deg_to_spherical(
|
432 |
+
receivers_node_lat, receivers_node_lon)
|
433 |
+
|
434 |
+
# Computing some node features.
|
435 |
+
senders_node_features = []
|
436 |
+
receivers_node_features = []
|
437 |
+
if add_node_positions:
|
438 |
+
# Already in [-1, 1.] range.
|
439 |
+
senders_node_features.extend(
|
440 |
+
spherical_to_cartesian(senders_node_phi, senders_node_theta))
|
441 |
+
receivers_node_features.extend(
|
442 |
+
spherical_to_cartesian(receivers_node_phi, receivers_node_theta))
|
443 |
+
|
444 |
+
if add_node_latitude:
|
445 |
+
# Using the cos of theta.
|
446 |
+
# From 1. (north pole) to -1 (south pole).
|
447 |
+
senders_node_features.append(np.cos(senders_node_theta))
|
448 |
+
receivers_node_features.append(np.cos(receivers_node_theta))
|
449 |
+
|
450 |
+
if add_node_longitude:
|
451 |
+
# Using the cos and sin, which is already normalized.
|
452 |
+
senders_node_features.append(np.cos(senders_node_phi))
|
453 |
+
senders_node_features.append(np.sin(senders_node_phi))
|
454 |
+
|
455 |
+
receivers_node_features.append(np.cos(receivers_node_phi))
|
456 |
+
receivers_node_features.append(np.sin(receivers_node_phi))
|
457 |
+
|
458 |
+
if not senders_node_features:
|
459 |
+
senders_node_features = np.zeros([num_senders, 0], dtype=dtype)
|
460 |
+
receivers_node_features = np.zeros([num_receivers, 0], dtype=dtype)
|
461 |
+
else:
|
462 |
+
senders_node_features = np.stack(senders_node_features, axis=-1)
|
463 |
+
receivers_node_features = np.stack(receivers_node_features, axis=-1)
|
464 |
+
|
465 |
+
# Computing some edge features.
|
466 |
+
edge_features = []
|
467 |
+
|
468 |
+
if add_relative_positions:
|
469 |
+
|
470 |
+
relative_position = get_bipartite_relative_position_in_receiver_local_coordinates( # pylint: disable=line-too-long
|
471 |
+
senders_node_phi=senders_node_phi,
|
472 |
+
senders_node_theta=senders_node_theta,
|
473 |
+
receivers_node_phi=receivers_node_phi,
|
474 |
+
receivers_node_theta=receivers_node_theta,
|
475 |
+
senders=senders,
|
476 |
+
receivers=receivers,
|
477 |
+
latitude_local_coordinates=relative_latitude_local_coordinates,
|
478 |
+
longitude_local_coordinates=relative_longitude_local_coordinates)
|
479 |
+
|
480 |
+
# Note this is L2 distance in 3d space, rather than geodesic distance.
|
481 |
+
relative_edge_distances = np.linalg.norm(
|
482 |
+
relative_position, axis=-1, keepdims=True)
|
483 |
+
|
484 |
+
if edge_normalization_factor is None:
|
485 |
+
# Normalize to the maximum edge distance. Note that we expect to always
|
486 |
+
# have an edge that goes in the opposite direction of any given edge
|
487 |
+
# so the distribution of relative positions should be symmetric around
|
488 |
+
# zero. So by scaling by the maximum length, we expect all relative
|
489 |
+
# positions to fall in the [-1., 1.] interval, and all relative distances
|
490 |
+
# to fall in the [0., 1.] interval.
|
491 |
+
edge_normalization_factor = relative_edge_distances.max()
|
492 |
+
|
493 |
+
edge_features.append(relative_edge_distances / edge_normalization_factor)
|
494 |
+
edge_features.append(relative_position / edge_normalization_factor)
|
495 |
+
|
496 |
+
if not edge_features:
|
497 |
+
edge_features = np.zeros([num_edges, 0], dtype=dtype)
|
498 |
+
else:
|
499 |
+
edge_features = np.concatenate(edge_features, axis=-1)
|
500 |
+
|
501 |
+
return senders_node_features, receivers_node_features, edge_features
|
502 |
+
|
503 |
+
|
504 |
+
def get_bipartite_relative_position_in_receiver_local_coordinates(
|
505 |
+
senders_node_phi: np.ndarray,
|
506 |
+
senders_node_theta: np.ndarray,
|
507 |
+
senders: np.ndarray,
|
508 |
+
receivers_node_phi: np.ndarray,
|
509 |
+
receivers_node_theta: np.ndarray,
|
510 |
+
receivers: np.ndarray,
|
511 |
+
latitude_local_coordinates: bool,
|
512 |
+
longitude_local_coordinates: bool) -> np.ndarray:
|
513 |
+
"""Returns relative position features for the edges.
|
514 |
+
|
515 |
+
This function is equivalent to
|
516 |
+
`get_relative_position_in_receiver_local_coordinates`, but adapted to work
|
517 |
+
with bipartite typed graphs.
|
518 |
+
|
519 |
+
The relative positions will be computed in a rotated space for a local
|
520 |
+
coordinate system as defined by the receiver. The relative positions are
|
521 |
+
simply obtained by subtracting sender position minues receiver position in
|
522 |
+
that local coordinate system after the rotation in R^3.
|
523 |
+
|
524 |
+
Args:
|
525 |
+
senders_node_phi: [num_sender_nodes] with polar angles.
|
526 |
+
senders_node_theta: [num_sender_nodes] with azimuthal angles.
|
527 |
+
senders: [num_edges] with indices into sender nodes.
|
528 |
+
receivers_node_phi: [num_sender_nodes] with polar angles.
|
529 |
+
receivers_node_theta: [num_sender_nodes] with azimuthal angles.
|
530 |
+
receivers: [num_edges] with indices into receiver nodes.
|
531 |
+
latitude_local_coordinates: Whether to rotate edges such that in the
|
532 |
+
positions are computed such that the receiver is always at latitude 0.
|
533 |
+
longitude_local_coordinates: Whether to rotate edges such that in the
|
534 |
+
positions are computed such that the receiver is always at longitude 0.
|
535 |
+
|
536 |
+
Returns:
|
537 |
+
Array of relative positions in R3 [num_edges, 3]
|
538 |
+
"""
|
539 |
+
|
540 |
+
senders_node_pos = np.stack(
|
541 |
+
spherical_to_cartesian(senders_node_phi, senders_node_theta), axis=-1)
|
542 |
+
|
543 |
+
receivers_node_pos = np.stack(
|
544 |
+
spherical_to_cartesian(receivers_node_phi, receivers_node_theta), axis=-1)
|
545 |
+
|
546 |
+
# No rotation in this case.
|
547 |
+
if not (latitude_local_coordinates or longitude_local_coordinates):
|
548 |
+
return senders_node_pos[senders] - receivers_node_pos[receivers]
|
549 |
+
|
550 |
+
# Get rotation matrices for the local space space for every receiver node.
|
551 |
+
receiver_rotation_matrices = get_rotation_matrices_to_local_coordinates(
|
552 |
+
reference_phi=receivers_node_phi,
|
553 |
+
reference_theta=receivers_node_theta,
|
554 |
+
rotate_latitude=latitude_local_coordinates,
|
555 |
+
rotate_longitude=longitude_local_coordinates)
|
556 |
+
|
557 |
+
# Each edge will be rotated according to the rotation matrix of its receiver
|
558 |
+
# node.
|
559 |
+
edge_rotation_matrices = receiver_rotation_matrices[receivers]
|
560 |
+
|
561 |
+
# Rotate all nodes to the rotated space of the corresponding edge.
|
562 |
+
# Note for receivers we can also do the matmul first and the gather second:
|
563 |
+
# ```
|
564 |
+
# receiver_pos_in_rotated_space = rotate_with_matrices(
|
565 |
+
# rotation_matrices, node_pos)[receivers]
|
566 |
+
# ```
|
567 |
+
# which is more efficient, however, we do gather first to keep it more
|
568 |
+
# symmetric with the sender computation.
|
569 |
+
receiver_pos_in_rotated_space = rotate_with_matrices(
|
570 |
+
edge_rotation_matrices, receivers_node_pos[receivers])
|
571 |
+
sender_pos_in_in_rotated_space = rotate_with_matrices(
|
572 |
+
edge_rotation_matrices, senders_node_pos[senders])
|
573 |
+
# Note, here, that because the rotated space is chosen according to the
|
574 |
+
# receiver, if:
|
575 |
+
# * latitude_local_coordinates = True: latitude for the receivers will be
|
576 |
+
# 0, that is the z coordinate will always be 0.
|
577 |
+
# * longitude_local_coordinates = True: longitude for the receivers will be
|
578 |
+
# 0, that is the y coordinate will be 0.
|
579 |
+
|
580 |
+
# Now we can just subtract.
|
581 |
+
# Note we are rotating to a local coordinate system, where the y-z axes are
|
582 |
+
# parallel to a tangent plane to the sphere, but still remain in a 3d space.
|
583 |
+
# Note that if both `latitude_local_coordinates` and
|
584 |
+
# `longitude_local_coordinates` are True, and edges are short,
|
585 |
+
# then the difference in x coordinate between sender and receiver
|
586 |
+
# should be small, so we could consider dropping the new x coordinate if
|
587 |
+
# we wanted to the tangent plane, however in doing so
|
588 |
+
# we would lose information about the curvature of the mesh, which may be
|
589 |
+
# important for very coarse meshes.
|
590 |
+
return sender_pos_in_in_rotated_space - receiver_pos_in_rotated_space
|
591 |
+
|
592 |
+
|
593 |
+
def variable_to_stacked(
|
594 |
+
variable: xarray.Variable,
|
595 |
+
sizes: Mapping[str, int],
|
596 |
+
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
|
597 |
+
) -> xarray.Variable:
|
598 |
+
"""Converts an xarray.Variable to preserved_dims + ("channels",).
|
599 |
+
|
600 |
+
Any dimensions other than those included in preserved_dims get stacked into a
|
601 |
+
final "channels" dimension. If any of the preserved_dims are missing then they
|
602 |
+
are added, with the data broadcast/tiled to match the sizes specified in
|
603 |
+
`sizes`.
|
604 |
+
|
605 |
+
Args:
|
606 |
+
variable: An xarray.Variable.
|
607 |
+
sizes: Mapping including sizes for any dimensions which are not present in
|
608 |
+
`variable` but are needed for the output. This may be needed for example
|
609 |
+
for a static variable with only ("lat", "lon") dims, or if you want to
|
610 |
+
encode just the latitude coordinates (a variable with dims ("lat",)).
|
611 |
+
preserved_dims: dimensions of variable to not be folded in channels.
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
An xarray.Variable with dimensions preserved_dims + ("channels",).
|
615 |
+
"""
|
616 |
+
stack_to_channels_dims = [
|
617 |
+
d for d in variable.dims if d not in preserved_dims]
|
618 |
+
if stack_to_channels_dims:
|
619 |
+
variable = variable.stack(channels=stack_to_channels_dims)
|
620 |
+
dims = {dim: variable.sizes.get(dim) or sizes[dim] for dim in preserved_dims}
|
621 |
+
dims["channels"] = variable.sizes.get("channels", 1)
|
622 |
+
return variable.set_dims(dims)
|
623 |
+
|
624 |
+
|
625 |
+
def dataset_to_stacked(
|
626 |
+
dataset: xarray.Dataset,
|
627 |
+
sizes: Optional[Mapping[str, int]] = None,
|
628 |
+
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
|
629 |
+
) -> xarray.DataArray:
|
630 |
+
"""Converts an xarray.Dataset to a single stacked array.
|
631 |
+
|
632 |
+
This takes each consistuent data_var, converts it into BHWC layout
|
633 |
+
using `variable_to_stacked`, then concats them all along the channels axis.
|
634 |
+
|
635 |
+
Args:
|
636 |
+
dataset: An xarray.Dataset.
|
637 |
+
sizes: Mapping including sizes for any dimensions which are not present in
|
638 |
+
the `dataset` but are needed for the output. See variable_to_stacked.
|
639 |
+
preserved_dims: dimensions from the dataset that should not be folded in
|
640 |
+
the predictions channels.
|
641 |
+
|
642 |
+
Returns:
|
643 |
+
An xarray.DataArray with dimensions preserved_dims + ("channels",).
|
644 |
+
Existing coordinates for preserved_dims axes will be preserved, however
|
645 |
+
there will be no coordinates for "channels".
|
646 |
+
"""
|
647 |
+
data_vars = [
|
648 |
+
variable_to_stacked(dataset.variables[name], sizes or dataset.sizes,
|
649 |
+
preserved_dims)
|
650 |
+
for name in sorted(dataset.data_vars.keys())
|
651 |
+
]
|
652 |
+
coords = {
|
653 |
+
dim: coord
|
654 |
+
for dim, coord in dataset.coords.items()
|
655 |
+
if dim in preserved_dims
|
656 |
+
}
|
657 |
+
return xarray.DataArray(
|
658 |
+
data=xarray.Variable.concat(data_vars, dim="channels"), coords=coords)
|
659 |
+
|
660 |
+
|
661 |
+
def stacked_to_dataset(
|
662 |
+
stacked_array: xarray.Variable,
|
663 |
+
template_dataset: xarray.Dataset,
|
664 |
+
preserved_dims: Tuple[str, ...] = ("batch", "lat", "lon"),
|
665 |
+
) -> xarray.Dataset:
|
666 |
+
"""The inverse of dataset_to_stacked.
|
667 |
+
|
668 |
+
Requires a template dataset to demonstrate the variables/shapes/coordinates
|
669 |
+
required.
|
670 |
+
All variables must have preserved_dims dimensions.
|
671 |
+
|
672 |
+
Args:
|
673 |
+
stacked_array: Data in BHWC layout, encoded the same as dataset_to_stacked
|
674 |
+
would if it was asked to encode `template_dataset`.
|
675 |
+
template_dataset: A template Dataset (or other mapping of DataArrays)
|
676 |
+
demonstrating the shape of output required (variables, shapes,
|
677 |
+
coordinates etc).
|
678 |
+
preserved_dims: dimensions from the target_template that were not folded in
|
679 |
+
the predictions channels. The preserved_dims need to be a subset of the
|
680 |
+
dims of all the variables of template_dataset.
|
681 |
+
|
682 |
+
Returns:
|
683 |
+
An xarray.Dataset (or other mapping of DataArrays) with the same shape and
|
684 |
+
type as template_dataset.
|
685 |
+
"""
|
686 |
+
unstack_from_channels_sizes = {}
|
687 |
+
var_names = sorted(template_dataset.keys())
|
688 |
+
for name in var_names:
|
689 |
+
template_var = template_dataset[name]
|
690 |
+
if not all(dim in template_var.dims for dim in preserved_dims):
|
691 |
+
raise ValueError(
|
692 |
+
f"stacked_to_dataset requires all Variables to have {preserved_dims} "
|
693 |
+
f"dimensions, but found only {template_var.dims}.")
|
694 |
+
unstack_from_channels_sizes[name] = {
|
695 |
+
dim: size for dim, size in template_var.sizes.items()
|
696 |
+
if dim not in preserved_dims}
|
697 |
+
|
698 |
+
channels = {name: np.prod(list(unstack_sizes.values()), dtype=np.int64)
|
699 |
+
for name, unstack_sizes in unstack_from_channels_sizes.items()}
|
700 |
+
total_expected_channels = sum(channels.values())
|
701 |
+
found_channels = stacked_array.sizes["channels"]
|
702 |
+
if total_expected_channels != found_channels:
|
703 |
+
raise ValueError(
|
704 |
+
f"Expected {total_expected_channels} channels but found "
|
705 |
+
f"{found_channels}, when trying to convert a stacked array of shape "
|
706 |
+
f"{stacked_array.sizes} to a dataset of shape {template_dataset}.")
|
707 |
+
|
708 |
+
data_vars = {}
|
709 |
+
index = 0
|
710 |
+
for name in var_names:
|
711 |
+
template_var = template_dataset[name]
|
712 |
+
var = stacked_array.isel({"channels": slice(index, index + channels[name])})
|
713 |
+
index += channels[name]
|
714 |
+
var = var.unstack({"channels": unstack_from_channels_sizes[name]})
|
715 |
+
var = var.transpose(*template_var.dims)
|
716 |
+
data_vars[name] = xarray.DataArray(
|
717 |
+
data=var,
|
718 |
+
coords=template_var.coords,
|
719 |
+
# This might not always be the same as the name it's keyed under; it
|
720 |
+
# will refer to the original variable name, whereas the key might be
|
721 |
+
# some alias e.g. temperature_850 under which it should be logged:
|
722 |
+
name=template_var.name,
|
723 |
+
)
|
724 |
+
return type(template_dataset)(data_vars) # pytype:disable=not-callable,wrong-arg-count
|
graphcast/normalization.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Wrappers for Predictors which allow them to work with normalized data.
|
15 |
+
|
16 |
+
The Predictor which is wrapped sees normalized inputs and targets, and makes
|
17 |
+
normalized predictions. The wrapper handles translating the predictions back
|
18 |
+
to the original domain.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import logging
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
from graphcast import predictor_base
|
25 |
+
from graphcast import xarray_tree
|
26 |
+
import xarray
|
27 |
+
|
28 |
+
|
29 |
+
def normalize(values: xarray.Dataset,
|
30 |
+
scales: xarray.Dataset,
|
31 |
+
locations: Optional[xarray.Dataset],
|
32 |
+
) -> xarray.Dataset:
|
33 |
+
"""Normalize variables using the given scales and (optionally) locations."""
|
34 |
+
def normalize_array(array):
|
35 |
+
if array.name is None:
|
36 |
+
raise ValueError(
|
37 |
+
"Can't look up normalization constants because array has no name.")
|
38 |
+
if locations is not None:
|
39 |
+
if array.name in locations:
|
40 |
+
array = array - locations[array.name].astype(array.dtype)
|
41 |
+
else:
|
42 |
+
logging.warning('No normalization location found for %s', array.name)
|
43 |
+
if array.name in scales:
|
44 |
+
array = array / scales[array.name].astype(array.dtype)
|
45 |
+
else:
|
46 |
+
logging.warning('No normalization scale found for %s', array.name)
|
47 |
+
return array
|
48 |
+
return xarray_tree.map_structure(normalize_array, values)
|
49 |
+
|
50 |
+
|
51 |
+
def unnormalize(values: xarray.Dataset,
|
52 |
+
scales: xarray.Dataset,
|
53 |
+
locations: Optional[xarray.Dataset],
|
54 |
+
) -> xarray.Dataset:
|
55 |
+
"""Unnormalize variables using the given scales and (optionally) locations."""
|
56 |
+
def unnormalize_array(array):
|
57 |
+
if array.name is None:
|
58 |
+
raise ValueError(
|
59 |
+
"Can't look up normalization constants because array has no name.")
|
60 |
+
if array.name in scales:
|
61 |
+
array = array * scales[array.name].astype(array.dtype)
|
62 |
+
else:
|
63 |
+
logging.warning('No normalization scale found for %s', array.name)
|
64 |
+
if locations is not None:
|
65 |
+
if array.name in locations:
|
66 |
+
array = array + locations[array.name].astype(array.dtype)
|
67 |
+
else:
|
68 |
+
logging.warning('No normalization location found for %s', array.name)
|
69 |
+
return array
|
70 |
+
return xarray_tree.map_structure(unnormalize_array, values)
|
71 |
+
|
72 |
+
|
73 |
+
class InputsAndResiduals(predictor_base.Predictor):
|
74 |
+
"""Wraps with a residual connection, normalizing inputs and target residuals.
|
75 |
+
|
76 |
+
The inner predictor is given inputs that are normalized using `locations`
|
77 |
+
and `scales` to roughly zero-mean unit variance.
|
78 |
+
|
79 |
+
For target variables that are present in the inputs, the inner predictor is
|
80 |
+
trained to predict residuals (target - last_frame_of_input) that have been
|
81 |
+
normalized using `residual_scales` (and optionally `residual_locations`) to
|
82 |
+
roughly unit variance / zero mean.
|
83 |
+
|
84 |
+
This replaces `residual.Predictor` in the case where you want normalization
|
85 |
+
that's based on the scales of the residuals.
|
86 |
+
|
87 |
+
Since we return the underlying predictor's loss on the normalized residuals,
|
88 |
+
if the underlying predictor is a sum of per-variable losses, the normalization
|
89 |
+
will affect the relative weighting of the per-variable loss terms (hopefully
|
90 |
+
in a good way).
|
91 |
+
|
92 |
+
For target variables *not* present in the inputs, the inner predictor is
|
93 |
+
trained to predict targets directly, that have been normalized in the same
|
94 |
+
way as the inputs.
|
95 |
+
|
96 |
+
The transforms applied to the targets (the residual connection and the
|
97 |
+
normalization) are applied in reverse to the predictions before returning
|
98 |
+
them.
|
99 |
+
"""
|
100 |
+
|
101 |
+
def __init__(
|
102 |
+
self,
|
103 |
+
predictor: predictor_base.Predictor,
|
104 |
+
stddev_by_level: xarray.Dataset,
|
105 |
+
mean_by_level: xarray.Dataset,
|
106 |
+
diffs_stddev_by_level: xarray.Dataset):
|
107 |
+
self._predictor = predictor
|
108 |
+
self._scales = stddev_by_level
|
109 |
+
self._locations = mean_by_level
|
110 |
+
self._residual_scales = diffs_stddev_by_level
|
111 |
+
self._residual_locations = None
|
112 |
+
|
113 |
+
def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction):
|
114 |
+
if norm_prediction.sizes.get('time') != 1:
|
115 |
+
raise ValueError(
|
116 |
+
'normalization.InputsAndResiduals only supports predicting a '
|
117 |
+
'single timestep.')
|
118 |
+
if norm_prediction.name in inputs:
|
119 |
+
# Residuals are assumed to be predicted as normalized (unit variance),
|
120 |
+
# but the scale and location they need mapping to is that of the residuals
|
121 |
+
# not of the values themselves.
|
122 |
+
prediction = unnormalize(
|
123 |
+
norm_prediction, self._residual_scales, self._residual_locations)
|
124 |
+
# A prediction for which we have a corresponding input -- we are
|
125 |
+
# predicting the residual:
|
126 |
+
last_input = inputs[norm_prediction.name].isel(time=-1)
|
127 |
+
prediction = prediction + last_input
|
128 |
+
return prediction
|
129 |
+
else:
|
130 |
+
# A predicted variable which is not an input variable. We are predicting
|
131 |
+
# it directly, so unnormalize it directly to the target scale/location:
|
132 |
+
return unnormalize(norm_prediction, self._scales, self._locations)
|
133 |
+
|
134 |
+
def _subtract_input_and_normalize_target(self, inputs, target):
|
135 |
+
if target.sizes.get('time') != 1:
|
136 |
+
raise ValueError(
|
137 |
+
'normalization.InputsAndResiduals only supports wrapping predictors'
|
138 |
+
'that predict a single timestep.')
|
139 |
+
if target.name in inputs:
|
140 |
+
target_residual = target
|
141 |
+
last_input = inputs[target.name].isel(time=-1)
|
142 |
+
target_residual = target_residual - last_input
|
143 |
+
return normalize(
|
144 |
+
target_residual, self._residual_scales, self._residual_locations)
|
145 |
+
else:
|
146 |
+
return normalize(target, self._scales, self._locations)
|
147 |
+
|
148 |
+
def __call__(self,
|
149 |
+
inputs: xarray.Dataset,
|
150 |
+
targets_template: xarray.Dataset,
|
151 |
+
forcings: xarray.Dataset,
|
152 |
+
**kwargs
|
153 |
+
) -> xarray.Dataset:
|
154 |
+
norm_inputs = normalize(inputs, self._scales, self._locations)
|
155 |
+
norm_forcings = normalize(forcings, self._scales, self._locations)
|
156 |
+
norm_predictions = self._predictor(
|
157 |
+
norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
|
158 |
+
return xarray_tree.map_structure(
|
159 |
+
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
|
160 |
+
norm_predictions)
|
161 |
+
|
162 |
+
def loss(self,
|
163 |
+
inputs: xarray.Dataset,
|
164 |
+
targets: xarray.Dataset,
|
165 |
+
forcings: xarray.Dataset,
|
166 |
+
**kwargs,
|
167 |
+
) -> predictor_base.LossAndDiagnostics:
|
168 |
+
"""Returns the loss computed on normalized inputs and targets."""
|
169 |
+
norm_inputs = normalize(inputs, self._scales, self._locations)
|
170 |
+
norm_forcings = normalize(forcings, self._scales, self._locations)
|
171 |
+
norm_target_residuals = xarray_tree.map_structure(
|
172 |
+
lambda t: self._subtract_input_and_normalize_target(inputs, t),
|
173 |
+
targets)
|
174 |
+
return self._predictor.loss(
|
175 |
+
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
|
176 |
+
|
177 |
+
def loss_and_predictions( # pytype: disable=signature-mismatch # jax-ndarray
|
178 |
+
self,
|
179 |
+
inputs: xarray.Dataset,
|
180 |
+
targets: xarray.Dataset,
|
181 |
+
forcings: xarray.Dataset,
|
182 |
+
**kwargs,
|
183 |
+
) -> Tuple[predictor_base.LossAndDiagnostics,
|
184 |
+
xarray.Dataset]:
|
185 |
+
"""The loss computed on normalized data, with unnormalized predictions."""
|
186 |
+
norm_inputs = normalize(inputs, self._scales, self._locations)
|
187 |
+
norm_forcings = normalize(forcings, self._scales, self._locations)
|
188 |
+
norm_target_residuals = xarray_tree.map_structure(
|
189 |
+
lambda t: self._subtract_input_and_normalize_target(inputs, t),
|
190 |
+
targets)
|
191 |
+
(loss, scalars), norm_predictions = self._predictor.loss_and_predictions(
|
192 |
+
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)
|
193 |
+
predictions = xarray_tree.map_structure(
|
194 |
+
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
|
195 |
+
norm_predictions)
|
196 |
+
return (loss, scalars), predictions
|
graphcast/predictor_base.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Abstract base classes for an xarray-based Predictor API."""
|
15 |
+
|
16 |
+
import abc
|
17 |
+
|
18 |
+
from typing import Tuple
|
19 |
+
|
20 |
+
from graphcast import losses
|
21 |
+
from graphcast import xarray_jax
|
22 |
+
import jax.numpy as jnp
|
23 |
+
import xarray
|
24 |
+
|
25 |
+
LossAndDiagnostics = losses.LossAndDiagnostics
|
26 |
+
|
27 |
+
|
28 |
+
class Predictor(abc.ABC):
|
29 |
+
"""A possibly-trainable predictor of weather, exposing an xarray-based API.
|
30 |
+
|
31 |
+
Typically wraps an underlying JAX model and handles translating the xarray
|
32 |
+
Dataset values to and from plain JAX arrays that are convenient for input to
|
33 |
+
(and output from) the underlying model.
|
34 |
+
|
35 |
+
Different subclasses may exist to wrap different kinds of underlying model,
|
36 |
+
e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
|
37 |
+
inputs/outputs, autoregressive models.
|
38 |
+
|
39 |
+
You can also implement a specific model directly as a Predictor if you want,
|
40 |
+
for example if it has quite specific/unique requirements for its input/output
|
41 |
+
or loss function, or if it's convenient to implement directly using xarray.
|
42 |
+
"""
|
43 |
+
|
44 |
+
@abc.abstractmethod
|
45 |
+
def __call__(self,
|
46 |
+
inputs: xarray.Dataset,
|
47 |
+
targets_template: xarray.Dataset,
|
48 |
+
forcings: xarray.Dataset,
|
49 |
+
**optional_kwargs
|
50 |
+
) -> xarray.Dataset:
|
51 |
+
"""Makes predictions.
|
52 |
+
|
53 |
+
This is only used by the Experiment for inference / evaluation, with
|
54 |
+
training going via the .loss method. So it should default to making
|
55 |
+
predictions for evaluation, although you can also support making predictions
|
56 |
+
for use in the loss via an is_training argument -- see
|
57 |
+
LossFunctionPredictor which helps with that.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
inputs: An xarray.Dataset of inputs.
|
61 |
+
targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
|
62 |
+
with the same shape as the targets, to demonstrate what kind of
|
63 |
+
predictions are required. You can use this to determine which variables,
|
64 |
+
levels and lead times must be predicted.
|
65 |
+
You are free to raise an error if you don't support predicting what is
|
66 |
+
requested.
|
67 |
+
forcings: An xarray.Dataset of forcings terms. Forcings are variables
|
68 |
+
that can be fed to the model, but do not need to be predicted. This is
|
69 |
+
often because this variable can be computed analytically (e.g. the toa
|
70 |
+
radiation of the sun is mostly a function of geometry) or are considered
|
71 |
+
to be controlled for the experiment (e.g., impose a scenario of C02
|
72 |
+
emission into the atmosphere). Unlike `inputs`, the `forcings` can
|
73 |
+
include information "from the future", that is, information at target
|
74 |
+
times specified in the `targets_template`.
|
75 |
+
**optional_kwargs: Implementations may support extra optional kwargs,
|
76 |
+
provided they set appropriate defaults for them.
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Predictions, as an xarray.Dataset or other mapping of DataArrays which
|
80 |
+
is capable of being evaluated against targets with shape given by
|
81 |
+
targets_template.
|
82 |
+
For probabilistic predictors which can return multiple samples from a
|
83 |
+
predictive distribution, these should (by convention) be returned along
|
84 |
+
an additional 'sample' dimension.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def loss(self,
|
88 |
+
inputs: xarray.Dataset,
|
89 |
+
targets: xarray.Dataset,
|
90 |
+
forcings: xarray.Dataset,
|
91 |
+
**optional_kwargs,
|
92 |
+
) -> LossAndDiagnostics:
|
93 |
+
"""Computes a training loss, for predictors that are trainable.
|
94 |
+
|
95 |
+
Why make this the Predictor's responsibility, rather than letting callers
|
96 |
+
compute their own loss function using predictions obtained from
|
97 |
+
Predictor.__call__?
|
98 |
+
|
99 |
+
Doing it this way gives Predictors more control over their training setup.
|
100 |
+
For example, some predictors may wish to train using different targets to
|
101 |
+
the ones they predict at evaluation time -- perhaps different lead times and
|
102 |
+
variables, perhaps training to predict transformed versions of targets
|
103 |
+
where the transform needs to be inverted at evaluation time, etc.
|
104 |
+
|
105 |
+
It's also necessary for generative models (VAEs, GANs, ...) where the
|
106 |
+
training loss is more complex and isn't expressible as a parameter-free
|
107 |
+
function of predictions and targets.
|
108 |
+
|
109 |
+
Args:
|
110 |
+
inputs: An xarray.Dataset.
|
111 |
+
targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
|
112 |
+
docs on __call__ for an explanation about the targets.
|
113 |
+
forcings: xarray.Dataset of forcing terms.
|
114 |
+
**optional_kwargs: Implementations may support extra optional kwargs,
|
115 |
+
provided they set appropriate defaults for them.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
loss: A DataArray with dimensions ('batch',) containing losses for each
|
119 |
+
element of the batch. These will be averaged to give the final
|
120 |
+
loss, locally and across replicas.
|
121 |
+
diagnostics: Mapping of additional quantities to log by name alongside the
|
122 |
+
loss. These will will typically correspond to terms in the loss. They
|
123 |
+
should also have dimensions ('batch',) and will be averaged over the
|
124 |
+
batch before logging.
|
125 |
+
You need not include the loss itself in this dict; it will be added for
|
126 |
+
you.
|
127 |
+
"""
|
128 |
+
del targets, forcings, optional_kwargs
|
129 |
+
batch_size = inputs.sizes['batch']
|
130 |
+
dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
|
131 |
+
return dummy_loss, {} # pytype: disable=bad-return-type
|
132 |
+
|
133 |
+
def loss_and_predictions(
|
134 |
+
self,
|
135 |
+
inputs: xarray.Dataset,
|
136 |
+
targets: xarray.Dataset,
|
137 |
+
forcings: xarray.Dataset,
|
138 |
+
**optional_kwargs,
|
139 |
+
) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
|
140 |
+
"""Like .loss but also returns corresponding predictions.
|
141 |
+
|
142 |
+
Implementing this is optional as it's not used directly by the Experiment,
|
143 |
+
but it is required by autoregressive.Predictor when applying an inner
|
144 |
+
Predictor autoregressively at training time; we need a loss at each step but
|
145 |
+
also predictions to feed back in for the next step.
|
146 |
+
|
147 |
+
Note the loss itself may not be directly regressing the predictions towards
|
148 |
+
targets, the loss may be computed in terms of transformed predictions and
|
149 |
+
targets (or in some other way). For this reason we can't always cleanly
|
150 |
+
separate this into step 1: get predictions, step 2: compute loss from them,
|
151 |
+
hence the need for this combined method.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
inputs:
|
155 |
+
targets:
|
156 |
+
forcings:
|
157 |
+
**optional_kwargs:
|
158 |
+
As for self.loss.
|
159 |
+
|
160 |
+
Returns:
|
161 |
+
(loss, diagnostics)
|
162 |
+
As for self.loss
|
163 |
+
predictions:
|
164 |
+
The predictions which the loss relates to. These should be of the same
|
165 |
+
shape as what you would get from
|
166 |
+
`self.__call__(inputs, targets_template=targets)`, and should be in the
|
167 |
+
same 'domain' as the inputs (i.e. they shouldn't be transformed
|
168 |
+
differently to how the predictor expects its inputs).
|
169 |
+
"""
|
170 |
+
raise NotImplementedError
|
graphcast/rollout.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utils for rolling out models."""
|
15 |
+
|
16 |
+
from typing import Iterator
|
17 |
+
|
18 |
+
from absl import logging
|
19 |
+
import chex
|
20 |
+
import dask.array
|
21 |
+
from graphcast import xarray_tree
|
22 |
+
import jax
|
23 |
+
import numpy as np
|
24 |
+
import typing_extensions
|
25 |
+
import xarray
|
26 |
+
|
27 |
+
|
28 |
+
class PredictorFn(typing_extensions.Protocol):
|
29 |
+
"""Functional version of base.Predictor.__call__ with explicit rng."""
|
30 |
+
|
31 |
+
def __call__(
|
32 |
+
self, rng: chex.PRNGKey, inputs: xarray.Dataset,
|
33 |
+
targets_template: xarray.Dataset,
|
34 |
+
forcings: xarray.Dataset,
|
35 |
+
**optional_kwargs,
|
36 |
+
) -> xarray.Dataset:
|
37 |
+
...
|
38 |
+
|
39 |
+
|
40 |
+
def chunked_prediction(
|
41 |
+
predictor_fn: PredictorFn,
|
42 |
+
rng: chex.PRNGKey,
|
43 |
+
inputs: xarray.Dataset,
|
44 |
+
targets_template: xarray.Dataset,
|
45 |
+
forcings: xarray.Dataset,
|
46 |
+
num_steps_per_chunk: int = 1,
|
47 |
+
verbose: bool = False,
|
48 |
+
) -> xarray.Dataset:
|
49 |
+
"""Outputs a long trajectory by iteratively concatenating chunked predictions.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
predictor_fn: Function to use to make predictions for each chunk.
|
53 |
+
rng: Random key.
|
54 |
+
inputs: Inputs for the model.
|
55 |
+
targets_template: Template for the target prediction, requires targets
|
56 |
+
equispaced in time.
|
57 |
+
forcings: Optional forcing for the model.
|
58 |
+
num_steps_per_chunk: How many of the steps in `targets_template` to predict
|
59 |
+
at each call of `predictor_fn`. It must evenly divide the number of
|
60 |
+
steps in `targets_template`.
|
61 |
+
verbose: Whether to log the current chunk being predicted.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
Predictions for the targets template.
|
65 |
+
|
66 |
+
"""
|
67 |
+
chunks_list = []
|
68 |
+
for prediction_chunk in chunked_prediction_generator(
|
69 |
+
predictor_fn=predictor_fn,
|
70 |
+
rng=rng,
|
71 |
+
inputs=inputs,
|
72 |
+
targets_template=targets_template,
|
73 |
+
forcings=forcings,
|
74 |
+
num_steps_per_chunk=num_steps_per_chunk,
|
75 |
+
verbose=verbose):
|
76 |
+
chunks_list.append(jax.device_get(prediction_chunk))
|
77 |
+
return xarray.concat(chunks_list, dim="time")
|
78 |
+
|
79 |
+
|
80 |
+
def chunked_prediction_generator(
|
81 |
+
predictor_fn: PredictorFn,
|
82 |
+
rng: chex.PRNGKey,
|
83 |
+
inputs: xarray.Dataset,
|
84 |
+
targets_template: xarray.Dataset,
|
85 |
+
forcings: xarray.Dataset,
|
86 |
+
num_steps_per_chunk: int = 1,
|
87 |
+
verbose: bool = False,
|
88 |
+
) -> Iterator[xarray.Dataset]:
|
89 |
+
"""Outputs a long trajectory by yielding chunked predictions.
|
90 |
+
|
91 |
+
Args:
|
92 |
+
predictor_fn: Function to use to make predictions for each chunk.
|
93 |
+
rng: Random key.
|
94 |
+
inputs: Inputs for the model.
|
95 |
+
targets_template: Template for the target prediction, requires targets
|
96 |
+
equispaced in time.
|
97 |
+
forcings: Optional forcing for the model.
|
98 |
+
num_steps_per_chunk: How many of the steps in `targets_template` to predict
|
99 |
+
at each call of `predictor_fn`. It must evenly divide the number of
|
100 |
+
steps in `targets_template`.
|
101 |
+
verbose: Whether to log the current chunk being predicted.
|
102 |
+
|
103 |
+
Yields:
|
104 |
+
The predictions for each chunked step of the chunked rollout, such as
|
105 |
+
if all predictions are concatenated in time this would match the targets
|
106 |
+
template in structure.
|
107 |
+
|
108 |
+
"""
|
109 |
+
|
110 |
+
# Create copies to avoid mutating inputs.
|
111 |
+
inputs = xarray.Dataset(inputs)
|
112 |
+
targets_template = xarray.Dataset(targets_template)
|
113 |
+
forcings = xarray.Dataset(forcings)
|
114 |
+
|
115 |
+
if "datetime" in inputs.coords:
|
116 |
+
del inputs.coords["datetime"]
|
117 |
+
|
118 |
+
if "datetime" in targets_template.coords:
|
119 |
+
output_datetime = targets_template.coords["datetime"]
|
120 |
+
del targets_template.coords["datetime"]
|
121 |
+
else:
|
122 |
+
output_datetime = None
|
123 |
+
|
124 |
+
if "datetime" in forcings.coords:
|
125 |
+
del forcings.coords["datetime"]
|
126 |
+
|
127 |
+
num_target_steps = targets_template.dims["time"]
|
128 |
+
num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk)
|
129 |
+
if remainder != 0:
|
130 |
+
raise ValueError(
|
131 |
+
f"The number of steps per chunk {num_steps_per_chunk} must "
|
132 |
+
f"evenly divide the number of target steps {num_target_steps} ")
|
133 |
+
|
134 |
+
if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1:
|
135 |
+
raise ValueError("The targets time coordinates must be evenly spaced")
|
136 |
+
|
137 |
+
# Our template targets will always have a time axis corresponding for the
|
138 |
+
# timedeltas for the first chunk.
|
139 |
+
targets_chunk_time = targets_template.time.isel(
|
140 |
+
time=slice(0, num_steps_per_chunk))
|
141 |
+
|
142 |
+
current_inputs = inputs
|
143 |
+
for chunk_index in range(num_chunks):
|
144 |
+
if verbose:
|
145 |
+
logging.info("Chunk %d/%d", chunk_index, num_chunks)
|
146 |
+
logging.flush()
|
147 |
+
|
148 |
+
# Select targets for the time period that we are predicting for this chunk.
|
149 |
+
target_offset = num_steps_per_chunk * chunk_index
|
150 |
+
target_slice = slice(target_offset, target_offset + num_steps_per_chunk)
|
151 |
+
current_targets_template = targets_template.isel(time=target_slice)
|
152 |
+
|
153 |
+
# Replace the timedelta, by the one corresponding to the first chunk, so we
|
154 |
+
# don't recompile at every iteration, keeping the
|
155 |
+
actual_target_time = current_targets_template.coords["time"]
|
156 |
+
current_targets_template = current_targets_template.assign_coords(
|
157 |
+
time=targets_chunk_time).compute()
|
158 |
+
|
159 |
+
current_forcings = forcings.isel(time=target_slice)
|
160 |
+
current_forcings = current_forcings.assign_coords(time=targets_chunk_time)
|
161 |
+
current_forcings = current_forcings.compute()
|
162 |
+
# Make predictions for the chunk.
|
163 |
+
rng, this_rng = jax.random.split(rng)
|
164 |
+
predictions = predictor_fn(
|
165 |
+
rng=this_rng,
|
166 |
+
inputs=current_inputs,
|
167 |
+
targets_template=current_targets_template,
|
168 |
+
forcings=current_forcings)
|
169 |
+
|
170 |
+
next_frame = xarray.merge([predictions, current_forcings])
|
171 |
+
|
172 |
+
next_inputs = _get_next_inputs(current_inputs, next_frame)
|
173 |
+
|
174 |
+
# Shift timedelta coordinates, so we don't recompile at every iteration.
|
175 |
+
next_inputs = next_inputs.assign_coords(time=current_inputs.coords["time"])
|
176 |
+
current_inputs = next_inputs
|
177 |
+
|
178 |
+
# At this point we can assign the actual targets time coordinates.
|
179 |
+
predictions = predictions.assign_coords(time=actual_target_time)
|
180 |
+
if output_datetime is not None:
|
181 |
+
predictions.coords["datetime"] = output_datetime.isel(
|
182 |
+
time=target_slice)
|
183 |
+
yield predictions
|
184 |
+
del predictions
|
185 |
+
|
186 |
+
|
187 |
+
def _get_next_inputs(
|
188 |
+
prev_inputs: xarray.Dataset, next_frame: xarray.Dataset,
|
189 |
+
) -> xarray.Dataset:
|
190 |
+
"""Computes next inputs, from previous inputs and predictions."""
|
191 |
+
|
192 |
+
# Make sure are are predicting all inputs with a time axis.
|
193 |
+
non_predicted_or_forced_inputs = list(
|
194 |
+
set(prev_inputs.keys()) - set(next_frame.keys()))
|
195 |
+
if "time" in prev_inputs[non_predicted_or_forced_inputs].dims:
|
196 |
+
raise ValueError(
|
197 |
+
"Found an input with a time index that is not predicted or forced.")
|
198 |
+
|
199 |
+
# Keys we need to copy from predictions to inputs.
|
200 |
+
next_inputs_keys = list(
|
201 |
+
set(next_frame.keys()).intersection(set(prev_inputs.keys())))
|
202 |
+
next_inputs = next_frame[next_inputs_keys]
|
203 |
+
|
204 |
+
# Apply concatenate next frame with inputs, crop what we don't need.
|
205 |
+
num_inputs = prev_inputs.dims["time"]
|
206 |
+
return (
|
207 |
+
xarray.concat(
|
208 |
+
[prev_inputs, next_inputs], dim="time", data_vars="different")
|
209 |
+
.tail(time=num_inputs))
|
210 |
+
|
211 |
+
|
212 |
+
def extend_targets_template(
|
213 |
+
targets_template: xarray.Dataset,
|
214 |
+
required_num_steps: int) -> xarray.Dataset:
|
215 |
+
"""Extends `targets_template` to `required_num_steps` with lazy arrays.
|
216 |
+
|
217 |
+
It uses lazy dask arrays of zeros, so it does not require instantiating the
|
218 |
+
array in memory.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
targets_template: Input template to extend.
|
222 |
+
required_num_steps: Number of steps required in the returned template.
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
`xarray.Dataset` identical in variables and timestep to `targets_template`
|
226 |
+
full of `dask.array.zeros` such that the time axis has `required_num_steps`.
|
227 |
+
|
228 |
+
"""
|
229 |
+
|
230 |
+
# Extend the "time" and "datetime" coordinates
|
231 |
+
time = targets_template.coords["time"]
|
232 |
+
|
233 |
+
# Assert the first target time corresponds to the timestep.
|
234 |
+
timestep = time[0].data
|
235 |
+
if time.shape[0] > 1:
|
236 |
+
assert np.all(timestep == time[1:] - time[:-1])
|
237 |
+
|
238 |
+
extended_time = (np.arange(required_num_steps) + 1) * timestep
|
239 |
+
|
240 |
+
if "datetime" in targets_template.coords:
|
241 |
+
datetime = targets_template.coords["datetime"]
|
242 |
+
extended_datetime = (datetime[0].data - timestep) + extended_time
|
243 |
+
else:
|
244 |
+
extended_datetime = None
|
245 |
+
|
246 |
+
# Replace the values with empty dask arrays extending the time coordinates.
|
247 |
+
datetime = targets_template.coords["time"]
|
248 |
+
|
249 |
+
def extend_time(data_array: xarray.DataArray) -> xarray.DataArray:
|
250 |
+
dims = data_array.dims
|
251 |
+
shape = list(data_array.shape)
|
252 |
+
shape[dims.index("time")] = required_num_steps
|
253 |
+
dask_data = dask.array.zeros(
|
254 |
+
shape=tuple(shape),
|
255 |
+
chunks=-1, # Will give chunk info directly to `ChunksToZarr``.
|
256 |
+
dtype=data_array.dtype)
|
257 |
+
|
258 |
+
coords = dict(data_array.coords)
|
259 |
+
coords["time"] = extended_time
|
260 |
+
|
261 |
+
if extended_datetime is not None:
|
262 |
+
coords["datetime"] = ("time", extended_datetime)
|
263 |
+
|
264 |
+
return xarray.DataArray(
|
265 |
+
dims=dims,
|
266 |
+
data=dask_data,
|
267 |
+
coords=coords)
|
268 |
+
|
269 |
+
return xarray_tree.map_structure(extend_time, targets_template)
|
graphcast/solar_radiation.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Computes TOA incident solar radiation compatible with ERA5.
|
15 |
+
|
16 |
+
The Top-Of-the-Atmosphere (TOA) incident solar radiation is available in the
|
17 |
+
ERA5 dataset as the parameter `toa_incident_solar_radiation` (or `tisr`). This
|
18 |
+
represents the TOA solar radiation flux integrated over a period of one hour
|
19 |
+
ending at the timestamp given by the `datetime` coordinate. See
|
20 |
+
https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
|
21 |
+
https://codes.ecmwf.int/grib/param-db/?id=212.
|
22 |
+
"""
|
23 |
+
|
24 |
+
from collections.abc import Callable, Sequence
|
25 |
+
import dataclasses
|
26 |
+
import functools
|
27 |
+
|
28 |
+
import chex
|
29 |
+
import jax
|
30 |
+
import jax.numpy as jnp
|
31 |
+
import numpy as np
|
32 |
+
import pandas as pd
|
33 |
+
import xarray as xa
|
34 |
+
|
35 |
+
|
36 |
+
# Default value of the `integration_period` argument to be compatible with ERA5.
|
37 |
+
_DEFAULT_INTEGRATION_PERIOD = pd.Timedelta(hours=1)
|
38 |
+
|
39 |
+
# Default value for the `num_integration_bins` argument. This provides a good
|
40 |
+
# approximation of the solar radiation in ERA5.
|
41 |
+
_DEFAULT_NUM_INTEGRATION_BINS = 360
|
42 |
+
|
43 |
+
# The length of a Julian year in days.
|
44 |
+
# https://en.wikipedia.org/wiki/Julian_year_(astronomy)
|
45 |
+
_JULIAN_YEAR_LENGTH_IN_DAYS = 365.25
|
46 |
+
|
47 |
+
# Julian Date for the J2000 epoch, a standard reference used in astronomy.
|
48 |
+
# https://en.wikipedia.org/wiki/Epoch_(astronomy)#Julian_years_and_J2000
|
49 |
+
_J2000_EPOCH = 2451545.0
|
50 |
+
|
51 |
+
# Number of seconds in a day.
|
52 |
+
_SECONDS_PER_DAY = 60 * 60 * 24
|
53 |
+
|
54 |
+
|
55 |
+
_TimestampLike = str | pd.Timestamp | np.datetime64
|
56 |
+
_TimedeltaLike = str | pd.Timedelta | np.timedelta64
|
57 |
+
|
58 |
+
|
59 |
+
# Interface for loading Total Solar Irradiance (TSI) data.
|
60 |
+
# Returns a xa.DataArray containing yearly average TSI values with a `time`
|
61 |
+
# coordinate in units of years since 0000-1-1. E.g. 2023.5 corresponds to
|
62 |
+
# the middle of the year 2023.
|
63 |
+
TsiDataLoader = Callable[[], xa.DataArray]
|
64 |
+
|
65 |
+
|
66 |
+
# Total Solar Irradiance (TSI): Energy input to the top of the Earth's
|
67 |
+
# atmosphere in W⋅m⁻². TSI varies with time. This is the reference TSI value
|
68 |
+
# that can be used when more accurate data is not available.
|
69 |
+
# https://www.ncei.noaa.gov/products/climate-data-records/total-solar-irradiance
|
70 |
+
# https://github.com/ecmwf-ifs/ecrad/blob/6db82f929fb75028cc20606a04da87c0abe9b642/radiation/radiation_ecckd.F90#L296
|
71 |
+
_REFERENCE_TSI = 1361.0
|
72 |
+
|
73 |
+
|
74 |
+
def reference_tsi_data() -> xa.DataArray:
|
75 |
+
"""A TsiDataProvider that returns a single reference TSI value."""
|
76 |
+
return xa.DataArray(
|
77 |
+
np.array([_REFERENCE_TSI]),
|
78 |
+
dims=["time"],
|
79 |
+
coords={"time": np.array([0.0])},
|
80 |
+
)
|
81 |
+
|
82 |
+
|
83 |
+
def era5_tsi_data() -> xa.DataArray:
|
84 |
+
"""A TsiDataProvider that returns ERA5 compatible TSI data."""
|
85 |
+
# ECMWF provided the data used for ERA5, which was hardcoded in the IFS (cycle
|
86 |
+
# 41r2). The values were scaled down to agree better with more recent
|
87 |
+
# observations of the sun.
|
88 |
+
time = np.arange(1951.5, 2035.5, 1.0)
|
89 |
+
tsi = 0.9965 * np.array([
|
90 |
+
# fmt: off
|
91 |
+
# 1951-1995 (non-repeating sequence)
|
92 |
+
1365.7765, 1365.7676, 1365.6284, 1365.6564, 1365.7773,
|
93 |
+
1366.3109, 1366.6681, 1366.6328, 1366.3828, 1366.2767,
|
94 |
+
1365.9199, 1365.7484, 1365.6963, 1365.6976, 1365.7341,
|
95 |
+
1365.9178, 1366.1143, 1366.1644, 1366.2476, 1366.2426,
|
96 |
+
1365.9580, 1366.0525, 1365.7991, 1365.7271, 1365.5345,
|
97 |
+
1365.6453, 1365.8331, 1366.2747, 1366.6348, 1366.6482,
|
98 |
+
1366.6951, 1366.2859, 1366.1992, 1365.8103, 1365.6416,
|
99 |
+
1365.6379, 1365.7899, 1366.0826, 1366.6479, 1366.5533,
|
100 |
+
1366.4457, 1366.3021, 1366.0286, 1365.7971, 1365.6996,
|
101 |
+
# 1996-2008 (13 year cycle, repeated below)
|
102 |
+
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
|
103 |
+
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
|
104 |
+
1365.8107, 1365.7240, 1365.6918,
|
105 |
+
# 2009-2021
|
106 |
+
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
|
107 |
+
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
|
108 |
+
1365.8107, 1365.7240, 1365.6918,
|
109 |
+
# 2022-2034
|
110 |
+
1365.6121, 1365.7399, 1366.1021, 1366.3851, 1366.6836,
|
111 |
+
1366.6022, 1366.6807, 1366.2300, 1366.0480, 1365.8545,
|
112 |
+
1365.8107, 1365.7240, 1365.6918,
|
113 |
+
# fmt: on
|
114 |
+
])
|
115 |
+
return xa.DataArray(tsi, dims=["time"], coords={"time": time})
|
116 |
+
|
117 |
+
|
118 |
+
# HRES compatible TSI data is from IFS cycle 47r1. The dataset can be obtained
|
119 |
+
# from the ECRAD package: https://confluence.ecmwf.int/display/ECRAD.
|
120 |
+
# The example code below can load this dataset from a local file.
|
121 |
+
|
122 |
+
# def hres_tsi_data() -> xa.DataArray:
|
123 |
+
# with open("total_solar_irradiance_CMIP6_47r1.nc", "rb") as f:
|
124 |
+
# with xa.load_dataset(f, decode_times=False) as ds:
|
125 |
+
# return ds["tsi"]
|
126 |
+
|
127 |
+
|
128 |
+
_DEFAULT_TSI_DATA_LOADER: TsiDataLoader = era5_tsi_data
|
129 |
+
|
130 |
+
|
131 |
+
def get_tsi(
|
132 |
+
timestamps: Sequence[_TimestampLike], tsi_data: xa.DataArray
|
133 |
+
) -> chex.Array:
|
134 |
+
"""Returns TSI values for the given timestamps.
|
135 |
+
|
136 |
+
TSI values are interpolated from the provided yearly TSI data.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
timestamps: Timestamps for which to compute TSI values.
|
140 |
+
tsi_data: A DataArray with a single dimension `time` that has coordinates in
|
141 |
+
units of years since 0000-1-1. E.g. 2023.5 corresponds to the middle of
|
142 |
+
the year 2023.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
An Array containing interpolated TSI data.
|
146 |
+
"""
|
147 |
+
timestamps = pd.DatetimeIndex(timestamps)
|
148 |
+
timestamps_date = pd.DatetimeIndex(timestamps.date)
|
149 |
+
day_fraction = (timestamps - timestamps_date) / pd.Timedelta(days=1)
|
150 |
+
year_length = 365 + timestamps.is_leap_year
|
151 |
+
year_fraction = (timestamps.dayofyear - 1 + day_fraction) / year_length
|
152 |
+
fractional_year = timestamps.year + year_fraction
|
153 |
+
return np.interp(fractional_year, tsi_data.coords["time"].data, tsi_data.data)
|
154 |
+
|
155 |
+
|
156 |
+
@dataclasses.dataclass(frozen=True)
|
157 |
+
class _OrbitalParameters:
|
158 |
+
"""Parameters characterising Earth's position relative to the Sun.
|
159 |
+
|
160 |
+
The parameters characterize the position of the Earth in its orbit around the
|
161 |
+
Sun for specific points in time. Each attribute is an N-dimensional array
|
162 |
+
to represent orbital parameters for multiple points in time.
|
163 |
+
|
164 |
+
Attributes:
|
165 |
+
theta: The number of Julian years since the Julian epoch J2000.0.
|
166 |
+
rotational_phase: The phase of the Earth's rotation along its axis as a
|
167 |
+
ratio with 0 representing the phase at Julian epoch J2000.0 at exactly
|
168 |
+
12:00 Terrestrial Time (TT). Multiplying this value by `2*pi` yields the
|
169 |
+
phase in radians.
|
170 |
+
sin_declination: Sine of the declination of the Sun as seen from the Earth.
|
171 |
+
cos_declination: Cosine of the declination of the Sun as seen from the
|
172 |
+
Earth.
|
173 |
+
eq_of_time_seconds: The value of the equation of time, in seconds.
|
174 |
+
solar_distance_au: Earth-Sun distance in astronomical units.
|
175 |
+
"""
|
176 |
+
|
177 |
+
theta: chex.Array
|
178 |
+
rotational_phase: chex.Array
|
179 |
+
sin_declination: chex.Array
|
180 |
+
cos_declination: chex.Array
|
181 |
+
eq_of_time_seconds: chex.Array
|
182 |
+
solar_distance_au: chex.Array
|
183 |
+
|
184 |
+
|
185 |
+
def _get_j2000_days(timestamp: pd.Timestamp) -> float:
|
186 |
+
"""Returns the number of days since the J2000 epoch.
|
187 |
+
|
188 |
+
Args:
|
189 |
+
timestamp: A timestamp for which to compute the J2000 days.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
The J2000 days corresponding to the input timestamp.
|
193 |
+
"""
|
194 |
+
return timestamp.to_julian_date() - _J2000_EPOCH
|
195 |
+
|
196 |
+
|
197 |
+
def _get_orbital_parameters(j2000_days: chex.Array) -> _OrbitalParameters:
|
198 |
+
"""Computes the orbital parameters for the given J2000 days.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
j2000_days: Timestamps represented as the number of days since the J2000
|
202 |
+
epoch.
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
Orbital parameters for the given timestamps. Each attribute of the return
|
206 |
+
value is an array containing the same dimensions as the input.
|
207 |
+
"""
|
208 |
+
# Orbital parameters are computed based on the formulas in this code, which
|
209 |
+
# were determined empirically to produce radiation values similar to ERA5:
|
210 |
+
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/sucst.F90
|
211 |
+
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fctast.cdk
|
212 |
+
# https://github.com/ECCC-ASTD-MRD/gem/blob/1d711f7b89971cd7b1e10afc7508d1135b51397d/src/rpnphy/src/base/fcttim.cdk
|
213 |
+
# There are many variations to these formulas, but since the goal is to match
|
214 |
+
# the values in ERA5, the formulas were implemented as is. Comments reference
|
215 |
+
# the notation used in those sources. Here are some additional references
|
216 |
+
# related to the quantities being computed here:
|
217 |
+
# https://aa.usno.navy.mil/faq/sun_approx
|
218 |
+
# https://en.wikipedia.org/wiki/Position_of_the_Sun
|
219 |
+
# https://en.wikipedia.org/wiki/Equation_of_time
|
220 |
+
|
221 |
+
# Number of Julian years since the J2000 epoch (including fractional years).
|
222 |
+
theta = j2000_days / _JULIAN_YEAR_LENGTH_IN_DAYS
|
223 |
+
# The phase of the Earth's rotation along its axis as a ratio. 0 represents
|
224 |
+
# Julian epoch J2000.0 at exactly 12:00 Terrestrial Time (TT).
|
225 |
+
rotational_phase = j2000_days % 1.0
|
226 |
+
|
227 |
+
# REL(PTETA).
|
228 |
+
rel = 1.7535 + 6.283076 * theta
|
229 |
+
# REM(PTETA).
|
230 |
+
rem = 6.240041 + 6.283020 * theta
|
231 |
+
# RLLS(PTETA).
|
232 |
+
rlls = 4.8951 + 6.283076 * theta
|
233 |
+
|
234 |
+
# Variables used in the three polynomials below.
|
235 |
+
one = jnp.ones_like(theta)
|
236 |
+
sin_rel = jnp.sin(rel)
|
237 |
+
cos_rel = jnp.cos(rel)
|
238 |
+
sin_two_rel = jnp.sin(2.0 * rel)
|
239 |
+
cos_two_rel = jnp.cos(2.0 * rel)
|
240 |
+
sin_two_rlls = jnp.sin(2.0 * rlls)
|
241 |
+
cos_two_rlls = jnp.cos(2.0 * rlls)
|
242 |
+
sin_four_rlls = jnp.sin(4.0 * rlls)
|
243 |
+
sin_rem = jnp.sin(rem)
|
244 |
+
sin_two_rem = jnp.sin(2.0 * rem)
|
245 |
+
|
246 |
+
# Ecliptic longitude of the Sun - RLLLS(PTETA).
|
247 |
+
rllls = jnp.dot(
|
248 |
+
jnp.stack(
|
249 |
+
[one, theta, sin_rel, cos_rel, sin_two_rel, cos_two_rel], axis=-1
|
250 |
+
),
|
251 |
+
jnp.array([4.8952, 6.283320, -0.0075, -0.0326, -0.0003, 0.0002]),
|
252 |
+
)
|
253 |
+
|
254 |
+
# Angle in radians between the Earth's rotational axis and its orbital axis.
|
255 |
+
# Equivalent to 23.4393°.
|
256 |
+
repsm = 0.409093
|
257 |
+
|
258 |
+
# Declination of the Sun - RDS(teta).
|
259 |
+
sin_declination = jnp.sin(repsm) * jnp.sin(rllls)
|
260 |
+
cos_declination = jnp.sqrt(1.0 - sin_declination**2)
|
261 |
+
|
262 |
+
# Equation of time in seconds - RET(PTETA).
|
263 |
+
eq_of_time_seconds = jnp.dot(
|
264 |
+
jnp.stack(
|
265 |
+
[
|
266 |
+
sin_two_rlls,
|
267 |
+
sin_rem,
|
268 |
+
sin_rem * cos_two_rlls,
|
269 |
+
sin_four_rlls,
|
270 |
+
sin_two_rem,
|
271 |
+
],
|
272 |
+
axis=-1,
|
273 |
+
),
|
274 |
+
jnp.array([591.8, -459.4, 39.5, -12.7, -4.8]),
|
275 |
+
)
|
276 |
+
|
277 |
+
# Earth-Sun distance in astronomical units - RRS(PTETA).
|
278 |
+
solar_distance_au = jnp.dot(
|
279 |
+
jnp.stack([one, sin_rel, cos_rel], axis=-1),
|
280 |
+
jnp.array([1.0001, -0.0163, 0.0037]),
|
281 |
+
)
|
282 |
+
|
283 |
+
return _OrbitalParameters(
|
284 |
+
theta=theta,
|
285 |
+
rotational_phase=rotational_phase,
|
286 |
+
sin_declination=sin_declination,
|
287 |
+
cos_declination=cos_declination,
|
288 |
+
eq_of_time_seconds=eq_of_time_seconds,
|
289 |
+
solar_distance_au=solar_distance_au,
|
290 |
+
)
|
291 |
+
|
292 |
+
|
293 |
+
def _get_solar_sin_altitude(
|
294 |
+
op: _OrbitalParameters,
|
295 |
+
sin_latitude: chex.Array,
|
296 |
+
cos_latitude: chex.Array,
|
297 |
+
longitude: chex.Array,
|
298 |
+
) -> chex.Array:
|
299 |
+
"""Returns the sine of the solar altitude angle.
|
300 |
+
|
301 |
+
All computations are vectorized. Dimensions of all the inputs should be
|
302 |
+
broadcastable using standard NumPy rules. For example, if `op` has shape
|
303 |
+
`(T, 1, 1)`, `latitude` has shape `(1, H, 1)`, and `longitude` has shape
|
304 |
+
`(1, H, W)`, the return value will have shape `(T, H, W)`.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
op: Orbital parameters characterising Earth's position relative to the Sun.
|
308 |
+
sin_latitude: Sine of latitude coordinates.
|
309 |
+
cos_latitude: Cosine of latitude coordinates.
|
310 |
+
longitude: Longitude coordinates in radians.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
Sine of the solar altitude angle for each set of orbital parameters and each
|
314 |
+
geographical coordinates. The returned array has the shape resulting from
|
315 |
+
broadcasting all the inputs together.
|
316 |
+
"""
|
317 |
+
solar_time = op.rotational_phase + op.eq_of_time_seconds / _SECONDS_PER_DAY
|
318 |
+
# https://en.wikipedia.org/wiki/Hour_angle#Solar_hour_angle
|
319 |
+
hour_angle = 2.0 * jnp.pi * solar_time + longitude
|
320 |
+
# https://en.wikipedia.org/wiki/Solar_zenith_angle
|
321 |
+
sin_altitude = (
|
322 |
+
cos_latitude * op.cos_declination * jnp.cos(hour_angle)
|
323 |
+
+ sin_latitude * op.sin_declination
|
324 |
+
)
|
325 |
+
return sin_altitude
|
326 |
+
|
327 |
+
|
328 |
+
def _get_radiation_flux(
|
329 |
+
j2000_days: chex.Array,
|
330 |
+
sin_latitude: chex.Array,
|
331 |
+
cos_latitude: chex.Array,
|
332 |
+
longitude: chex.Array,
|
333 |
+
tsi: chex.Array,
|
334 |
+
) -> chex.Array:
|
335 |
+
"""Computes the instantaneous TOA incident solar radiation flux.
|
336 |
+
|
337 |
+
Computes the instantanous Top-Of-the-Atmosphere (TOA) incident radiation flux
|
338 |
+
in W⋅m⁻² for the given timestamps and locations on the surface of the Earth.
|
339 |
+
See https://en.wikipedia.org/wiki/Solar_irradiance.
|
340 |
+
|
341 |
+
All inputs are assumed to be broadcastable together using standard NumPy
|
342 |
+
rules.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
j2000_days: Timestamps represented as the number of days since the J2000
|
346 |
+
epoch.
|
347 |
+
sin_latitude: Sine of latitude coordinates.
|
348 |
+
cos_latitude: Cosine of latitude coordinates.
|
349 |
+
longitude: Longitude coordinates in radians.
|
350 |
+
tsi: Total Solar Irradiance (TSI) in W⋅m⁻². This can be a scalar (default)
|
351 |
+
to use the same TSI value for all the inputs, or an array to allow TSI to
|
352 |
+
depend on the timestamps.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
The instataneous TOA incident solar radiation flux in W⋅m⁻² for the given
|
356 |
+
timestamps and geographical coordinates. The returned array has the shape
|
357 |
+
resulting from broadcasting all the inputs together.
|
358 |
+
"""
|
359 |
+
op = _get_orbital_parameters(j2000_days)
|
360 |
+
# Attenuation of the solar radiation based on the solar distance.
|
361 |
+
solar_factor = (1.0 / op.solar_distance_au) ** 2
|
362 |
+
sin_altitude = _get_solar_sin_altitude(
|
363 |
+
op, sin_latitude, cos_latitude, longitude
|
364 |
+
)
|
365 |
+
return tsi * solar_factor * jnp.maximum(sin_altitude, 0.0)
|
366 |
+
|
367 |
+
|
368 |
+
def _get_integrated_radiation(
|
369 |
+
j2000_days: chex.Array,
|
370 |
+
sin_latitude: chex.Array,
|
371 |
+
cos_latitude: chex.Array,
|
372 |
+
longitude: chex.Array,
|
373 |
+
tsi: chex.Array,
|
374 |
+
integration_period: pd.Timedelta,
|
375 |
+
num_integration_bins: int,
|
376 |
+
) -> chex.Array:
|
377 |
+
"""Returns the TOA solar radiation flux integrated over a time period.
|
378 |
+
|
379 |
+
Integrates the instantaneous TOA solar radiation flux over a time period.
|
380 |
+
The input timestamps represent the end times of each integration period.
|
381 |
+
When the integration period is one hour this approximates the
|
382 |
+
`toa_incident_solar_radiation` (or `tisr`) parameter from the ERA5 dataset.
|
383 |
+
See https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
|
384 |
+
https://codes.ecmwf.int/grib/param-db/?id=212.
|
385 |
+
|
386 |
+
All inputs are assumed to be broadcastable together using standard NumPy
|
387 |
+
rules. To approximate the integral, the instantaneous radiation is computed
|
388 |
+
at `num_integration_bins+1` time steps using `_get_radiation_flux` and
|
389 |
+
integrated using the trapezoidal rule. A dimension is appended at the end
|
390 |
+
of all inputs to compute the instantaneous radiation, which is then integrated
|
391 |
+
over to compute the final result.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
j2000_days: Timestamps represented as the number of days since the J2000
|
395 |
+
epoch. These correspond to the end times of each integration period.
|
396 |
+
sin_latitude: Sine of latitude coordinates.
|
397 |
+
cos_latitude: Cosine of latitude coordinates.
|
398 |
+
longitude: Longitude in radians.
|
399 |
+
tsi: Total Solar Irradiance (TSI) in W⋅m⁻².
|
400 |
+
integration_period: Integration period.
|
401 |
+
num_integration_bins: Number of bins to divide the `integration_period` to
|
402 |
+
approximate the integral using the trapezoidal rule.
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
The TOA solar radiation flux integrated over the requested time period for
|
406 |
+
the given timestamps and geographical coordinates. Unit is J⋅m⁻² .
|
407 |
+
"""
|
408 |
+
# Offsets for the integration time steps.
|
409 |
+
offsets = (
|
410 |
+
pd.timedelta_range(
|
411 |
+
start=-integration_period,
|
412 |
+
end=pd.Timedelta(0),
|
413 |
+
periods=num_integration_bins + 1,
|
414 |
+
)
|
415 |
+
/ pd.Timedelta(days=1)
|
416 |
+
).to_numpy()
|
417 |
+
|
418 |
+
# Integration happens over the time dimension. Compute the instantaneous
|
419 |
+
# radiation flux for all the integration time steps by appending a dimension
|
420 |
+
# to all the inputs and adding `offsets` to `j2000_days` (will be broadcast
|
421 |
+
# over all the other dimensions).
|
422 |
+
fluxes = _get_radiation_flux(
|
423 |
+
j2000_days=jnp.expand_dims(j2000_days, axis=-1) + offsets,
|
424 |
+
sin_latitude=jnp.expand_dims(sin_latitude, axis=-1),
|
425 |
+
cos_latitude=jnp.expand_dims(cos_latitude, axis=-1),
|
426 |
+
longitude=jnp.expand_dims(longitude, axis=-1),
|
427 |
+
tsi=jnp.expand_dims(tsi, axis=-1),
|
428 |
+
)
|
429 |
+
|
430 |
+
# Size of each bin in seconds. The instantaneous solar radiation flux is
|
431 |
+
# returned in units of W⋅m⁻². Integrating over time expressed in seconds
|
432 |
+
# yields a result in units of J⋅m⁻².
|
433 |
+
dx = (integration_period / num_integration_bins) / pd.Timedelta(seconds=1)
|
434 |
+
return jax.scipy.integrate.trapezoid(fluxes, dx=dx)
|
435 |
+
|
436 |
+
|
437 |
+
_get_integrated_radiation_jitted = jax.jit(
|
438 |
+
_get_integrated_radiation,
|
439 |
+
static_argnames=["integration_period", "num_integration_bins"],
|
440 |
+
)
|
441 |
+
|
442 |
+
|
443 |
+
def get_toa_incident_solar_radiation(
|
444 |
+
timestamps: Sequence[_TimestampLike],
|
445 |
+
latitude: chex.Array,
|
446 |
+
longitude: chex.Array,
|
447 |
+
tsi_data: xa.DataArray | None = None,
|
448 |
+
integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
|
449 |
+
num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
|
450 |
+
use_jit: bool = False,
|
451 |
+
) -> chex.Array:
|
452 |
+
"""Computes the solar radiation incident at the top of the atmosphere.
|
453 |
+
|
454 |
+
The solar radiation is computed for each element in `timestamps` for all the
|
455 |
+
locations on the grid determined by the `latitude` and `longitude` parameters.
|
456 |
+
|
457 |
+
To approximate the `toa_incident_solar_radiation` (or `tisr`) parameter from
|
458 |
+
the ERA5 dataset, set `integration_period` to one hour (default). See
|
459 |
+
https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation and
|
460 |
+
https://codes.ecmwf.int/grib/param-db/?id=212.
|
461 |
+
|
462 |
+
Args:
|
463 |
+
timestamps: Timestamps for which to compute the solar radiation.
|
464 |
+
latitude: The latitude coordinates in degrees of the grid for which to
|
465 |
+
compute the solar radiation.
|
466 |
+
longitude: The longitude coordinates in degrees of the grid for which to
|
467 |
+
compute the solar radiation.
|
468 |
+
tsi_data: A DataArray containing yearly TSI data as returned by a
|
469 |
+
`TsiDataLoader`. The default is to use ERA5 compatible TSI data.
|
470 |
+
integration_period: Timedelta to use to integrate the radiation, e.g. if
|
471 |
+
producing radiation for 1989-11-08 21:00:00, and `integration_period` is
|
472 |
+
"1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
|
473 |
+
21:00:00. The default value ("1h") matches ERA5.
|
474 |
+
num_integration_bins: Number of equally spaced bins to divide the
|
475 |
+
`integration_period` in when approximating the integral using the
|
476 |
+
trapezoidal rule. Performance and peak memory usage are affected by this
|
477 |
+
value. The default (360) provides a good approximation, but lower values
|
478 |
+
may work to improve performance and reduce memory usage.
|
479 |
+
use_jit: Set to True to use the jitted implementation, or False (default) to
|
480 |
+
use the non-jitted one.
|
481 |
+
|
482 |
+
Returns:
|
483 |
+
An 3D array with dimensions (time, lat, lon) containing the total
|
484 |
+
top of atmosphere solar radiation integrated for the `integration_period`
|
485 |
+
up to each timestamp.
|
486 |
+
"""
|
487 |
+
# Add a trailing dimension to latitude to get dimensions (lat, lon).
|
488 |
+
lat = jnp.radians(latitude).reshape((-1, 1))
|
489 |
+
lon = jnp.radians(longitude)
|
490 |
+
sin_lat = jnp.sin(lat)
|
491 |
+
cos_lat = jnp.cos(lat)
|
492 |
+
integration_period = pd.Timedelta(integration_period)
|
493 |
+
if tsi_data is None:
|
494 |
+
tsi_data = _DEFAULT_TSI_DATA_LOADER()
|
495 |
+
tsi = get_tsi(timestamps, tsi_data)
|
496 |
+
fn = (
|
497 |
+
_get_integrated_radiation_jitted if use_jit else _get_integrated_radiation
|
498 |
+
)
|
499 |
+
|
500 |
+
# Compute integral for each timestamp individually. Although this could be
|
501 |
+
# done in one step, peak memory usage would be proportional to
|
502 |
+
# `len(timestamps) * num_integration_bins`. Computing each timestamp
|
503 |
+
# individually reduces this to `max(len(timestamps), num_integration_bins)`.
|
504 |
+
# E.g. memory usage for a single timestamp, with a full 0.25° grid and 360
|
505 |
+
# integration bins is about 1.5 GB (1440 * 721 * 361 * 4 bytes); computing
|
506 |
+
# forcings for 40 prediction steps would require 60 GB.
|
507 |
+
results = []
|
508 |
+
for idx, timestamp in enumerate(timestamps):
|
509 |
+
results.append(
|
510 |
+
fn(
|
511 |
+
j2000_days=jnp.array(_get_j2000_days(pd.Timestamp(timestamp))),
|
512 |
+
sin_latitude=sin_lat,
|
513 |
+
cos_latitude=cos_lat,
|
514 |
+
longitude=lon,
|
515 |
+
tsi=tsi[idx],
|
516 |
+
integration_period=integration_period,
|
517 |
+
num_integration_bins=num_integration_bins,
|
518 |
+
)
|
519 |
+
)
|
520 |
+
return jnp.stack(results, axis=0)
|
521 |
+
|
522 |
+
|
523 |
+
def get_toa_incident_solar_radiation_for_xarray(
|
524 |
+
data_array_like: xa.DataArray | xa.Dataset,
|
525 |
+
tsi_data: xa.DataArray | None = None,
|
526 |
+
integration_period: _TimedeltaLike = _DEFAULT_INTEGRATION_PERIOD,
|
527 |
+
num_integration_bins: int = _DEFAULT_NUM_INTEGRATION_BINS,
|
528 |
+
use_jit: bool = False,
|
529 |
+
) -> xa.DataArray:
|
530 |
+
"""Computes the solar radiation incident at the top of the atmosphere.
|
531 |
+
|
532 |
+
This method is a wrapper for `get_toa_incident_solar_radiation` using
|
533 |
+
coordinates from an Xarray and returning an Xarray.
|
534 |
+
|
535 |
+
Args:
|
536 |
+
data_array_like: A xa.Dataset or xa.DataArray from which to take the time
|
537 |
+
and spatial coordinates for which to compute the solar radiation. It must
|
538 |
+
contain `lat` and `lon` spatial dimensions with corresponding coordinates.
|
539 |
+
If a `time` dimension is present, the `datetime` coordinate should be a
|
540 |
+
vector associated with that dimension containing timestamps for which to
|
541 |
+
compute the solar radiation. Otherwise, the `datetime` coordinate should
|
542 |
+
be a scalar representing the timestamp for which to compute the solar
|
543 |
+
radiation.
|
544 |
+
tsi_data: A DataArray containing yearly TSI data as returned by a
|
545 |
+
`TsiDataLoader`. The default is to use ERA5 compatible TSI data.
|
546 |
+
integration_period: Timedelta to use to integrate the radiation, e.g. if
|
547 |
+
producing radiation for 1989-11-08 21:00:00, and `integration_period` is
|
548 |
+
"1h", radiation will be integrated from 1989-11-08 20:00:00 to 1989-11-08
|
549 |
+
21:00:00. The default value ("1h") matches ERA5.
|
550 |
+
num_integration_bins: Number of equally spaced bins to divide the
|
551 |
+
`integration_period` in when approximating the integral using the
|
552 |
+
trapezoidal rule. Performance and peak memory usage are affected by this
|
553 |
+
value. The default (360) provides a good approximation, but lower values
|
554 |
+
may work to improve performance and reduce memory usage.
|
555 |
+
use_jit: Set to True to use the jitted implementation, or False to use the
|
556 |
+
non-jitted one.
|
557 |
+
|
558 |
+
Returns:
|
559 |
+
xa.DataArray with dimensions `(time, lat, lon)` if `data_array_like` had
|
560 |
+
a `time` dimension; or dimensions `(lat, lon)` otherwise. The `datetime`
|
561 |
+
coordinates and those for the dimensions are copied to the returned array.
|
562 |
+
The array contains the total top of atmosphere solar radiation integrated
|
563 |
+
for `integration_period` up to the corresponding `datetime`.
|
564 |
+
|
565 |
+
Raises:
|
566 |
+
ValueError: If there are missing coordinates or dimensions.
|
567 |
+
"""
|
568 |
+
missing_dims = set(["lat", "lon"]) - set(data_array_like.dims)
|
569 |
+
if missing_dims:
|
570 |
+
raise ValueError(
|
571 |
+
f"'{missing_dims}' dimensions are missing in `data_array_like`."
|
572 |
+
)
|
573 |
+
|
574 |
+
missing_coords = set(["datetime", "lat", "lon"]) - set(data_array_like.coords)
|
575 |
+
if missing_coords:
|
576 |
+
raise ValueError(
|
577 |
+
f"'{missing_coords}' coordinates are missing in `data_array_like`."
|
578 |
+
)
|
579 |
+
|
580 |
+
if "time" in data_array_like.dims:
|
581 |
+
timestamps = data_array_like.coords["datetime"].data
|
582 |
+
else:
|
583 |
+
timestamps = [data_array_like.coords["datetime"].data.item()]
|
584 |
+
|
585 |
+
radiation = get_toa_incident_solar_radiation(
|
586 |
+
timestamps=timestamps,
|
587 |
+
latitude=data_array_like.coords["lat"].data,
|
588 |
+
longitude=data_array_like.coords["lon"].data,
|
589 |
+
tsi_data=tsi_data,
|
590 |
+
integration_period=integration_period,
|
591 |
+
num_integration_bins=num_integration_bins,
|
592 |
+
use_jit=use_jit,
|
593 |
+
)
|
594 |
+
|
595 |
+
if "time" in data_array_like.dims:
|
596 |
+
output = xa.DataArray(radiation, dims=("time", "lat", "lon"))
|
597 |
+
else:
|
598 |
+
output = xa.DataArray(radiation[0], dims=("lat", "lon"))
|
599 |
+
|
600 |
+
# Preserve as many of the original coordinates as possible, so long as the
|
601 |
+
# dimension or the coordinate still exist in the output array.
|
602 |
+
for k, coord in data_array_like.coords.items():
|
603 |
+
if set(coord.dims).issubset(set(output.dims)):
|
604 |
+
output.coords[k] = coord
|
605 |
+
return output
|
graphcast/solar_radiation_test.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
import timeit
|
15 |
+
from typing import Sequence
|
16 |
+
|
17 |
+
from absl import logging
|
18 |
+
from absl.testing import absltest
|
19 |
+
from absl.testing import parameterized
|
20 |
+
from graphcast import solar_radiation
|
21 |
+
import numpy as np
|
22 |
+
import pandas as pd
|
23 |
+
import xarray as xa
|
24 |
+
|
25 |
+
|
26 |
+
def _get_grid_lat_lon_coords(
|
27 |
+
num_lat: int, num_lon: int
|
28 |
+
) -> tuple[np.ndarray, np.ndarray]:
|
29 |
+
"""Generates a linear latitude-longitude grid of the given size.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
num_lat: Size of the latitude dimension of the grid.
|
33 |
+
num_lon: Size of the longitude dimension of the grid.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
A tuple `(lat, lon)` containing 1D arrays with the latitude and longitude
|
37 |
+
coordinates in degrees of the generated grid.
|
38 |
+
"""
|
39 |
+
lat = np.linspace(-90.0, 90.0, num=num_lat, endpoint=True)
|
40 |
+
lon = np.linspace(0.0, 360.0, num=num_lon, endpoint=False)
|
41 |
+
return lat, lon
|
42 |
+
|
43 |
+
|
44 |
+
class SolarRadiationTest(parameterized.TestCase):
|
45 |
+
|
46 |
+
def setUp(self):
|
47 |
+
super().setUp()
|
48 |
+
np.random.seed(0)
|
49 |
+
|
50 |
+
def test_missing_dim_raises_value_error(self):
|
51 |
+
data = xa.DataArray(
|
52 |
+
np.random.randn(2, 2),
|
53 |
+
coords=[np.array([0.1, 0.2]), np.array([0.0, 0.5])],
|
54 |
+
dims=["lon", "x"],
|
55 |
+
)
|
56 |
+
with self.assertRaisesRegex(
|
57 |
+
ValueError, r".* dimensions are missing in `data_array_like`."
|
58 |
+
):
|
59 |
+
solar_radiation.get_toa_incident_solar_radiation_for_xarray(
|
60 |
+
data, integration_period="1h", num_integration_bins=360
|
61 |
+
)
|
62 |
+
|
63 |
+
def test_missing_coordinate_raises_value_error(self):
|
64 |
+
data = xa.Dataset(
|
65 |
+
data_vars={"var1": (["x", "lat", "lon"], np.random.randn(2, 3, 2))},
|
66 |
+
coords={
|
67 |
+
"lat": np.array([0.0, 0.1, 0.2]),
|
68 |
+
"lon": np.array([0.0, 0.5]),
|
69 |
+
},
|
70 |
+
)
|
71 |
+
with self.assertRaisesRegex(
|
72 |
+
ValueError, r".* coordinates are missing in `data_array_like`."
|
73 |
+
):
|
74 |
+
solar_radiation.get_toa_incident_solar_radiation_for_xarray(
|
75 |
+
data, integration_period="1h", num_integration_bins=360
|
76 |
+
)
|
77 |
+
|
78 |
+
def test_shape_multiple_timestamps(self):
|
79 |
+
data = xa.Dataset(
|
80 |
+
data_vars={"var1": (["time", "lat", "lon"], np.random.randn(2, 4, 2))},
|
81 |
+
coords={
|
82 |
+
"lat": np.array([0.0, 0.1, 0.2, 0.3]),
|
83 |
+
"lon": np.array([0.0, 0.5]),
|
84 |
+
"time": np.array([100, 200], dtype="timedelta64[s]"),
|
85 |
+
"datetime": xa.Variable(
|
86 |
+
"time", np.array([10, 20], dtype="datetime64[D]")
|
87 |
+
),
|
88 |
+
},
|
89 |
+
)
|
90 |
+
|
91 |
+
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
|
92 |
+
data, integration_period="1h", num_integration_bins=2
|
93 |
+
)
|
94 |
+
|
95 |
+
self.assertEqual(("time", "lat", "lon"), actual.dims)
|
96 |
+
self.assertEqual((2, 4, 2), actual.shape)
|
97 |
+
|
98 |
+
def test_shape_single_timestamp(self):
|
99 |
+
data = xa.Dataset(
|
100 |
+
data_vars={"var1": (["lat", "lon"], np.random.randn(4, 2))},
|
101 |
+
coords={
|
102 |
+
"lat": np.array([0.0, 0.1, 0.2, 0.3]),
|
103 |
+
"lon": np.array([0.0, 0.5]),
|
104 |
+
"datetime": np.datetime64(10, "D"),
|
105 |
+
},
|
106 |
+
)
|
107 |
+
|
108 |
+
actual = solar_radiation.get_toa_incident_solar_radiation_for_xarray(
|
109 |
+
data, integration_period="1h", num_integration_bins=2
|
110 |
+
)
|
111 |
+
|
112 |
+
self.assertEqual(("lat", "lon"), actual.dims)
|
113 |
+
self.assertEqual((4, 2), actual.shape)
|
114 |
+
|
115 |
+
@parameterized.named_parameters(
|
116 |
+
dict(
|
117 |
+
testcase_name="one_timestamp_jitted",
|
118 |
+
periods=1,
|
119 |
+
repeats=3,
|
120 |
+
use_jit=True,
|
121 |
+
),
|
122 |
+
dict(
|
123 |
+
testcase_name="one_timestamp_non_jitted",
|
124 |
+
periods=1,
|
125 |
+
repeats=3,
|
126 |
+
use_jit=False,
|
127 |
+
),
|
128 |
+
dict(
|
129 |
+
testcase_name="ten_timestamps_non_jitted",
|
130 |
+
periods=10,
|
131 |
+
repeats=1,
|
132 |
+
use_jit=False,
|
133 |
+
),
|
134 |
+
)
|
135 |
+
def test_full_spatial_resolution(
|
136 |
+
self, periods: int, repeats: int, use_jit: bool
|
137 |
+
):
|
138 |
+
timestamps = pd.date_range(start="2023-09-25", periods=periods, freq="6h")
|
139 |
+
# Generate a linear grid with 0.25 degrees resolution similar to ERA5.
|
140 |
+
lat, lon = _get_grid_lat_lon_coords(num_lat=721, num_lon=1440)
|
141 |
+
|
142 |
+
def benchmark() -> None:
|
143 |
+
solar_radiation.get_toa_incident_solar_radiation(
|
144 |
+
timestamps,
|
145 |
+
lat,
|
146 |
+
lon,
|
147 |
+
integration_period="1h",
|
148 |
+
num_integration_bins=360,
|
149 |
+
use_jit=use_jit,
|
150 |
+
).block_until_ready()
|
151 |
+
|
152 |
+
results = timeit.repeat(benchmark, repeat=repeats, number=1)
|
153 |
+
|
154 |
+
logging.info(
|
155 |
+
"Times to compute `tisr` for input of shape `%d, %d, %d` (seconds): %s",
|
156 |
+
len(timestamps),
|
157 |
+
len(lat),
|
158 |
+
len(lon),
|
159 |
+
np.array2string(np.array(results), precision=1),
|
160 |
+
)
|
161 |
+
|
162 |
+
|
163 |
+
class GetTsiTest(parameterized.TestCase):
|
164 |
+
|
165 |
+
@parameterized.named_parameters(
|
166 |
+
dict(
|
167 |
+
testcase_name="reference_tsi_data",
|
168 |
+
loader=solar_radiation.reference_tsi_data,
|
169 |
+
expected_tsi=np.array([1361.0]),
|
170 |
+
),
|
171 |
+
dict(
|
172 |
+
testcase_name="era5_tsi_data",
|
173 |
+
loader=solar_radiation.era5_tsi_data,
|
174 |
+
expected_tsi=np.array([1360.9440]), # 0.9965 * 1365.7240
|
175 |
+
),
|
176 |
+
)
|
177 |
+
def test_mid_2020_lookup(
|
178 |
+
self, loader: solar_radiation.TsiDataLoader, expected_tsi: np.ndarray
|
179 |
+
):
|
180 |
+
tsi_data = loader()
|
181 |
+
|
182 |
+
tsi = solar_radiation.get_tsi(
|
183 |
+
[np.datetime64("2020-07-02T00:00:00")], tsi_data
|
184 |
+
)
|
185 |
+
|
186 |
+
np.testing.assert_allclose(expected_tsi, tsi)
|
187 |
+
|
188 |
+
@parameterized.named_parameters(
|
189 |
+
dict(
|
190 |
+
testcase_name="beginning_2020_left_boundary",
|
191 |
+
timestamps=[np.datetime64("2020-01-01T00:00:00")],
|
192 |
+
expected_tsi=np.array([1000.0]),
|
193 |
+
),
|
194 |
+
dict(
|
195 |
+
testcase_name="mid_2020_exact",
|
196 |
+
timestamps=[np.datetime64("2020-07-02T00:00:00")],
|
197 |
+
expected_tsi=np.array([1000.0]),
|
198 |
+
),
|
199 |
+
dict(
|
200 |
+
testcase_name="beginning_2021_interpolated",
|
201 |
+
timestamps=[np.datetime64("2021-01-01T00:00:00")],
|
202 |
+
expected_tsi=np.array([1150.0]),
|
203 |
+
),
|
204 |
+
dict(
|
205 |
+
testcase_name="mid_2021_lookup",
|
206 |
+
timestamps=[np.datetime64("2021-07-02T12:00:00")],
|
207 |
+
expected_tsi=np.array([1300.0]),
|
208 |
+
),
|
209 |
+
dict(
|
210 |
+
testcase_name="beginning_2022_interpolated",
|
211 |
+
timestamps=[np.datetime64("2022-01-01T00:00:00")],
|
212 |
+
expected_tsi=np.array([1250.0]),
|
213 |
+
),
|
214 |
+
dict(
|
215 |
+
testcase_name="mid_2022_lookup",
|
216 |
+
timestamps=[np.datetime64("2022-07-02T12:00:00")],
|
217 |
+
expected_tsi=np.array([1200.0]),
|
218 |
+
),
|
219 |
+
dict(
|
220 |
+
testcase_name="beginning_2023_right_boundary",
|
221 |
+
timestamps=[np.datetime64("2023-01-01T00:00:00")],
|
222 |
+
expected_tsi=np.array([1200.0]),
|
223 |
+
),
|
224 |
+
)
|
225 |
+
def test_interpolation(
|
226 |
+
self, timestamps: Sequence[np.datetime64], expected_tsi: np.ndarray
|
227 |
+
):
|
228 |
+
tsi_data = xa.DataArray(
|
229 |
+
np.array([1000.0, 1300.0, 1200.0]),
|
230 |
+
dims=["time"],
|
231 |
+
coords={"time": np.array([2020.5, 2021.5, 2022.5])},
|
232 |
+
)
|
233 |
+
|
234 |
+
tsi = solar_radiation.get_tsi(timestamps, tsi_data)
|
235 |
+
|
236 |
+
np.testing.assert_allclose(expected_tsi, tsi)
|
237 |
+
|
238 |
+
|
239 |
+
if __name__ == "__main__":
|
240 |
+
absltest.main()
|
graphcast/typed_graph.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Data-structure for storing graphs with typed edges and nodes."""
|
15 |
+
|
16 |
+
from typing import NamedTuple, Any, Union, Tuple, Mapping, TypeVar
|
17 |
+
|
18 |
+
ArrayLike = Union[Any] # np.ndarray, jnp.ndarray, tf.tensor
|
19 |
+
ArrayLikeTree = Union[Any, ArrayLike] # Nest of ArrayLike
|
20 |
+
|
21 |
+
_T = TypeVar('_T')
|
22 |
+
|
23 |
+
|
24 |
+
# All tensors have a "flat_batch_axis", which is similar to the leading
|
25 |
+
# axes of graph_tuples:
|
26 |
+
# * In the case of nodes this is simply a shared node and flat batch axis, with
|
27 |
+
# size corresponding to the total number of nodes in the flattened batch.
|
28 |
+
# * In the case of edges this is simply a shared edge and flat batch axis, with
|
29 |
+
# size corresponding to the total number of edges in the flattened batch.
|
30 |
+
# * In the case of globals this is simply the number of graphs in the flattened
|
31 |
+
# batch.
|
32 |
+
|
33 |
+
# All shapes may also have any additional leading shape "batch_shape".
|
34 |
+
# Options for building batches are:
|
35 |
+
# * Use a provided "flatten" method that takes a leading `batch_shape` and
|
36 |
+
# it into the flat_batch_axis (this will be useful when using `tf.Dataset`
|
37 |
+
# which supports batching into RaggedTensors, with leading batch shape even
|
38 |
+
# if graphs have different numbers of nodes and edges), so the RaggedBatches
|
39 |
+
# can then be converted into something without ragged dimensions that jax can
|
40 |
+
# use.
|
41 |
+
# * Directly build a "flat batch" using a provided function for batching a list
|
42 |
+
# of graphs (how it is done in `jraph`).
|
43 |
+
|
44 |
+
|
45 |
+
class NodeSet(NamedTuple):
|
46 |
+
"""Represents a set of nodes."""
|
47 |
+
n_node: ArrayLike # [num_flat_graphs]
|
48 |
+
features: ArrayLikeTree # Prev. `nodes`: [num_flat_nodes] + feature_shape
|
49 |
+
|
50 |
+
|
51 |
+
class EdgesIndices(NamedTuple):
|
52 |
+
"""Represents indices to nodes adjacent to the edges."""
|
53 |
+
senders: ArrayLike # [num_flat_edges]
|
54 |
+
receivers: ArrayLike # [num_flat_edges]
|
55 |
+
|
56 |
+
|
57 |
+
class EdgeSet(NamedTuple):
|
58 |
+
"""Represents a set of edges."""
|
59 |
+
n_edge: ArrayLike # [num_flat_graphs]
|
60 |
+
indices: EdgesIndices
|
61 |
+
features: ArrayLikeTree # Prev. `edges`: [num_flat_edges] + feature_shape
|
62 |
+
|
63 |
+
|
64 |
+
class Context(NamedTuple):
|
65 |
+
# `n_graph` always contains ones but it is useful to query the leading shape
|
66 |
+
# in case of graphs without any nodes or edges sets.
|
67 |
+
n_graph: ArrayLike # [num_flat_graphs]
|
68 |
+
features: ArrayLikeTree # Prev. `globals`: [num_flat_graphs] + feature_shape
|
69 |
+
|
70 |
+
|
71 |
+
class EdgeSetKey(NamedTuple):
|
72 |
+
name: str # Name of the EdgeSet.
|
73 |
+
|
74 |
+
# Sender node set name and receiver node set name connected by the edge set.
|
75 |
+
node_sets: Tuple[str, str]
|
76 |
+
|
77 |
+
|
78 |
+
class TypedGraph(NamedTuple):
|
79 |
+
"""A graph with typed nodes and edges.
|
80 |
+
|
81 |
+
A typed graph is made of a context, multiple sets of nodes and multiple
|
82 |
+
sets of edges connecting those nodes (as indicated by the EdgeSetKey).
|
83 |
+
"""
|
84 |
+
|
85 |
+
context: Context
|
86 |
+
nodes: Mapping[str, NodeSet]
|
87 |
+
edges: Mapping[EdgeSetKey, EdgeSet]
|
88 |
+
|
89 |
+
def edge_key_by_name(self, name: str) -> EdgeSetKey:
|
90 |
+
found_key = [k for k in self.edges.keys() if k.name == name]
|
91 |
+
if len(found_key) != 1:
|
92 |
+
raise KeyError("invalid edge key '{}'. Available edges: [{}]".format(
|
93 |
+
name, ', '.join(x.name for x in self.edges.keys())))
|
94 |
+
return found_key[0]
|
95 |
+
|
96 |
+
def edge_by_name(self, name: str) -> EdgeSet:
|
97 |
+
return self.edges[self.edge_key_by_name(name)]
|
graphcast/typed_graph_net.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""A library of typed Graph Neural Networks."""
|
15 |
+
|
16 |
+
from typing import Callable, Mapping, Optional, Union
|
17 |
+
|
18 |
+
from graphcast import typed_graph
|
19 |
+
import jax.numpy as jnp
|
20 |
+
import jax.tree_util as tree
|
21 |
+
import jraph
|
22 |
+
|
23 |
+
|
24 |
+
# All features will be an ArrayTree.
|
25 |
+
NodeFeatures = EdgeFeatures = SenderFeatures = ReceiverFeatures = Globals = (
|
26 |
+
jraph.ArrayTree)
|
27 |
+
|
28 |
+
# Signature:
|
29 |
+
# (node features, outgoing edge features, incoming edge features,
|
30 |
+
# globals) -> updated node features
|
31 |
+
GNUpdateNodeFn = Callable[
|
32 |
+
[NodeFeatures, Mapping[str, SenderFeatures], Mapping[str, ReceiverFeatures],
|
33 |
+
Globals],
|
34 |
+
NodeFeatures]
|
35 |
+
|
36 |
+
GNUpdateGlobalFn = Callable[
|
37 |
+
[Mapping[str, NodeFeatures], Mapping[str, EdgeFeatures], Globals],
|
38 |
+
Globals]
|
39 |
+
|
40 |
+
|
41 |
+
def GraphNetwork( # pylint: disable=invalid-name
|
42 |
+
update_edge_fn: Mapping[str, jraph.GNUpdateEdgeFn],
|
43 |
+
update_node_fn: Mapping[str, GNUpdateNodeFn],
|
44 |
+
update_global_fn: Optional[GNUpdateGlobalFn] = None,
|
45 |
+
aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
|
46 |
+
.segment_sum,
|
47 |
+
aggregate_nodes_for_globals_fn: jraph.AggregateNodesToGlobalsFn = jraph
|
48 |
+
.segment_sum,
|
49 |
+
aggregate_edges_for_globals_fn: jraph.AggregateEdgesToGlobalsFn = jraph
|
50 |
+
.segment_sum,
|
51 |
+
):
|
52 |
+
"""Returns a method that applies a configured GraphNetwork.
|
53 |
+
|
54 |
+
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
|
55 |
+
extended to Typed Graphs with multiple edge sets and node sets and extended to
|
56 |
+
allow aggregating not only edges received by the nodes, but also edges sent by
|
57 |
+
the nodes.
|
58 |
+
|
59 |
+
Example usage::
|
60 |
+
|
61 |
+
gn = GraphNetwork(update_edge_function,
|
62 |
+
update_node_function, **kwargs)
|
63 |
+
# Conduct multiple rounds of message passing with the same parameters:
|
64 |
+
for _ in range(num_message_passing_steps):
|
65 |
+
graph = gn(graph)
|
66 |
+
|
67 |
+
Args:
|
68 |
+
update_edge_fn: mapping of functions used to update a subset of the edge
|
69 |
+
types, indexed by edge type name.
|
70 |
+
update_node_fn: mapping of functions used to update a subset of the node
|
71 |
+
types, indexed by node type name.
|
72 |
+
update_global_fn: function used to update the globals or None to deactivate
|
73 |
+
globals updates.
|
74 |
+
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
|
75 |
+
node.
|
76 |
+
aggregate_nodes_for_globals_fn: function used to aggregate the nodes for the
|
77 |
+
globals.
|
78 |
+
aggregate_edges_for_globals_fn: function used to aggregate the edges for the
|
79 |
+
globals.
|
80 |
+
|
81 |
+
Returns:
|
82 |
+
A method that applies the configured GraphNetwork.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def _apply_graph_net(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
86 |
+
"""Applies a configured GraphNetwork to a graph.
|
87 |
+
|
88 |
+
This implementation follows Algorithm 1 in https://arxiv.org/abs/1806.01261
|
89 |
+
extended to Typed Graphs with multiple edge sets and node sets and extended
|
90 |
+
to allow aggregating not only edges received by the nodes, but also edges
|
91 |
+
sent by the nodes.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
graph: a `TypedGraph` containing the graph.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
Updated `TypedGraph`.
|
98 |
+
"""
|
99 |
+
|
100 |
+
updated_graph = graph
|
101 |
+
|
102 |
+
# Edge update.
|
103 |
+
updated_edges = dict(updated_graph.edges)
|
104 |
+
for edge_set_name, edge_fn in update_edge_fn.items():
|
105 |
+
edge_set_key = graph.edge_key_by_name(edge_set_name)
|
106 |
+
updated_edges[edge_set_key] = _edge_update(
|
107 |
+
updated_graph, edge_fn, edge_set_key)
|
108 |
+
updated_graph = updated_graph._replace(edges=updated_edges)
|
109 |
+
|
110 |
+
# Node update.
|
111 |
+
updated_nodes = dict(updated_graph.nodes)
|
112 |
+
for node_set_key, node_fn in update_node_fn.items():
|
113 |
+
updated_nodes[node_set_key] = _node_update(
|
114 |
+
updated_graph, node_fn, node_set_key, aggregate_edges_for_nodes_fn)
|
115 |
+
updated_graph = updated_graph._replace(nodes=updated_nodes)
|
116 |
+
|
117 |
+
# Global update.
|
118 |
+
if update_global_fn:
|
119 |
+
updated_context = _global_update(
|
120 |
+
updated_graph, update_global_fn,
|
121 |
+
aggregate_edges_for_globals_fn,
|
122 |
+
aggregate_nodes_for_globals_fn)
|
123 |
+
updated_graph = updated_graph._replace(context=updated_context)
|
124 |
+
|
125 |
+
return updated_graph
|
126 |
+
|
127 |
+
return _apply_graph_net
|
128 |
+
|
129 |
+
|
130 |
+
def _edge_update(graph, edge_fn, edge_set_key): # pylint: disable=invalid-name
|
131 |
+
"""Updates an edge set of a given key."""
|
132 |
+
|
133 |
+
sender_nodes = graph.nodes[edge_set_key.node_sets[0]]
|
134 |
+
receiver_nodes = graph.nodes[edge_set_key.node_sets[1]]
|
135 |
+
edge_set = graph.edges[edge_set_key]
|
136 |
+
senders = edge_set.indices.senders # pytype: disable=attribute-error
|
137 |
+
receivers = edge_set.indices.receivers # pytype: disable=attribute-error
|
138 |
+
|
139 |
+
sent_attributes = tree.tree_map(
|
140 |
+
lambda n: n[senders], sender_nodes.features)
|
141 |
+
received_attributes = tree.tree_map(
|
142 |
+
lambda n: n[receivers], receiver_nodes.features)
|
143 |
+
|
144 |
+
n_edge = edge_set.n_edge
|
145 |
+
sum_n_edge = senders.shape[0]
|
146 |
+
global_features = tree.tree_map(
|
147 |
+
lambda g: jnp.repeat(g, n_edge, axis=0, total_repeat_length=sum_n_edge),
|
148 |
+
graph.context.features)
|
149 |
+
new_features = edge_fn(
|
150 |
+
edge_set.features, sent_attributes, received_attributes,
|
151 |
+
global_features)
|
152 |
+
return edge_set._replace(features=new_features)
|
153 |
+
|
154 |
+
|
155 |
+
def _node_update(graph, node_fn, node_set_key, aggregation_fn): # pylint: disable=invalid-name
|
156 |
+
"""Updates an edge set of a given key."""
|
157 |
+
node_set = graph.nodes[node_set_key]
|
158 |
+
sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
|
159 |
+
|
160 |
+
sent_features = {}
|
161 |
+
for edge_set_key, edge_set in graph.edges.items():
|
162 |
+
sender_node_set_key = edge_set_key.node_sets[0]
|
163 |
+
if sender_node_set_key == node_set_key:
|
164 |
+
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
|
165 |
+
senders = edge_set.indices.senders
|
166 |
+
sent_features[edge_set_key.name] = tree.tree_map(
|
167 |
+
lambda e: aggregation_fn(e, senders, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
|
168 |
+
|
169 |
+
received_features = {}
|
170 |
+
for edge_set_key, edge_set in graph.edges.items():
|
171 |
+
receiver_node_set_key = edge_set_key.node_sets[1]
|
172 |
+
if receiver_node_set_key == node_set_key:
|
173 |
+
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
|
174 |
+
receivers = edge_set.indices.receivers
|
175 |
+
received_features[edge_set_key.name] = tree.tree_map(
|
176 |
+
lambda e: aggregation_fn(e, receivers, sum_n_node), edge_set.features) # pylint: disable=cell-var-from-loop
|
177 |
+
|
178 |
+
n_node = node_set.n_node
|
179 |
+
global_features = tree.tree_map(
|
180 |
+
lambda g: jnp.repeat(g, n_node, axis=0, total_repeat_length=sum_n_node),
|
181 |
+
graph.context.features)
|
182 |
+
new_features = node_fn(
|
183 |
+
node_set.features, sent_features, received_features, global_features)
|
184 |
+
return node_set._replace(features=new_features)
|
185 |
+
|
186 |
+
|
187 |
+
def _global_update(graph, global_fn, edge_aggregation_fn, node_aggregation_fn): # pylint: disable=invalid-name
|
188 |
+
"""Updates an edge set of a given key."""
|
189 |
+
n_graph = graph.context.n_graph.shape[0]
|
190 |
+
graph_idx = jnp.arange(n_graph)
|
191 |
+
|
192 |
+
edge_features = {}
|
193 |
+
for edge_set_key, edge_set in graph.edges.items():
|
194 |
+
assert isinstance(edge_set.indices, typed_graph.EdgesIndices)
|
195 |
+
sum_n_edge = edge_set.indices.senders.shape[0]
|
196 |
+
edge_gr_idx = jnp.repeat(
|
197 |
+
graph_idx, edge_set.n_edge, axis=0, total_repeat_length=sum_n_edge)
|
198 |
+
edge_features[edge_set_key.name] = tree.tree_map(
|
199 |
+
lambda e: edge_aggregation_fn(e, edge_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
|
200 |
+
edge_set.features)
|
201 |
+
|
202 |
+
node_features = {}
|
203 |
+
for node_set_key, node_set in graph.nodes.items():
|
204 |
+
sum_n_node = tree.tree_leaves(node_set.features)[0].shape[0]
|
205 |
+
node_gr_idx = jnp.repeat(
|
206 |
+
graph_idx, node_set.n_node, axis=0, total_repeat_length=sum_n_node)
|
207 |
+
node_features[node_set_key] = tree.tree_map(
|
208 |
+
lambda n: node_aggregation_fn(n, node_gr_idx, n_graph), # pylint: disable=cell-var-from-loop
|
209 |
+
node_set.features)
|
210 |
+
|
211 |
+
new_features = global_fn(node_features, edge_features, graph.context.features)
|
212 |
+
return graph.context._replace(features=new_features)
|
213 |
+
|
214 |
+
|
215 |
+
InteractionUpdateNodeFn = Callable[
|
216 |
+
[jraph.NodeFeatures,
|
217 |
+
Mapping[str, SenderFeatures],
|
218 |
+
Mapping[str, ReceiverFeatures]],
|
219 |
+
jraph.NodeFeatures]
|
220 |
+
|
221 |
+
|
222 |
+
InteractionUpdateNodeFnNoSentEdges = Callable[
|
223 |
+
[jraph.NodeFeatures,
|
224 |
+
Mapping[str, ReceiverFeatures]],
|
225 |
+
jraph.NodeFeatures]
|
226 |
+
|
227 |
+
|
228 |
+
def InteractionNetwork( # pylint: disable=invalid-name
|
229 |
+
update_edge_fn: Mapping[str, jraph.InteractionUpdateEdgeFn],
|
230 |
+
update_node_fn: Mapping[str, Union[InteractionUpdateNodeFn,
|
231 |
+
InteractionUpdateNodeFnNoSentEdges]],
|
232 |
+
aggregate_edges_for_nodes_fn: jraph.AggregateEdgesToNodesFn = jraph
|
233 |
+
.segment_sum,
|
234 |
+
include_sent_messages_in_node_update: bool = False):
|
235 |
+
"""Returns a method that applies a configured InteractionNetwork.
|
236 |
+
|
237 |
+
An interaction network computes interactions on the edges based on the
|
238 |
+
previous edges features, and on the features of the nodes sending into those
|
239 |
+
edges. It then updates the nodes based on the incoming updated edges.
|
240 |
+
See https://arxiv.org/abs/1612.00222 for more details.
|
241 |
+
|
242 |
+
This implementation extends the behavior to `TypedGraphs` adding an option
|
243 |
+
to include edge features for which a node is a sender in the arguments to
|
244 |
+
the node update function.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
update_edge_fn: mapping of functions used to update a subset of the edge
|
248 |
+
types, indexed by edge type name.
|
249 |
+
update_node_fn: mapping of functions used to update a subset of the node
|
250 |
+
types, indexed by node type name.
|
251 |
+
aggregate_edges_for_nodes_fn: function used to aggregate messages to each
|
252 |
+
node.
|
253 |
+
include_sent_messages_in_node_update: pass edge features for which a node is
|
254 |
+
a sender to the node update function.
|
255 |
+
"""
|
256 |
+
# An InteractionNetwork is a GraphNetwork without globals features,
|
257 |
+
# so we implement the InteractionNetwork as a configured GraphNetwork.
|
258 |
+
|
259 |
+
# An InteractionNetwork edge function does not have global feature inputs,
|
260 |
+
# so we filter the passed global argument in the GraphNetwork.
|
261 |
+
wrapped_update_edge_fn = tree.tree_map(
|
262 |
+
lambda fn: lambda e, s, r, g: fn(e, s, r), update_edge_fn)
|
263 |
+
|
264 |
+
# Similarly, we wrap the update_node_fn to ensure only the expected
|
265 |
+
# arguments are passed to the Interaction net.
|
266 |
+
if include_sent_messages_in_node_update:
|
267 |
+
wrapped_update_node_fn = tree.tree_map(
|
268 |
+
lambda fn: lambda n, s, r, g: fn(n, s, r), update_node_fn)
|
269 |
+
else:
|
270 |
+
wrapped_update_node_fn = tree.tree_map(
|
271 |
+
lambda fn: lambda n, s, r, g: fn(n, r), update_node_fn)
|
272 |
+
return GraphNetwork(
|
273 |
+
update_edge_fn=wrapped_update_edge_fn,
|
274 |
+
update_node_fn=wrapped_update_node_fn,
|
275 |
+
aggregate_edges_for_nodes_fn=aggregate_edges_for_nodes_fn)
|
276 |
+
|
277 |
+
|
278 |
+
def GraphMapFeatures( # pylint: disable=invalid-name
|
279 |
+
embed_edge_fn: Optional[Mapping[str, jraph.EmbedEdgeFn]] = None,
|
280 |
+
embed_node_fn: Optional[Mapping[str, jraph.EmbedNodeFn]] = None,
|
281 |
+
embed_global_fn: Optional[jraph.EmbedGlobalFn] = None):
|
282 |
+
"""Returns function which embeds the components of a graph independently.
|
283 |
+
|
284 |
+
Args:
|
285 |
+
embed_edge_fn: mapping of functions used to embed each edge type,
|
286 |
+
indexed by edge type name.
|
287 |
+
embed_node_fn: mapping of functions used to embed each node type,
|
288 |
+
indexed by node type name.
|
289 |
+
embed_global_fn: function used to embed the globals.
|
290 |
+
"""
|
291 |
+
|
292 |
+
def _embed(graph: typed_graph.TypedGraph) -> typed_graph.TypedGraph:
|
293 |
+
|
294 |
+
updated_edges = dict(graph.edges)
|
295 |
+
if embed_edge_fn:
|
296 |
+
for edge_set_name, embed_fn in embed_edge_fn.items():
|
297 |
+
edge_set_key = graph.edge_key_by_name(edge_set_name)
|
298 |
+
edge_set = graph.edges[edge_set_key]
|
299 |
+
updated_edges[edge_set_key] = edge_set._replace(
|
300 |
+
features=embed_fn(edge_set.features))
|
301 |
+
|
302 |
+
updated_nodes = dict(graph.nodes)
|
303 |
+
if embed_node_fn:
|
304 |
+
for node_set_key, embed_fn in embed_node_fn.items():
|
305 |
+
node_set = graph.nodes[node_set_key]
|
306 |
+
updated_nodes[node_set_key] = node_set._replace(
|
307 |
+
features=embed_fn(node_set.features))
|
308 |
+
|
309 |
+
updated_context = graph.context
|
310 |
+
if embed_global_fn:
|
311 |
+
updated_context = updated_context._replace(
|
312 |
+
features=embed_global_fn(updated_context.features))
|
313 |
+
|
314 |
+
return graph._replace(edges=updated_edges, nodes=updated_nodes,
|
315 |
+
context=updated_context)
|
316 |
+
|
317 |
+
return _embed
|
graphcast/xarray_jax.py
ADDED
@@ -0,0 +1,810 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Helpers to use xarray.{Variable,DataArray,Dataset} with JAX.
|
15 |
+
|
16 |
+
Allows them to be based on JAX arrays without converting to numpy arrays under
|
17 |
+
the hood, so you can start with a JAX array, do some computation with it in
|
18 |
+
xarray-land, get a JAX array out the other end and (for example) jax.jit
|
19 |
+
through the whole thing. You can even jax.jit a function which accepts and
|
20 |
+
returns xarray.Dataset, DataArray and Variable.
|
21 |
+
|
22 |
+
## Creating xarray datatypes from jax arrays, and vice-versa.
|
23 |
+
|
24 |
+
You can use the xarray_jax.{Variable,DataArray,Dataset} constructors, which have
|
25 |
+
the same API as the standard xarray constructors but will accept JAX arrays
|
26 |
+
without converting them to numpy.
|
27 |
+
|
28 |
+
It does this by wrapping the JAX array in a wrapper before passing it to
|
29 |
+
xarray; you can also do this manually by calling xarray_jax.wrap on your JAX
|
30 |
+
arrays before passing them to the standard xarray constructors.
|
31 |
+
|
32 |
+
To get non-wrapped JAX arrays out the other end, you can use e.g.:
|
33 |
+
|
34 |
+
xarray_jax.jax_vars(dataset)
|
35 |
+
xarray_jax.jax_data(dataset.some_var)
|
36 |
+
|
37 |
+
which will complain if the data isn't actually a JAX array. Use this if you need
|
38 |
+
to make sure the computation has gone via JAX, e.g. if it's the output of code
|
39 |
+
that you want to JIT or compute gradients through. If this is not the case and
|
40 |
+
you want to support passing plain numpy arrays through as well as potentially
|
41 |
+
JAX arrays, you can use:
|
42 |
+
|
43 |
+
xarray_jax.unwrap_vars(dataset)
|
44 |
+
xarray_jax.unwrap_data(dataset.some_var)
|
45 |
+
|
46 |
+
which will unwrap the data if it is a wrapped JAX array, but otherwise pass
|
47 |
+
it through to you without complaint.
|
48 |
+
|
49 |
+
The wrapped JAX arrays aim to support all the core operations from the numpy
|
50 |
+
array API that xarray expects, however there may still be some gaps; if you run
|
51 |
+
into any problems around this, you may need to add a few more proxy methods onto
|
52 |
+
the wrapper class below.
|
53 |
+
|
54 |
+
In future once JAX and xarray support the new Python array API standard
|
55 |
+
(https://data-apis.org/array-api/latest/index.html), we hope to avoid the need
|
56 |
+
for wrapping the JAX arrays like this.
|
57 |
+
|
58 |
+
## jax.jit and pmap of functions taking and returning xarray datatypes
|
59 |
+
|
60 |
+
We register xarray datatypes with jax.tree_util, which allows them to be treated
|
61 |
+
as generic containers of jax arrays by various parts of jax including jax.jit.
|
62 |
+
|
63 |
+
This allows for, e.g.:
|
64 |
+
|
65 |
+
@jax.jit
|
66 |
+
def foo(input: xarray.Dataset) -> xarray.Dataset:
|
67 |
+
...
|
68 |
+
|
69 |
+
It will not work out-of-the-box with shape-modifying transformations like
|
70 |
+
jax.pmap, or e.g. a jax.tree_util.tree_map with some transform that alters array
|
71 |
+
shapes or dimension order. That's because we won't know what dimension names
|
72 |
+
and/or coordinates to use when unflattening, if the results have a different
|
73 |
+
shape to the data that was originally flattened.
|
74 |
+
|
75 |
+
You can work around this using xarray_jax.dims_change_on_unflatten, however,
|
76 |
+
and in the case of jax.pmap we provide a wrapper xarray_jax.pmap which allows
|
77 |
+
it to be used with functions taking and returning xarrays.
|
78 |
+
|
79 |
+
## Treatment of coordinates
|
80 |
+
|
81 |
+
We don't support passing jax arrays as coordinates when constructing a
|
82 |
+
DataArray/Dataset. This is because xarray's advanced indexing and slicing is
|
83 |
+
unlikely to work with jax arrays (at least when a Tracer is used during
|
84 |
+
jax.jit), and also because some important datatypes used for coordinates, like
|
85 |
+
timedelta64 and datetime64, are not supported by jax.
|
86 |
+
|
87 |
+
For the purposes of tree_util and jax.jit, coordinates are not treated as leaves
|
88 |
+
of the tree (array data 'contained' by a Dataset/DataArray), they are just a
|
89 |
+
static part of the structure. That means that if a jit'ed function is called
|
90 |
+
twice with Dataset inputs that use different coordinates, it will compile a
|
91 |
+
separate version of the function for each. The coordinates are treated like
|
92 |
+
static_argnums by jax.jit.
|
93 |
+
|
94 |
+
If you want to use dynamic data for coordinates, we recommend making it a
|
95 |
+
data_var instead of a coord. You won't be able to do indexing and slicing using
|
96 |
+
the coordinate, but that wasn't going to work with a jax array anyway.
|
97 |
+
"""
|
98 |
+
|
99 |
+
import collections
|
100 |
+
import contextlib
|
101 |
+
import contextvars
|
102 |
+
from typing import Any, Callable, Hashable, Iterator, Mapping, Optional, Union, Tuple, TypeVar, cast
|
103 |
+
|
104 |
+
import jax
|
105 |
+
import jax.numpy as jnp
|
106 |
+
import numpy as np
|
107 |
+
import tree
|
108 |
+
import xarray
|
109 |
+
|
110 |
+
|
111 |
+
def Variable(dims, data, **kwargs) -> xarray.Variable: # pylint:disable=invalid-name
|
112 |
+
"""Like xarray.Variable, but can wrap JAX arrays."""
|
113 |
+
return xarray.Variable(dims, wrap(data), **kwargs)
|
114 |
+
|
115 |
+
|
116 |
+
_JAX_COORD_ATTR_NAME = '_jax_coord'
|
117 |
+
|
118 |
+
|
119 |
+
def DataArray( # pylint:disable=invalid-name
|
120 |
+
data,
|
121 |
+
coords=None,
|
122 |
+
dims=None,
|
123 |
+
name=None,
|
124 |
+
attrs=None,
|
125 |
+
jax_coords=None,
|
126 |
+
) -> xarray.DataArray:
|
127 |
+
"""Like xarray.DataArray, but supports using JAX arrays.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
data: As for xarray.DataArray, except jax arrays are also supported.
|
131 |
+
coords: Coordinates for the array, see xarray.DataArray. These coordinates
|
132 |
+
must be based on plain numpy arrays or something convertible to plain
|
133 |
+
numpy arrays. Their values will form a static part of the data structure
|
134 |
+
from the point of view of jax.tree_util. In particular this means these
|
135 |
+
coordinates will be passed as plain numpy arrays even inside a JIT'd
|
136 |
+
function, and the JIT'd function will be recompiled under the hood if the
|
137 |
+
coordinates of DataArrays passed into it change.
|
138 |
+
If this is not convenient for you, see also jax_coords below.
|
139 |
+
dims: See xarray.DataArray.
|
140 |
+
name: See xarray.DataArray.
|
141 |
+
attrs: See xarray.DataArray.
|
142 |
+
jax_coords: Additional coordinates, which *can* use JAX arrays. These
|
143 |
+
coordinates will be treated as JAX data from the point of view of
|
144 |
+
jax.tree_util, that means when JIT'ing they will be passed as tracers and
|
145 |
+
computation involving them will be JIT'd.
|
146 |
+
Unfortunately a side-effect of this is that they can't be used as index
|
147 |
+
coordinates (because xarray's indexing logic is not JIT-able). If you
|
148 |
+
specify a coordinate with the same name as a dimension here, it will not
|
149 |
+
be set as an index coordinate; this behaviour is different to the default
|
150 |
+
for `coords`, and it means that things like `.sel` based on the jax
|
151 |
+
coordinate will not work.
|
152 |
+
Note we require `jax_coords` to be explicitly specified via a different
|
153 |
+
constructor argument to `coords`, rather than just looking for jax arrays
|
154 |
+
within the `coords` and treating them differently. This is because it
|
155 |
+
affects the way jax.tree_util treats them, which is somewhat orthogonal to
|
156 |
+
whether the value is passed in as numpy or not, and generally needs to be
|
157 |
+
handled consistently so is something we encourage explicit control over.
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
An instance of xarray.DataArray. Where JAX arrays are used as data or
|
161 |
+
coords, they will be wrapped with JaxArrayWrapper and can be unwrapped via
|
162 |
+
`unwrap` and `unwrap_data`.
|
163 |
+
"""
|
164 |
+
result = xarray.DataArray(
|
165 |
+
wrap(data), dims=dims, name=name, attrs=attrs or {})
|
166 |
+
return assign_coords(result, coords=coords, jax_coords=jax_coords)
|
167 |
+
|
168 |
+
|
169 |
+
def Dataset( # pylint:disable=invalid-name
|
170 |
+
data_vars,
|
171 |
+
coords=None,
|
172 |
+
attrs=None,
|
173 |
+
jax_coords=None,
|
174 |
+
) -> xarray.Dataset:
|
175 |
+
"""Like xarray.Dataset, but can wrap JAX arrays.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
data_vars: As for xarray.Dataset, except jax arrays are also supported.
|
179 |
+
coords: Coordinates for the dataset, see xarray.Dataset. These coordinates
|
180 |
+
must be based on plain numpy arrays or something convertible to plain
|
181 |
+
numpy arrays. Their values will form a static part of the data structure
|
182 |
+
from the point of view of jax.tree_util. In particular this means these
|
183 |
+
coordinates will be passed as plain numpy arrays even inside a JIT'd
|
184 |
+
function, and the JIT'd function will be recompiled under the hood if the
|
185 |
+
coordinates of DataArrays passed into it change.
|
186 |
+
If this is not convenient for you, see also jax_coords below.
|
187 |
+
attrs: See xarray.Dataset.
|
188 |
+
jax_coords: Additional coordinates, which *can* use JAX arrays. These
|
189 |
+
coordinates will be treated as JAX data from the point of view of
|
190 |
+
jax.tree_util, that means when JIT'ing they will be passed as tracers and
|
191 |
+
computation involving them will be JIT'd.
|
192 |
+
Unfortunately a side-effect of this is that they can't be used as index
|
193 |
+
coordinates (because xarray's indexing logic is not JIT-able). If you
|
194 |
+
specify a coordinate with the same name as a dimension here, it will not
|
195 |
+
be set as an index coordinate; this behaviour is different to the default
|
196 |
+
for `coords`, and it means that things like `.sel` based on the jax
|
197 |
+
coordinate will not work.
|
198 |
+
Note we require `jax_coords` to be explicitly specified via a different
|
199 |
+
constructor argument to `coords`, rather than just looking for jax arrays
|
200 |
+
within the `coords` and treating them differently. This is because it
|
201 |
+
affects the way jax.tree_util treats them, which is somewhat orthogonal to
|
202 |
+
whether the value is passed in as numpy or not, and generally needs to be
|
203 |
+
handled consistently so is something we encourage explicit control over.
|
204 |
+
|
205 |
+
Returns:
|
206 |
+
An instance of xarray.Dataset. Where JAX arrays are used as data, they
|
207 |
+
will be wrapped with JaxArrayWrapper.
|
208 |
+
"""
|
209 |
+
wrapped_data_vars = {}
|
210 |
+
for name, var_like in data_vars.items():
|
211 |
+
# xarray.Dataset accepts a few different formats for data_vars:
|
212 |
+
if isinstance(var_like, jax.Array):
|
213 |
+
wrapped_data_vars[name] = wrap(var_like)
|
214 |
+
elif isinstance(var_like, tuple):
|
215 |
+
# Layout is (dims, data, ...). We wrap data.
|
216 |
+
wrapped_data_vars[name] = (var_like[0], wrap(var_like[1])) + var_like[2:]
|
217 |
+
else:
|
218 |
+
# Could be a plain numpy array or scalar (we don't wrap), or an
|
219 |
+
# xarray.Variable, DataArray etc, which we must assume is already wrapped
|
220 |
+
# if necessary (e.g. if creating using xarray_jax.{Variable,DataArray}).
|
221 |
+
wrapped_data_vars[name] = var_like
|
222 |
+
|
223 |
+
result = xarray.Dataset(
|
224 |
+
data_vars=wrapped_data_vars,
|
225 |
+
attrs=attrs)
|
226 |
+
|
227 |
+
return assign_coords(result, coords=coords, jax_coords=jax_coords)
|
228 |
+
|
229 |
+
|
230 |
+
DatasetOrDataArray = TypeVar(
|
231 |
+
'DatasetOrDataArray', xarray.Dataset, xarray.DataArray)
|
232 |
+
|
233 |
+
|
234 |
+
def assign_coords(
|
235 |
+
x: DatasetOrDataArray,
|
236 |
+
*,
|
237 |
+
coords: Optional[Mapping[Hashable, Any]] = None,
|
238 |
+
jax_coords: Optional[Mapping[Hashable, Any]] = None,
|
239 |
+
) -> DatasetOrDataArray:
|
240 |
+
"""Replacement for assign_coords which works in presence of jax_coords.
|
241 |
+
|
242 |
+
`jax_coords` allow certain specified coordinates to have their data passed as
|
243 |
+
JAX arrays (including through jax.jit boundaries). The compromise in return is
|
244 |
+
that they are not created as index coordinates and cannot be used for .sel
|
245 |
+
and other coordinate-based indexing operations. See docs for `jax_coords` on
|
246 |
+
xarray_jax.Dataset and xarray_jax.DataArray for more information.
|
247 |
+
|
248 |
+
This function can be used to set jax_coords on an existing DataArray or
|
249 |
+
Dataset, and also to set a mix of jax and non-jax coordinates. It implements
|
250 |
+
some workarounds to prevent xarray trying and failing to create IndexVariables
|
251 |
+
from jax arrays under the hood.
|
252 |
+
|
253 |
+
If you have any jax_coords with the same name as a dimension, you'll need to
|
254 |
+
use this function instead of data_array.assign_coords or dataset.assign_coords
|
255 |
+
in general, to avoid an xarray bug where it tries (and in our case fails) to
|
256 |
+
create indexes for existing jax coords. See
|
257 |
+
https://github.com/pydata/xarray/issues/7885.
|
258 |
+
|
259 |
+
Args:
|
260 |
+
x: An xarray Dataset or DataArray.
|
261 |
+
coords: Dict of (non-JAX) coords, or None if not assigning any.
|
262 |
+
jax_coords: Dict of JAX coords, or None if not assigning any. See docs for
|
263 |
+
xarray_jax.Dataset / DataArray for more information on jax_coords.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
The Dataset or DataArray with coordinates assigned, similarly to
|
267 |
+
Dataset.assign_coords / DataArray.assign_coords.
|
268 |
+
"""
|
269 |
+
coords = {} if coords is None else dict(coords) # Copy before mutating.
|
270 |
+
jax_coords = {} if jax_coords is None else dict(jax_coords)
|
271 |
+
|
272 |
+
# Any existing JAX coords must be dropped and re-added via the workaround
|
273 |
+
# below, since otherwise .assign_coords will trigger an xarray bug where
|
274 |
+
# it tries to recreate the indexes again for the existing coordinates.
|
275 |
+
# Can remove if/when https://github.com/pydata/xarray/issues/7885 fixed.
|
276 |
+
existing_jax_coords = get_jax_coords(x)
|
277 |
+
jax_coords = existing_jax_coords | jax_coords
|
278 |
+
x = x.drop_vars(existing_jax_coords.keys())
|
279 |
+
|
280 |
+
# We need to ensure that xarray doesn't try to create an index for
|
281 |
+
# coordinates with the same name as a dimension, since this will fail if
|
282 |
+
# given a wrapped JAX tracer.
|
283 |
+
# It appears the only way to avoid this is to name them differently to any
|
284 |
+
# dimension name, then rename them back afterwards.
|
285 |
+
renamed_jax_coords = {}
|
286 |
+
for name, coord in jax_coords.items():
|
287 |
+
if isinstance(coord, xarray.DataArray):
|
288 |
+
coord = coord.variable
|
289 |
+
if isinstance(coord, xarray.Variable):
|
290 |
+
coord = coord.copy(deep=False) # Copy before mutating attrs.
|
291 |
+
else:
|
292 |
+
# Must wrap as Variable with the correct dims first if this has not
|
293 |
+
# already been done, otherwise xarray.Dataset will assume the dimension
|
294 |
+
# name is also __NONINDEX_{n}.
|
295 |
+
coord = Variable((name,), coord)
|
296 |
+
|
297 |
+
# We set an attr on each jax_coord identifying it as such. These attrs on
|
298 |
+
# the coord Variable gets reflected on the coord DataArray exposed too, and
|
299 |
+
# when set on coordinates they generally get preserved under the default
|
300 |
+
# keep_attrs setting.
|
301 |
+
# These attrs are used by jax.tree_util registered flatten/unflatten to
|
302 |
+
# determine which coords need to be treated as leaves of the flattened
|
303 |
+
# structure vs static data.
|
304 |
+
coord.attrs[_JAX_COORD_ATTR_NAME] = True
|
305 |
+
renamed_jax_coords[f'__NONINDEX_{name}'] = coord
|
306 |
+
|
307 |
+
x = x.assign_coords(coords=coords | renamed_jax_coords)
|
308 |
+
|
309 |
+
rename_back_mapping = {f'__NONINDEX_{name}': name for name in jax_coords}
|
310 |
+
if isinstance(x, xarray.Dataset):
|
311 |
+
# Using 'rename' doesn't work if renaming to the same name as a dimension.
|
312 |
+
return x.rename_vars(rename_back_mapping)
|
313 |
+
else: # DataArray
|
314 |
+
return x.rename(rename_back_mapping)
|
315 |
+
|
316 |
+
|
317 |
+
def get_jax_coords(x: DatasetOrDataArray) -> Mapping[Hashable, Any]:
|
318 |
+
return {
|
319 |
+
name: coord_var
|
320 |
+
for name, coord_var in x.coords.variables.items()
|
321 |
+
if coord_var.attrs.get(_JAX_COORD_ATTR_NAME, False)}
|
322 |
+
|
323 |
+
|
324 |
+
def assign_jax_coords(
|
325 |
+
x: DatasetOrDataArray,
|
326 |
+
jax_coords: Optional[Mapping[Hashable, Any]] = None,
|
327 |
+
**jax_coords_kwargs
|
328 |
+
) -> DatasetOrDataArray:
|
329 |
+
"""Assigns only jax_coords, with same API as xarray's assign_coords."""
|
330 |
+
return assign_coords(x, jax_coords=jax_coords or jax_coords_kwargs)
|
331 |
+
|
332 |
+
|
333 |
+
def wrap(value):
|
334 |
+
"""Wraps JAX arrays for use in xarray, passing through other values."""
|
335 |
+
if isinstance(value, jax.Array):
|
336 |
+
return JaxArrayWrapper(value)
|
337 |
+
else:
|
338 |
+
return value
|
339 |
+
|
340 |
+
|
341 |
+
def unwrap(value, require_jax=False):
|
342 |
+
"""Unwraps wrapped JAX arrays used in xarray, passing through other values."""
|
343 |
+
if isinstance(value, JaxArrayWrapper):
|
344 |
+
return value.jax_array
|
345 |
+
elif isinstance(value, jax.Array):
|
346 |
+
return value
|
347 |
+
elif require_jax:
|
348 |
+
raise TypeError(f'Expected JAX array, found {type(value)}.')
|
349 |
+
else:
|
350 |
+
return value
|
351 |
+
|
352 |
+
|
353 |
+
def _wrapped(func):
|
354 |
+
"""Surrounds a function with JAX array unwrapping/wrapping."""
|
355 |
+
def wrapped_func(*args, **kwargs):
|
356 |
+
args, kwargs = tree.map_structure(unwrap, (args, kwargs))
|
357 |
+
result = func(*args, **kwargs)
|
358 |
+
return tree.map_structure(wrap, result)
|
359 |
+
return wrapped_func
|
360 |
+
|
361 |
+
|
362 |
+
def unwrap_data(
|
363 |
+
value: Union[xarray.Variable, xarray.DataArray],
|
364 |
+
require_jax: bool = False
|
365 |
+
) -> Union[jax.Array, np.ndarray]:
|
366 |
+
"""The unwrapped (see unwrap) data of a an xarray.Variable or DataArray."""
|
367 |
+
return unwrap(value.data, require_jax=require_jax)
|
368 |
+
|
369 |
+
|
370 |
+
def unwrap_vars(
|
371 |
+
dataset: Mapping[Hashable, xarray.DataArray],
|
372 |
+
require_jax: bool = False
|
373 |
+
) -> Mapping[str, Union[jax.Array, np.ndarray]]:
|
374 |
+
"""The unwrapped data (see unwrap) of the variables in a dataset."""
|
375 |
+
# xarray types variable names as Hashable, but in practice they're invariably
|
376 |
+
# strings and we convert to str to allow for a more useful return type.
|
377 |
+
return {str(name): unwrap_data(var, require_jax=require_jax)
|
378 |
+
for name, var in dataset.items()}
|
379 |
+
|
380 |
+
|
381 |
+
def unwrap_coords(
|
382 |
+
dataset: Union[xarray.Dataset, xarray.DataArray],
|
383 |
+
require_jax: bool = False
|
384 |
+
) -> Mapping[str, Union[jax.Array, np.ndarray]]:
|
385 |
+
"""The unwrapped data (see unwrap) of the coords in a Dataset or DataArray."""
|
386 |
+
return {str(name): unwrap_data(var, require_jax=require_jax)
|
387 |
+
for name, var in dataset.coords.items()}
|
388 |
+
|
389 |
+
|
390 |
+
def jax_data(value: Union[xarray.Variable, xarray.DataArray]) -> jax.Array:
|
391 |
+
"""Like unwrap_data, but will complain if not a jax array."""
|
392 |
+
# Implementing this separately so we can give a more specific return type
|
393 |
+
# for it.
|
394 |
+
return cast(jax.Array, unwrap_data(value, require_jax=True))
|
395 |
+
|
396 |
+
|
397 |
+
def jax_vars(
|
398 |
+
dataset: Mapping[Hashable, xarray.DataArray]) -> Mapping[str, jax.Array]:
|
399 |
+
"""Like unwrap_vars, but will complain if vars are not all jax arrays."""
|
400 |
+
return cast(Mapping[str, jax.Array], unwrap_vars(dataset, require_jax=True))
|
401 |
+
|
402 |
+
|
403 |
+
class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
|
404 |
+
"""Wraps a JAX array into a duck-typed array suitable for use with xarray.
|
405 |
+
|
406 |
+
This uses an older duck-typed array protocol based on __array_ufunc__ and
|
407 |
+
__array_function__ which works with numpy and xarray. (In newer versions
|
408 |
+
of xarray it implements xarray.namedarray._typing._array_function.)
|
409 |
+
|
410 |
+
This is in the process of being superseded by the Python array API standard
|
411 |
+
(https://data-apis.org/array-api/latest/index.html), but JAX hasn't
|
412 |
+
implemented it yet. Once they have, we should be able to get rid of
|
413 |
+
this wrapper and use JAX arrays directly with xarray.
|
414 |
+
|
415 |
+
"""
|
416 |
+
|
417 |
+
def __init__(self, jax_array):
|
418 |
+
self.jax_array = jax_array
|
419 |
+
|
420 |
+
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
|
421 |
+
for x in args:
|
422 |
+
if not isinstance(x, (jax.typing.ArrayLike, type(self))):
|
423 |
+
return NotImplemented
|
424 |
+
if method != '__call__':
|
425 |
+
return NotImplemented
|
426 |
+
try:
|
427 |
+
# Get the corresponding jax.numpy function to the NumPy ufunc:
|
428 |
+
func = getattr(jnp, ufunc.__name__)
|
429 |
+
except AttributeError:
|
430 |
+
return NotImplemented
|
431 |
+
# There may be an 'out' kwarg requesting an in-place operation, e.g. when
|
432 |
+
# this is called via __iadd__ (+=), __imul__ (*=) etc. JAX doesn't support
|
433 |
+
# in-place operations so we just remove this argument and have the ufunc
|
434 |
+
# return a fresh JAX array instead.
|
435 |
+
kwargs.pop('out', None)
|
436 |
+
return _wrapped(func)(*args, **kwargs)
|
437 |
+
|
438 |
+
def __array_function__(self, func, types, args, kwargs):
|
439 |
+
try:
|
440 |
+
# Get the corresponding jax.np function to the NumPy function:
|
441 |
+
func = getattr(jnp, func.__name__)
|
442 |
+
except AttributeError:
|
443 |
+
return NotImplemented
|
444 |
+
return _wrapped(func)(*args, **kwargs)
|
445 |
+
|
446 |
+
def __repr__(self):
|
447 |
+
return f'xarray_jax.JaxArrayWrapper({repr(self.jax_array)})'
|
448 |
+
|
449 |
+
# NDArrayOperatorsMixin already proxies most __dunder__ operator methods.
|
450 |
+
# We need to proxy through a few more methods in a similar way:
|
451 |
+
|
452 |
+
# Essential array properties:
|
453 |
+
|
454 |
+
@property
|
455 |
+
def shape(self):
|
456 |
+
return self.jax_array.shape
|
457 |
+
|
458 |
+
@property
|
459 |
+
def dtype(self):
|
460 |
+
return self.jax_array.dtype
|
461 |
+
|
462 |
+
@property
|
463 |
+
def ndim(self):
|
464 |
+
return self.jax_array.ndim
|
465 |
+
|
466 |
+
@property
|
467 |
+
def size(self):
|
468 |
+
return self.jax_array.size
|
469 |
+
|
470 |
+
@property
|
471 |
+
def real(self):
|
472 |
+
return self.jax_array.real
|
473 |
+
|
474 |
+
@property
|
475 |
+
def imag(self):
|
476 |
+
return self.jax_array.imag
|
477 |
+
|
478 |
+
# Array methods not covered by NDArrayOperatorsMixin:
|
479 |
+
|
480 |
+
# Allows conversion to numpy array using np.asarray etc. Warning: doing this
|
481 |
+
# will fail in a jax.jit-ed function.
|
482 |
+
def __array__(self, dtype=None, context=None):
|
483 |
+
return np.asarray(self.jax_array, dtype=dtype)
|
484 |
+
|
485 |
+
__getitem__ = _wrapped(lambda array, *args: array.__getitem__(*args))
|
486 |
+
# We drop the kwargs on this as they are not supported by JAX, but xarray
|
487 |
+
# uses at least one of them (the copy arg).
|
488 |
+
astype = _wrapped(lambda array, *args, **kwargs: array.astype(*args))
|
489 |
+
|
490 |
+
# There are many more methods which are more canonically available via (j)np
|
491 |
+
# functions, e.g. .sum() available via jnp.sum, and also mean, max, min,
|
492 |
+
# argmax, argmin etc. We don't attempt to proxy through all of these as
|
493 |
+
# methods, since this doesn't appear to be expected from a duck-typed array
|
494 |
+
# implementation. But there are a few which xarray calls as methods, so we
|
495 |
+
# proxy those:
|
496 |
+
transpose = _wrapped(jnp.transpose)
|
497 |
+
reshape = _wrapped(jnp.reshape)
|
498 |
+
all = _wrapped(jnp.all)
|
499 |
+
|
500 |
+
|
501 |
+
def apply_ufunc(func, *args, require_jax=False, **apply_ufunc_kwargs):
|
502 |
+
"""Like xarray.apply_ufunc but for jax-specific ufuncs.
|
503 |
+
|
504 |
+
Many numpy ufuncs will work fine out of the box with xarray_jax and
|
505 |
+
JaxArrayWrapper, since JaxArrayWrapper quacks (mostly) like a numpy array and
|
506 |
+
will convert many numpy operations to jax ops under the hood. For these
|
507 |
+
situations, xarray.apply_ufunc should work fine.
|
508 |
+
|
509 |
+
But sometimes you need a jax-specific ufunc which needs to be given a
|
510 |
+
jax array as input or return a jax array as output. In that case you should
|
511 |
+
use this helper as it will remove any JaxArrayWrapper before calling the func,
|
512 |
+
and wrap the result afterwards before handing it back to xarray.
|
513 |
+
|
514 |
+
Args:
|
515 |
+
func: A function that works with jax arrays (e.g. using functions from
|
516 |
+
jax.numpy) but otherwise meets the spec for the func argument to
|
517 |
+
xarray.apply_ufunc.
|
518 |
+
*args: xarray arguments to be mapped to arguments for func
|
519 |
+
(see xarray.apply_ufunc).
|
520 |
+
require_jax: Whether to require that inputs are based on jax arrays or allow
|
521 |
+
those based on plain numpy arrays too.
|
522 |
+
**apply_ufunc_kwargs: See xarray.apply_ufunc.
|
523 |
+
|
524 |
+
Returns:
|
525 |
+
Corresponding xarray results (see xarray.apply_ufunc).
|
526 |
+
"""
|
527 |
+
def wrapped_func(*maybe_wrapped_args):
|
528 |
+
unwrapped_args = [unwrap(a, require_jax) for a in maybe_wrapped_args]
|
529 |
+
result = func(*unwrapped_args)
|
530 |
+
# Result can be an array or a tuple of arrays, this handles both:
|
531 |
+
return jax.tree_util.tree_map(wrap, result)
|
532 |
+
return xarray.apply_ufunc(wrapped_func, *args, **apply_ufunc_kwargs)
|
533 |
+
|
534 |
+
|
535 |
+
def pmap(fn: Callable[..., Any],
|
536 |
+
dim: str,
|
537 |
+
axis_name: Optional[str] = None,
|
538 |
+
devices: ... = None,
|
539 |
+
backend: ... = None) -> Callable[..., Any]:
|
540 |
+
"""Wraps a subset of jax.pmap functionality to handle xarray input/output.
|
541 |
+
|
542 |
+
Constraints:
|
543 |
+
* Any Dataset or DataArray passed to the function must have `dim` as the
|
544 |
+
first dimension. This will be checked. You can ensure this if necessary
|
545 |
+
by calling `.transpose(dim, ...)` beforehand.
|
546 |
+
* All args and return values will be mapped over the first dimension,
|
547 |
+
it will use in_axes=0, out_axes=0.
|
548 |
+
* No support for static_broadcasted_argnums, donate_argnums etc.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
fn: Function to be pmap'd which takes and returns trees which may contain
|
552 |
+
xarray Dataset/DataArray. Any Dataset/DataArrays passed as input must use
|
553 |
+
`dim` as the first dimension on all arrays.
|
554 |
+
dim: The xarray dimension name corresponding to the first dimension that is
|
555 |
+
pmapped over (pmap is called with in_axes=0, out_axes=0).
|
556 |
+
axis_name: Used by jax to identify the mapped axis so that parallel
|
557 |
+
collectives can be applied. Defaults to same as `dim`.
|
558 |
+
devices:
|
559 |
+
backend:
|
560 |
+
See jax.pmap.
|
561 |
+
|
562 |
+
Returns:
|
563 |
+
A pmap'd version of `fn`, which takes and returns Dataset/DataArray with an
|
564 |
+
extra leading dimension `dim` relative to what the original `fn` sees.
|
565 |
+
"""
|
566 |
+
input_treedef = None
|
567 |
+
output_treedef = None
|
568 |
+
|
569 |
+
def fn_passed_to_pmap(*flat_args):
|
570 |
+
assert input_treedef is not None
|
571 |
+
# Inside the pmap the original first dimension will no longer be present:
|
572 |
+
def check_and_remove_leading_dim(dims):
|
573 |
+
try:
|
574 |
+
index = dims.index(dim)
|
575 |
+
except ValueError:
|
576 |
+
index = None
|
577 |
+
if index != 0:
|
578 |
+
raise ValueError(f'Expected dim {dim} at index 0, found at {index}.')
|
579 |
+
return dims[1:]
|
580 |
+
with dims_change_on_unflatten(check_and_remove_leading_dim):
|
581 |
+
args = jax.tree_util.tree_unflatten(input_treedef, flat_args)
|
582 |
+
result = fn(*args)
|
583 |
+
nonlocal output_treedef
|
584 |
+
flat_result, output_treedef = jax.tree_util.tree_flatten(result)
|
585 |
+
return flat_result
|
586 |
+
|
587 |
+
pmapped_fn = jax.pmap(
|
588 |
+
fn_passed_to_pmap,
|
589 |
+
axis_name=axis_name or dim,
|
590 |
+
in_axes=0,
|
591 |
+
out_axes=0,
|
592 |
+
devices=devices,
|
593 |
+
backend=backend)
|
594 |
+
|
595 |
+
def result_fn(*args):
|
596 |
+
nonlocal input_treedef
|
597 |
+
flat_args, input_treedef = jax.tree_util.tree_flatten(args)
|
598 |
+
flat_result = pmapped_fn(*flat_args)
|
599 |
+
assert output_treedef is not None
|
600 |
+
# After the pmap an extra leading axis will be present, we need to add an
|
601 |
+
# xarray dimension for this when unflattening the result:
|
602 |
+
with dims_change_on_unflatten(lambda dims: (dim,) + dims):
|
603 |
+
return jax.tree_util.tree_unflatten(output_treedef, flat_result)
|
604 |
+
|
605 |
+
return result_fn
|
606 |
+
|
607 |
+
|
608 |
+
# Register xarray datatypes with jax.tree_util.
|
609 |
+
|
610 |
+
|
611 |
+
DimsChangeFn = Callable[[Tuple[Hashable, ...]], Tuple[Hashable, ...]]
|
612 |
+
_DIMS_CHANGE_ON_UNFLATTEN_FN: contextvars.ContextVar[DimsChangeFn] = (
|
613 |
+
contextvars.ContextVar('dims_change_on_unflatten_fn'))
|
614 |
+
|
615 |
+
|
616 |
+
@contextlib.contextmanager
|
617 |
+
def dims_change_on_unflatten(dims_change_fn: DimsChangeFn):
|
618 |
+
"""Can be used to change the dims used when unflattening arrays into xarrays.
|
619 |
+
|
620 |
+
This is useful when some axes were added to / removed from the underlying jax
|
621 |
+
arrays after they were flattened using jax.tree_util.tree_flatten, and you
|
622 |
+
want to unflatten them again afterwards using the original treedef but
|
623 |
+
adjusted for the added/removed dimensions.
|
624 |
+
|
625 |
+
It can also be used with jax.tree_util.tree_map, when it's called with a
|
626 |
+
function that adds/removes axes or otherwise changes the axis order.
|
627 |
+
|
628 |
+
When dimensions are removed, any coordinates using those removed dimensions
|
629 |
+
will also be removed on unflatten.
|
630 |
+
|
631 |
+
This is implemented as a context manager that sets some thread-local state
|
632 |
+
affecting the behaviour of our unflatten functions, because it's not possible
|
633 |
+
to directly modify the treedef to change the dims/coords in it (and with
|
634 |
+
tree_map, the treedef isn't exposed to you anyway).
|
635 |
+
|
636 |
+
Args:
|
637 |
+
dims_change_fn: Maps a tuple of dimension names for the original
|
638 |
+
Variable/DataArray/Dataset that was flattened, to an updated tuple of
|
639 |
+
dimensions which should be used when unflattening.
|
640 |
+
|
641 |
+
Yields:
|
642 |
+
To a context manager in whose scope jax.tree_util.tree_unflatten and
|
643 |
+
jax.tree_util.tree_map will apply the dims_change_fn before reconstructing
|
644 |
+
xarrays from jax arrays.
|
645 |
+
"""
|
646 |
+
token = _DIMS_CHANGE_ON_UNFLATTEN_FN.set(dims_change_fn)
|
647 |
+
try:
|
648 |
+
yield
|
649 |
+
finally:
|
650 |
+
_DIMS_CHANGE_ON_UNFLATTEN_FN.reset(token)
|
651 |
+
|
652 |
+
|
653 |
+
def _flatten_variable(v: xarray.Variable) -> Tuple[
|
654 |
+
Tuple[jax.typing.ArrayLike], Tuple[Hashable, ...]]:
|
655 |
+
"""Flattens a Variable for jax.tree_util."""
|
656 |
+
children = (unwrap_data(v),)
|
657 |
+
aux = v.dims
|
658 |
+
return children, aux
|
659 |
+
|
660 |
+
|
661 |
+
def _unflatten_variable(
|
662 |
+
aux: Tuple[Hashable, ...],
|
663 |
+
children: Tuple[jax.typing.ArrayLike]) -> xarray.Variable:
|
664 |
+
"""Unflattens a Variable for jax.tree_util."""
|
665 |
+
dims = aux
|
666 |
+
dims_change_fn = _DIMS_CHANGE_ON_UNFLATTEN_FN.get(None)
|
667 |
+
if dims_change_fn: dims = dims_change_fn(dims)
|
668 |
+
return Variable(dims=dims, data=children[0])
|
669 |
+
|
670 |
+
|
671 |
+
def _split_static_and_jax_coords(
|
672 |
+
coords: xarray.core.coordinates.Coordinates) -> Tuple[
|
673 |
+
Mapping[Hashable, xarray.Variable], Mapping[Hashable, xarray.Variable]]:
|
674 |
+
static_coord_vars = {}
|
675 |
+
jax_coord_vars = {}
|
676 |
+
for name, coord in coords.items():
|
677 |
+
if coord.attrs.get(_JAX_COORD_ATTR_NAME, False):
|
678 |
+
jax_coord_vars[name] = coord.variable
|
679 |
+
else:
|
680 |
+
assert not isinstance(coord, (jax.Array, JaxArrayWrapper))
|
681 |
+
static_coord_vars[name] = coord.variable
|
682 |
+
return static_coord_vars, jax_coord_vars
|
683 |
+
|
684 |
+
|
685 |
+
def _drop_with_none_of_dims(
|
686 |
+
coord_vars: Mapping[Hashable, xarray.Variable],
|
687 |
+
dims: Tuple[Hashable]) -> Mapping[Hashable, xarray.Variable]:
|
688 |
+
return {name: var for name, var in coord_vars.items()
|
689 |
+
if set(var.dims) <= set(dims)}
|
690 |
+
|
691 |
+
|
692 |
+
class _HashableCoords(collections.abc.Mapping):
|
693 |
+
"""Wraps a dict of xarray Variables as hashable, used for static coordinates.
|
694 |
+
|
695 |
+
This needs to be hashable so that when an xarray.Dataset is passed to a
|
696 |
+
jax.jit'ed function, jax can check whether it's seen an array with the
|
697 |
+
same static coordinates(*) before or whether it needs to recompile the
|
698 |
+
function for the new values of the static coordinates.
|
699 |
+
|
700 |
+
(*) note jax_coords are not included in this; their value can be different
|
701 |
+
on different calls without triggering a recompile.
|
702 |
+
"""
|
703 |
+
|
704 |
+
def __init__(self, coord_vars: Mapping[Hashable, xarray.Variable]):
|
705 |
+
self._variables = coord_vars
|
706 |
+
|
707 |
+
def __repr__(self) -> str:
|
708 |
+
return f'_HashableCoords({repr(self._variables)})'
|
709 |
+
|
710 |
+
def __getitem__(self, key: Hashable) -> xarray.Variable:
|
711 |
+
return self._variables[key]
|
712 |
+
|
713 |
+
def __len__(self) -> int:
|
714 |
+
return len(self._variables)
|
715 |
+
|
716 |
+
def __iter__(self) -> Iterator[Hashable]:
|
717 |
+
return iter(self._variables)
|
718 |
+
|
719 |
+
def __hash__(self):
|
720 |
+
if not hasattr(self, '_hash'):
|
721 |
+
self._hash = hash(frozenset((name, var.data.tobytes())
|
722 |
+
for name, var in self._variables.items()))
|
723 |
+
return self._hash
|
724 |
+
|
725 |
+
def __eq__(self, other):
|
726 |
+
if self is other:
|
727 |
+
return True
|
728 |
+
elif not isinstance(other, type(self)):
|
729 |
+
return NotImplemented
|
730 |
+
elif self._variables is other._variables:
|
731 |
+
return True
|
732 |
+
else:
|
733 |
+
return self._variables.keys() == other._variables.keys() and all(
|
734 |
+
variable.equals(other._variables[name])
|
735 |
+
for name, variable in self._variables.items())
|
736 |
+
|
737 |
+
|
738 |
+
def _flatten_data_array(v: xarray.DataArray) -> Tuple[
|
739 |
+
# Children (data variable, jax_coord_vars):
|
740 |
+
Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
|
741 |
+
# Static auxiliary data (name, static_coord_vars):
|
742 |
+
Tuple[Optional[Hashable], _HashableCoords]]:
|
743 |
+
"""Flattens a DataArray for jax.tree_util."""
|
744 |
+
static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(v.coords)
|
745 |
+
children = (v.variable, jax_coord_vars)
|
746 |
+
aux = (v.name, _HashableCoords(static_coord_vars))
|
747 |
+
return children, aux
|
748 |
+
|
749 |
+
|
750 |
+
def _unflatten_data_array(
|
751 |
+
aux: Tuple[Optional[Hashable], _HashableCoords],
|
752 |
+
children: Tuple[xarray.Variable, Mapping[Hashable, xarray.Variable]],
|
753 |
+
) -> xarray.DataArray:
|
754 |
+
"""Unflattens a DataArray for jax.tree_util."""
|
755 |
+
variable, jax_coord_vars = children
|
756 |
+
name, static_coord_vars = aux
|
757 |
+
# Drop static coords which have dims not present in any of the data_vars.
|
758 |
+
# These would generally be dims that were dropped by a dims_change_fn, but
|
759 |
+
# because static coordinates don't go through dims_change_fn on unflatten, we
|
760 |
+
# just drop them where this causes a problem.
|
761 |
+
# Since jax_coords go through the dims_change_fn on unflatten we don't need
|
762 |
+
# to do this for jax_coords.
|
763 |
+
static_coord_vars = _drop_with_none_of_dims(static_coord_vars, variable.dims)
|
764 |
+
return DataArray(
|
765 |
+
variable, name=name, coords=static_coord_vars, jax_coords=jax_coord_vars)
|
766 |
+
|
767 |
+
|
768 |
+
def _flatten_dataset(dataset: xarray.Dataset) -> Tuple[
|
769 |
+
# Children (data variables, jax_coord_vars):
|
770 |
+
Tuple[Mapping[Hashable, xarray.Variable],
|
771 |
+
Mapping[Hashable, xarray.Variable]],
|
772 |
+
# Static auxiliary data (static_coord_vars):
|
773 |
+
_HashableCoords]:
|
774 |
+
"""Flattens a Dataset for jax.tree_util."""
|
775 |
+
variables = {name: data_array.variable
|
776 |
+
for name, data_array in dataset.data_vars.items()}
|
777 |
+
static_coord_vars, jax_coord_vars = _split_static_and_jax_coords(
|
778 |
+
dataset.coords)
|
779 |
+
children = (variables, jax_coord_vars)
|
780 |
+
aux = _HashableCoords(static_coord_vars)
|
781 |
+
return children, aux
|
782 |
+
|
783 |
+
|
784 |
+
def _unflatten_dataset(
|
785 |
+
aux: _HashableCoords,
|
786 |
+
children: Tuple[Mapping[Hashable, xarray.Variable],
|
787 |
+
Mapping[Hashable, xarray.Variable]],
|
788 |
+
) -> xarray.Dataset:
|
789 |
+
"""Unflattens a Dataset for jax.tree_util."""
|
790 |
+
data_vars, jax_coord_vars = children
|
791 |
+
static_coord_vars = aux
|
792 |
+
dataset = xarray.Dataset(data_vars)
|
793 |
+
# Drop static coords which have dims not present in any of the data_vars.
|
794 |
+
# See corresponding comment in _unflatten_data_array.
|
795 |
+
static_coord_vars = _drop_with_none_of_dims(static_coord_vars, dataset.dims) # pytype: disable=wrong-arg-types
|
796 |
+
return assign_coords(
|
797 |
+
dataset, coords=static_coord_vars, jax_coords=jax_coord_vars)
|
798 |
+
|
799 |
+
|
800 |
+
jax.tree_util.register_pytree_node(
|
801 |
+
xarray.Variable, _flatten_variable, _unflatten_variable)
|
802 |
+
# This is a subclass of Variable but still needs registering separately.
|
803 |
+
# Flatten/unflatten for IndexVariable is a bit of a corner case but we do
|
804 |
+
# need to support it.
|
805 |
+
jax.tree_util.register_pytree_node(
|
806 |
+
xarray.IndexVariable, _flatten_variable, _unflatten_variable)
|
807 |
+
jax.tree_util.register_pytree_node(
|
808 |
+
xarray.DataArray, _flatten_data_array, _unflatten_data_array)
|
809 |
+
jax.tree_util.register_pytree_node(
|
810 |
+
xarray.Dataset, _flatten_dataset, _unflatten_dataset)
|
graphcast/xarray_jax_test.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tests for xarray_jax."""
|
15 |
+
|
16 |
+
from absl.testing import absltest
|
17 |
+
import chex
|
18 |
+
from graphcast import xarray_jax
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
import numpy as np
|
22 |
+
import xarray
|
23 |
+
|
24 |
+
|
25 |
+
class XarrayJaxTest(absltest.TestCase):
|
26 |
+
|
27 |
+
def test_jax_array_wrapper_with_numpy_api(self):
|
28 |
+
# This is just a side benefit of making things work with xarray, but the
|
29 |
+
# JaxArrayWrapper does allow you to manipulate JAX arrays using the
|
30 |
+
# standard numpy API, without converting them to numpy in the process:
|
31 |
+
ones = jnp.ones((3, 4), dtype=np.float32)
|
32 |
+
x = xarray_jax.JaxArrayWrapper(ones)
|
33 |
+
x = np.abs((x + 2) * (x - 3))
|
34 |
+
x = x[:-1, 1:3]
|
35 |
+
x = np.concatenate([x, x + 1], axis=0)
|
36 |
+
x = np.transpose(x, (1, 0))
|
37 |
+
x = np.reshape(x, (-1,))
|
38 |
+
x = x.astype(np.int32)
|
39 |
+
self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
|
40 |
+
# An explicit conversion gets us out of JAX-land however:
|
41 |
+
self.assertIsInstance(np.asarray(x), np.ndarray)
|
42 |
+
|
43 |
+
def test_jax_xarray_variable(self):
|
44 |
+
def ops_via_xarray(inputs):
|
45 |
+
x = xarray_jax.Variable(('lat', 'lon'), inputs)
|
46 |
+
# We'll apply a sequence of operations just to test that the end result is
|
47 |
+
# still a JAX array, i.e. we haven't converted to numpy at any point.
|
48 |
+
x = np.abs((x + 2) * (x - 3))
|
49 |
+
x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
|
50 |
+
x = xarray.Variable.concat([x, x + 1], dim='lat')
|
51 |
+
x = x.transpose('lon', 'lat')
|
52 |
+
x = x.stack(channels=('lon', 'lat'))
|
53 |
+
x = x.sum()
|
54 |
+
return xarray_jax.jax_data(x)
|
55 |
+
|
56 |
+
# Check it doesn't leave jax-land when passed concrete values:
|
57 |
+
ones = jnp.ones((3, 4), dtype=np.float32)
|
58 |
+
result = ops_via_xarray(ones)
|
59 |
+
self.assertIsInstance(result, jax.Array)
|
60 |
+
|
61 |
+
# And that you can JIT it and compute gradients through it. These will
|
62 |
+
# involve passing jax tracers through the xarray computation:
|
63 |
+
jax.jit(ops_via_xarray)(ones)
|
64 |
+
jax.grad(ops_via_xarray)(ones)
|
65 |
+
|
66 |
+
def test_jax_xarray_data_array(self):
|
67 |
+
def ops_via_xarray(inputs):
|
68 |
+
x = xarray_jax.DataArray(dims=('lat', 'lon'),
|
69 |
+
data=inputs,
|
70 |
+
coords={'lat': np.arange(3) * 10,
|
71 |
+
'lon': np.arange(4) * 10})
|
72 |
+
x = np.abs((x + 2) * (x - 3))
|
73 |
+
x = x.sel({'lat': slice(0, 20)})
|
74 |
+
y = xarray_jax.DataArray(dims=('lat', 'lon'),
|
75 |
+
data=ones,
|
76 |
+
coords={'lat': np.arange(3, 6) * 10,
|
77 |
+
'lon': np.arange(4) * 10})
|
78 |
+
x = xarray.concat([x, y], dim='lat')
|
79 |
+
x = x.transpose('lon', 'lat')
|
80 |
+
x = x.stack(channels=('lon', 'lat'))
|
81 |
+
x = x.unstack()
|
82 |
+
x = x.sum()
|
83 |
+
return xarray_jax.jax_data(x)
|
84 |
+
|
85 |
+
ones = jnp.ones((3, 4), dtype=np.float32)
|
86 |
+
result = ops_via_xarray(ones)
|
87 |
+
self.assertIsInstance(result, jax.Array)
|
88 |
+
|
89 |
+
jax.jit(ops_via_xarray)(ones)
|
90 |
+
jax.grad(ops_via_xarray)(ones)
|
91 |
+
|
92 |
+
def test_jax_xarray_dataset(self):
|
93 |
+
def ops_via_xarray(foo, bar):
|
94 |
+
x = xarray_jax.Dataset(
|
95 |
+
data_vars={'foo': (('lat', 'lon'), foo),
|
96 |
+
'bar': (('time', 'lat', 'lon'), bar)},
|
97 |
+
coords={
|
98 |
+
'time': np.arange(2),
|
99 |
+
'lat': np.arange(3) * 10,
|
100 |
+
'lon': np.arange(4) * 10})
|
101 |
+
x = np.abs((x + 2) * (x - 3))
|
102 |
+
x = x.sel({'lat': slice(0, 20)})
|
103 |
+
y = xarray_jax.Dataset(
|
104 |
+
data_vars={'foo': (('lat', 'lon'), foo),
|
105 |
+
'bar': (('time', 'lat', 'lon'), bar)},
|
106 |
+
coords={
|
107 |
+
'time': np.arange(2),
|
108 |
+
'lat': np.arange(3, 6) * 10,
|
109 |
+
'lon': np.arange(4) * 10})
|
110 |
+
x = xarray.concat([x, y], dim='lat')
|
111 |
+
x = x.transpose('lon', 'lat', 'time')
|
112 |
+
x = x.stack(channels=('lon', 'lat'))
|
113 |
+
x = (x.foo + x.bar).sum()
|
114 |
+
return xarray_jax.jax_data(x)
|
115 |
+
|
116 |
+
foo = jnp.ones((3, 4), dtype=np.float32)
|
117 |
+
bar = jnp.ones((2, 3, 4), dtype=np.float32)
|
118 |
+
result = ops_via_xarray(foo, bar)
|
119 |
+
self.assertIsInstance(result, jax.Array)
|
120 |
+
|
121 |
+
jax.jit(ops_via_xarray)(foo, bar)
|
122 |
+
jax.grad(ops_via_xarray)(foo, bar)
|
123 |
+
|
124 |
+
def test_jit_function_with_xarray_variable_arguments_and_return(self):
|
125 |
+
function = jax.jit(lambda v: v + 1)
|
126 |
+
with self.subTest('jax input'):
|
127 |
+
inputs = xarray_jax.Variable(
|
128 |
+
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
|
129 |
+
_ = function(inputs)
|
130 |
+
# We test running the jitted function a second time, to exercise logic in
|
131 |
+
# jax which checks if the structure of the inputs (including dimension
|
132 |
+
# names and coordinates) is the same as it was for the previous call and
|
133 |
+
# so whether it needs to re-trace-and-compile a new version of the
|
134 |
+
# function or not. This can run into problems if the 'aux' structure
|
135 |
+
# returned by the registered flatten function is not hashable/comparable.
|
136 |
+
outputs = function(inputs)
|
137 |
+
self.assertEqual(outputs.dims, inputs.dims)
|
138 |
+
with self.subTest('numpy input'):
|
139 |
+
inputs = xarray.Variable(
|
140 |
+
('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
|
141 |
+
_ = function(inputs)
|
142 |
+
outputs = function(inputs)
|
143 |
+
self.assertEqual(outputs.dims, inputs.dims)
|
144 |
+
|
145 |
+
def test_jit_problem_if_convert_to_plain_numpy_array(self):
|
146 |
+
inputs = xarray_jax.DataArray(
|
147 |
+
data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
|
148 |
+
with self.assertRaises(jax.errors.TracerArrayConversionError):
|
149 |
+
# Calling .values on a DataArray converts its values to numpy:
|
150 |
+
jax.jit(lambda data_array: data_array.values)(inputs)
|
151 |
+
|
152 |
+
def test_grad_function_with_xarray_variable_arguments(self):
|
153 |
+
x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
|
154 |
+
# For grad we still need a JAX scalar as the output:
|
155 |
+
jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
|
156 |
+
|
157 |
+
def test_jit_function_with_xarray_data_array_arguments_and_return(self):
|
158 |
+
inputs = xarray_jax.DataArray(
|
159 |
+
data=jnp.ones((3, 4), dtype=np.float32),
|
160 |
+
dims=('lat', 'lon'),
|
161 |
+
coords={'lat': np.arange(3),
|
162 |
+
'lon': np.arange(4) * 10})
|
163 |
+
fn = jax.jit(lambda v: v + 1)
|
164 |
+
_ = fn(inputs)
|
165 |
+
outputs = fn(inputs)
|
166 |
+
self.assertEqual(outputs.dims, inputs.dims)
|
167 |
+
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
|
168 |
+
|
169 |
+
def test_jit_function_with_data_array_and_jax_coords(self):
|
170 |
+
inputs = xarray_jax.DataArray(
|
171 |
+
data=jnp.ones((3, 4), dtype=np.float32),
|
172 |
+
dims=('lat', 'lon'),
|
173 |
+
coords={'lat': np.arange(3)},
|
174 |
+
jax_coords={'lon': jnp.arange(4) * 10})
|
175 |
+
# Verify the jax_coord 'lon' retains jax data, and has not been created
|
176 |
+
# as an index coordinate:
|
177 |
+
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
178 |
+
self.assertNotIn('lon', inputs.indexes)
|
179 |
+
|
180 |
+
@jax.jit
|
181 |
+
def fn(v):
|
182 |
+
# The non-JAX coord is passed with numpy array data and an index:
|
183 |
+
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
|
184 |
+
self.assertIn('lat', v.indexes)
|
185 |
+
|
186 |
+
# The jax_coord is passed with JAX array data:
|
187 |
+
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
188 |
+
self.assertNotIn('lon', v.indexes)
|
189 |
+
|
190 |
+
# Use the jax coord in the computation:
|
191 |
+
v = v + v.coords['lon']
|
192 |
+
|
193 |
+
# Return with an updated jax coord:
|
194 |
+
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
|
195 |
+
|
196 |
+
_ = fn(inputs)
|
197 |
+
outputs = fn(inputs)
|
198 |
+
|
199 |
+
# Verify the jax_coord 'lon' has jax data in the output too:
|
200 |
+
self.assertIsInstance(
|
201 |
+
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
202 |
+
self.assertNotIn('lon', outputs.indexes)
|
203 |
+
|
204 |
+
self.assertEqual(outputs.dims, inputs.dims)
|
205 |
+
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
|
206 |
+
# Check our computations with the coordinate values worked:
|
207 |
+
chex.assert_trees_all_equal(
|
208 |
+
outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
|
209 |
+
chex.assert_trees_all_equal(
|
210 |
+
outputs.data, (inputs + inputs.coords['lon']).data)
|
211 |
+
|
212 |
+
def test_jit_function_with_xarray_dataset_arguments_and_return(self):
|
213 |
+
foo = jnp.ones((3, 4), dtype=np.float32)
|
214 |
+
bar = jnp.ones((2, 3, 4), dtype=np.float32)
|
215 |
+
inputs = xarray_jax.Dataset(
|
216 |
+
data_vars={'foo': (('lat', 'lon'), foo),
|
217 |
+
'bar': (('time', 'lat', 'lon'), bar)},
|
218 |
+
coords={
|
219 |
+
'time': np.arange(2),
|
220 |
+
'lat': np.arange(3) * 10,
|
221 |
+
'lon': np.arange(4) * 10})
|
222 |
+
fn = jax.jit(lambda v: v + 1)
|
223 |
+
_ = fn(inputs)
|
224 |
+
outputs = fn(inputs)
|
225 |
+
self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
|
226 |
+
self.assertEqual(inputs.foo.dims, outputs.foo.dims)
|
227 |
+
self.assertEqual(inputs.bar.dims, outputs.bar.dims)
|
228 |
+
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
|
229 |
+
|
230 |
+
def test_jit_function_with_dataset_and_jax_coords(self):
|
231 |
+
foo = jnp.ones((3, 4), dtype=np.float32)
|
232 |
+
bar = jnp.ones((2, 3, 4), dtype=np.float32)
|
233 |
+
inputs = xarray_jax.Dataset(
|
234 |
+
data_vars={'foo': (('lat', 'lon'), foo),
|
235 |
+
'bar': (('time', 'lat', 'lon'), bar)},
|
236 |
+
coords={
|
237 |
+
'time': np.arange(2),
|
238 |
+
'lat': np.arange(3) * 10,
|
239 |
+
},
|
240 |
+
jax_coords={'lon': jnp.arange(4) * 10}
|
241 |
+
)
|
242 |
+
# Verify the jax_coord 'lon' retains jax data, and has not been created
|
243 |
+
# as an index coordinate:
|
244 |
+
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
245 |
+
self.assertNotIn('lon', inputs.indexes)
|
246 |
+
|
247 |
+
@jax.jit
|
248 |
+
def fn(v):
|
249 |
+
# The non-JAX coords are passed with numpy array data and an index:
|
250 |
+
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
|
251 |
+
self.assertIn('lat', v.indexes)
|
252 |
+
|
253 |
+
# The jax_coord is passed with JAX array data:
|
254 |
+
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
255 |
+
self.assertNotIn('lon', v.indexes)
|
256 |
+
|
257 |
+
# Use the jax coord in the computation:
|
258 |
+
v = v + v.coords['lon']
|
259 |
+
|
260 |
+
# Return with an updated jax coord:
|
261 |
+
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
|
262 |
+
|
263 |
+
_ = fn(inputs)
|
264 |
+
outputs = fn(inputs)
|
265 |
+
|
266 |
+
# Verify the jax_coord 'lon' has jax data in the output too:
|
267 |
+
self.assertIsInstance(
|
268 |
+
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
|
269 |
+
self.assertNotIn('lon', outputs.indexes)
|
270 |
+
|
271 |
+
self.assertEqual(outputs.dims, inputs.dims)
|
272 |
+
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
|
273 |
+
# Check our computations with the coordinate values worked:
|
274 |
+
chex.assert_trees_all_equal(
|
275 |
+
(outputs.coords['lon']).data,
|
276 |
+
(inputs.coords['lon']+1).data,
|
277 |
+
)
|
278 |
+
outputs_dict = {key: outputs[key].data for key in outputs}
|
279 |
+
inputs_and_inputs_coords_dict = {
|
280 |
+
key: (inputs + inputs.coords['lon'])[key].data
|
281 |
+
for key in inputs + inputs.coords['lon']
|
282 |
+
}
|
283 |
+
chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
|
284 |
+
|
285 |
+
def test_flatten_unflatten_variable(self):
|
286 |
+
variable = xarray_jax.Variable(
|
287 |
+
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
|
288 |
+
children, aux = xarray_jax._flatten_variable(variable)
|
289 |
+
# Check auxiliary info is hashable/comparable (important for jax.jit):
|
290 |
+
hash(aux)
|
291 |
+
self.assertEqual(aux, aux)
|
292 |
+
roundtrip = xarray_jax._unflatten_variable(aux, children)
|
293 |
+
self.assertTrue(variable.equals(roundtrip))
|
294 |
+
|
295 |
+
def test_flatten_unflatten_data_array(self):
|
296 |
+
data_array = xarray_jax.DataArray(
|
297 |
+
data=jnp.ones((3, 4), dtype=np.float32),
|
298 |
+
dims=('lat', 'lon'),
|
299 |
+
coords={'lat': np.arange(3)},
|
300 |
+
jax_coords={'lon': np.arange(4) * 10},
|
301 |
+
)
|
302 |
+
children, aux = xarray_jax._flatten_data_array(data_array)
|
303 |
+
# Check auxiliary info is hashable/comparable (important for jax.jit):
|
304 |
+
hash(aux)
|
305 |
+
self.assertEqual(aux, aux)
|
306 |
+
roundtrip = xarray_jax._unflatten_data_array(aux, children)
|
307 |
+
self.assertTrue(data_array.equals(roundtrip))
|
308 |
+
|
309 |
+
def test_flatten_unflatten_dataset(self):
|
310 |
+
foo = jnp.ones((3, 4), dtype=np.float32)
|
311 |
+
bar = jnp.ones((2, 3, 4), dtype=np.float32)
|
312 |
+
dataset = xarray_jax.Dataset(
|
313 |
+
data_vars={'foo': (('lat', 'lon'), foo),
|
314 |
+
'bar': (('time', 'lat', 'lon'), bar)},
|
315 |
+
coords={
|
316 |
+
'time': np.arange(2),
|
317 |
+
'lat': np.arange(3) * 10},
|
318 |
+
jax_coords={
|
319 |
+
'lon': np.arange(4) * 10})
|
320 |
+
children, aux = xarray_jax._flatten_dataset(dataset)
|
321 |
+
# Check auxiliary info is hashable/comparable (important for jax.jit):
|
322 |
+
hash(aux)
|
323 |
+
self.assertEqual(aux, aux)
|
324 |
+
roundtrip = xarray_jax._unflatten_dataset(aux, children)
|
325 |
+
self.assertTrue(dataset.equals(roundtrip))
|
326 |
+
|
327 |
+
def test_flatten_unflatten_added_dim(self):
|
328 |
+
data_array = xarray_jax.DataArray(
|
329 |
+
data=jnp.ones((3, 4), dtype=np.float32),
|
330 |
+
dims=('lat', 'lon'),
|
331 |
+
coords={'lat': np.arange(3),
|
332 |
+
'lon': np.arange(4) * 10})
|
333 |
+
leaves, treedef = jax.tree_util.tree_flatten(data_array)
|
334 |
+
leaves = [jnp.expand_dims(x, 0) for x in leaves]
|
335 |
+
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
|
336 |
+
with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
|
337 |
+
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
|
338 |
+
xarray.testing.assert_identical(
|
339 |
+
jax.device_get(data_array),
|
340 |
+
jax.device_get(with_new_dim.isel(new=0)))
|
341 |
+
|
342 |
+
def test_map_added_dim(self):
|
343 |
+
data_array = xarray_jax.DataArray(
|
344 |
+
data=jnp.ones((3, 4), dtype=np.float32),
|
345 |
+
dims=('lat', 'lon'),
|
346 |
+
coords={'lat': np.arange(3),
|
347 |
+
'lon': np.arange(4) * 10})
|
348 |
+
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
|
349 |
+
with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
|
350 |
+
data_array)
|
351 |
+
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
|
352 |
+
xarray.testing.assert_identical(
|
353 |
+
jax.device_get(data_array),
|
354 |
+
jax.device_get(with_new_dim.isel(new=0)))
|
355 |
+
|
356 |
+
def test_map_remove_dim(self):
|
357 |
+
foo = jnp.ones((1, 3, 4), dtype=np.float32)
|
358 |
+
bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
|
359 |
+
dataset = xarray_jax.Dataset(
|
360 |
+
data_vars={'foo': (('batch', 'lat', 'lon'), foo),
|
361 |
+
'bar': (('batch', 'time', 'lat', 'lon'), bar)},
|
362 |
+
coords={
|
363 |
+
'batch': np.array([123]),
|
364 |
+
'time': np.arange(2),
|
365 |
+
'lat': np.arange(3) * 10,
|
366 |
+
'lon': np.arange(4) * 10})
|
367 |
+
with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
|
368 |
+
with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
|
369 |
+
dataset)
|
370 |
+
self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
|
371 |
+
self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
|
372 |
+
self.assertNotIn('batch', with_removed_dim.dims)
|
373 |
+
self.assertNotIn('batch', with_removed_dim.coords)
|
374 |
+
xarray.testing.assert_identical(
|
375 |
+
jax.device_get(dataset.isel(batch=0, drop=True)),
|
376 |
+
jax.device_get(with_removed_dim))
|
377 |
+
|
378 |
+
def test_pmap(self):
|
379 |
+
devices = jax.local_device_count()
|
380 |
+
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
|
381 |
+
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
|
382 |
+
dataset = xarray_jax.Dataset({
|
383 |
+
'foo': (('device', 'lat', 'lon'), foo),
|
384 |
+
'bar': (('device', 'time', 'lat', 'lon'), bar)})
|
385 |
+
|
386 |
+
def func(d):
|
387 |
+
self.assertNotIn('device', d.dims)
|
388 |
+
return d + 1
|
389 |
+
func = xarray_jax.pmap(func, dim='device')
|
390 |
+
|
391 |
+
result = func(dataset)
|
392 |
+
xarray.testing.assert_identical(
|
393 |
+
jax.device_get(dataset + 1),
|
394 |
+
jax.device_get(result))
|
395 |
+
|
396 |
+
# Can call it again with a different argument structure (it will recompile
|
397 |
+
# under the hood but should work):
|
398 |
+
dataset = dataset.drop_vars('foo')
|
399 |
+
result = func(dataset)
|
400 |
+
xarray.testing.assert_identical(
|
401 |
+
jax.device_get(dataset + 1),
|
402 |
+
jax.device_get(result))
|
403 |
+
|
404 |
+
def test_pmap_with_jax_coords(self):
|
405 |
+
devices = jax.local_device_count()
|
406 |
+
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
|
407 |
+
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
|
408 |
+
time = jnp.zeros((devices, 2), dtype=np.float32)
|
409 |
+
dataset = xarray_jax.Dataset(
|
410 |
+
{'foo': (('device', 'lat', 'lon'), foo),
|
411 |
+
'bar': (('device', 'time', 'lat', 'lon'), bar)},
|
412 |
+
coords={
|
413 |
+
'lat': np.arange(3),
|
414 |
+
'lon': np.arange(4),
|
415 |
+
},
|
416 |
+
jax_coords={
|
417 |
+
# Currently any jax_coords need a leading device dimension to use
|
418 |
+
# with pmap, same as for data_vars.
|
419 |
+
# TODO(matthjw): have pmap automatically broadcast to all devices
|
420 |
+
# where the device dimension not present.
|
421 |
+
'time': xarray_jax.Variable(('device', 'time'), time),
|
422 |
+
}
|
423 |
+
)
|
424 |
+
|
425 |
+
def func(d):
|
426 |
+
self.assertNotIn('device', d.dims)
|
427 |
+
self.assertNotIn('device', d.coords['time'].dims)
|
428 |
+
|
429 |
+
# The jax_coord 'time' should be passed in backed by a JAX array, but
|
430 |
+
# not as an index coordinate.
|
431 |
+
self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
|
432 |
+
self.assertNotIn('time', d.indexes)
|
433 |
+
|
434 |
+
return d + 1
|
435 |
+
func = xarray_jax.pmap(func, dim='device')
|
436 |
+
|
437 |
+
result = func(dataset)
|
438 |
+
xarray.testing.assert_identical(
|
439 |
+
jax.device_get(dataset + 1),
|
440 |
+
jax.device_get(result))
|
441 |
+
|
442 |
+
# Can call it again with a different argument structure (it will recompile
|
443 |
+
# under the hood but should work):
|
444 |
+
dataset = dataset.drop_vars('foo')
|
445 |
+
result = func(dataset)
|
446 |
+
xarray.testing.assert_identical(
|
447 |
+
jax.device_get(dataset + 1),
|
448 |
+
jax.device_get(result))
|
449 |
+
|
450 |
+
def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
|
451 |
+
devices = jax.local_device_count()
|
452 |
+
data_array = xarray_jax.DataArray(
|
453 |
+
data=jnp.ones((devices, 3, 4), dtype=np.float32),
|
454 |
+
dims=('device', 'lat', 'lon'))
|
455 |
+
plain_array = jnp.ones((devices, 2), dtype=np.float32)
|
456 |
+
inputs = {'foo': data_array,
|
457 |
+
'bar': plain_array}
|
458 |
+
|
459 |
+
def func(x):
|
460 |
+
return x['foo'] + 1, x['bar'] + 1
|
461 |
+
|
462 |
+
func = xarray_jax.pmap(func, dim='device')
|
463 |
+
result_foo, result_bar = func(inputs)
|
464 |
+
xarray.testing.assert_identical(
|
465 |
+
jax.device_get(inputs['foo'] + 1),
|
466 |
+
jax.device_get(result_foo))
|
467 |
+
np.testing.assert_array_equal(
|
468 |
+
jax.device_get(inputs['bar'] + 1),
|
469 |
+
jax.device_get(result_bar))
|
470 |
+
|
471 |
+
def test_pmap_complains_when_dim_not_first(self):
|
472 |
+
devices = jax.local_device_count()
|
473 |
+
data_array = xarray_jax.DataArray(
|
474 |
+
data=jnp.ones((3, devices, 4), dtype=np.float32),
|
475 |
+
dims=('lat', 'device', 'lon'))
|
476 |
+
|
477 |
+
func = xarray_jax.pmap(lambda x: x+1, dim='device')
|
478 |
+
with self.assertRaisesRegex(
|
479 |
+
ValueError, 'Expected dim device at index 0, found at 1'):
|
480 |
+
func(data_array)
|
481 |
+
|
482 |
+
def test_apply_ufunc(self):
|
483 |
+
inputs = xarray_jax.DataArray(
|
484 |
+
data=jnp.asarray([[1, 2], [3, 4]]),
|
485 |
+
dims=('x', 'y'),
|
486 |
+
coords={'x': [0, 1],
|
487 |
+
'y': [2, 3]})
|
488 |
+
result = xarray_jax.apply_ufunc(
|
489 |
+
lambda x: jnp.sum(x, axis=-1),
|
490 |
+
inputs,
|
491 |
+
input_core_dims=[['x']])
|
492 |
+
expected_result = xarray_jax.DataArray(
|
493 |
+
data=[4, 6],
|
494 |
+
dims=('y',),
|
495 |
+
coords={'y': [2, 3]})
|
496 |
+
xarray.testing.assert_identical(expected_result, jax.device_get(result))
|
497 |
+
|
498 |
+
def test_apply_ufunc_multiple_return_values(self):
|
499 |
+
def ufunc(array):
|
500 |
+
return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
|
501 |
+
inputs = xarray_jax.DataArray(
|
502 |
+
data=jnp.asarray([[1, 4], [3, 2]]),
|
503 |
+
dims=('x', 'y'),
|
504 |
+
coords={'x': [0, 1],
|
505 |
+
'y': [2, 3]})
|
506 |
+
result = xarray_jax.apply_ufunc(
|
507 |
+
ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
|
508 |
+
expected = (
|
509 |
+
# Mins:
|
510 |
+
xarray_jax.DataArray(
|
511 |
+
data=[1, 2],
|
512 |
+
dims=('y',),
|
513 |
+
coords={'y': [2, 3]}
|
514 |
+
),
|
515 |
+
# Maxes:
|
516 |
+
xarray_jax.DataArray(
|
517 |
+
data=[3, 4],
|
518 |
+
dims=('y',),
|
519 |
+
coords={'y': [2, 3]}
|
520 |
+
)
|
521 |
+
)
|
522 |
+
xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
|
523 |
+
xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
|
524 |
+
|
525 |
+
if __name__ == '__main__':
|
526 |
+
absltest.main()
|
graphcast/xarray_tree.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Utilities for working with trees of xarray.DataArray (including Datasets).
|
15 |
+
|
16 |
+
Note that xarray.Dataset doesn't work out-of-the-box with the `tree` library;
|
17 |
+
it won't work as a leaf node since it implements Mapping, but also won't work
|
18 |
+
as an internal node since tree doesn't know how to re-create it properly.
|
19 |
+
|
20 |
+
To fix this, we reimplement a subset of `map_structure`, exposing its
|
21 |
+
constituent DataArrays as leaf nodes. This means it can be mapped over as a
|
22 |
+
generic container of DataArrays, while still preserving the result as a Dataset
|
23 |
+
where possible.
|
24 |
+
|
25 |
+
This is useful because in a few places we need to handle a general
|
26 |
+
Mapping[str, DataArray] (where the coordinates might not be compatible across
|
27 |
+
the constituent DataArrays) but also the special case of a Dataset nicely.
|
28 |
+
|
29 |
+
For the result e.g. of a tree.map_structure(fn, dataset), if fn returns None for
|
30 |
+
some of the child DataArrays, they will be omitted from the returned dataset. If
|
31 |
+
any values other than DataArrays or None are returned, then we don't attempt to
|
32 |
+
return a Dataset and just return a plain dict of the results. Similarly if
|
33 |
+
DataArrays are returned but with non-matching coordinates, it will just return a
|
34 |
+
plain dict of DataArrays.
|
35 |
+
|
36 |
+
Note xarray datatypes are registered with `jax.tree_util` by xarray_jax.py,
|
37 |
+
but `jax.tree_util.tree_map` is distinct from the `xarray_tree.map_structure`.
|
38 |
+
as the former exposes the underlying JAX/numpy arrays as leaf nodes, while the
|
39 |
+
latter exposes DataArrays as leaf nodes.
|
40 |
+
"""
|
41 |
+
|
42 |
+
from typing import Any, Callable
|
43 |
+
|
44 |
+
import xarray
|
45 |
+
|
46 |
+
|
47 |
+
def map_structure(func: Callable[..., Any], *structures: Any) -> Any:
|
48 |
+
"""Maps func through given structures with xarrays. See tree.map_structure."""
|
49 |
+
if not callable(func):
|
50 |
+
raise TypeError(f'func must be callable, got: {func}')
|
51 |
+
if not structures:
|
52 |
+
raise ValueError('Must provide at least one structure')
|
53 |
+
|
54 |
+
first = structures[0]
|
55 |
+
if isinstance(first, xarray.Dataset):
|
56 |
+
data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
|
57 |
+
if all(isinstance(a, (type(None), xarray.DataArray))
|
58 |
+
for a in data.values()):
|
59 |
+
data_arrays = [v.rename(k) for k, v in data.items() if v is not None]
|
60 |
+
try:
|
61 |
+
return xarray.merge(data_arrays, join='exact')
|
62 |
+
except ValueError: # Exact join not possible.
|
63 |
+
pass
|
64 |
+
return data
|
65 |
+
if isinstance(first, dict):
|
66 |
+
return {k: map_structure(func, *[s[k] for s in structures])
|
67 |
+
for k in first.keys()}
|
68 |
+
if isinstance(first, (list, tuple, set)):
|
69 |
+
return type(first)(map_structure(func, *s) for s in zip(*structures))
|
70 |
+
return func(*structures)
|
graphcast/xarray_tree_test.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 DeepMind Technologies Limited.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS-IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Tests for xarray_tree."""
|
15 |
+
|
16 |
+
from absl.testing import absltest
|
17 |
+
from graphcast import xarray_tree
|
18 |
+
import numpy as np
|
19 |
+
import xarray
|
20 |
+
|
21 |
+
|
22 |
+
TEST_DATASET = xarray.Dataset(
|
23 |
+
data_vars={
|
24 |
+
"foo": (("x", "y"), np.zeros((2, 3))),
|
25 |
+
"bar": (("x",), np.zeros((2,))),
|
26 |
+
},
|
27 |
+
coords={
|
28 |
+
"x": [1, 2],
|
29 |
+
"y": [10, 20, 30],
|
30 |
+
}
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class XarrayTreeTest(absltest.TestCase):
|
35 |
+
|
36 |
+
def test_map_structure_maps_over_leaves_but_preserves_dataset_type(self):
|
37 |
+
def fn(leaf):
|
38 |
+
self.assertIsInstance(leaf, xarray.DataArray)
|
39 |
+
result = leaf + 1
|
40 |
+
# Removing the name from the returned DataArray to test that we don't rely
|
41 |
+
# on it being present to restore the correct names in the result:
|
42 |
+
result = result.rename(None)
|
43 |
+
return result
|
44 |
+
|
45 |
+
result = xarray_tree.map_structure(fn, TEST_DATASET)
|
46 |
+
self.assertIsInstance(result, xarray.Dataset)
|
47 |
+
self.assertSameElements({"foo", "bar"}, result.keys())
|
48 |
+
|
49 |
+
def test_map_structure_on_data_arrays(self):
|
50 |
+
data_arrays = dict(TEST_DATASET)
|
51 |
+
result = xarray_tree.map_structure(lambda x: x+1, data_arrays)
|
52 |
+
self.assertIsInstance(result, dict)
|
53 |
+
self.assertSameElements({"foo", "bar"}, result.keys())
|
54 |
+
|
55 |
+
def test_map_structure_on_dataset_plain_dict_when_coords_incompatible(self):
|
56 |
+
def fn(leaf):
|
57 |
+
# Returns DataArrays that can't be exactly merged back into a Dataset
|
58 |
+
# due to the coordinates not matching:
|
59 |
+
if leaf.name == "foo":
|
60 |
+
return xarray.DataArray(
|
61 |
+
data=np.zeros(2), dims=("x",), coords={"x": [1, 2]})
|
62 |
+
else:
|
63 |
+
return xarray.DataArray(
|
64 |
+
data=np.zeros(2), dims=("x",), coords={"x": [3, 4]})
|
65 |
+
|
66 |
+
result = xarray_tree.map_structure(fn, TEST_DATASET)
|
67 |
+
self.assertIsInstance(result, dict)
|
68 |
+
self.assertSameElements({"foo", "bar"}, result.keys())
|
69 |
+
|
70 |
+
def test_map_structure_on_dataset_drops_vars_with_none_return_values(self):
|
71 |
+
def fn(leaf):
|
72 |
+
return leaf if leaf.name == "foo" else None
|
73 |
+
|
74 |
+
result = xarray_tree.map_structure(fn, TEST_DATASET)
|
75 |
+
self.assertIsInstance(result, xarray.Dataset)
|
76 |
+
self.assertSameElements({"foo"}, result.keys())
|
77 |
+
|
78 |
+
def test_map_structure_on_dataset_returns_plain_dict_other_return_types(self):
|
79 |
+
def fn(leaf):
|
80 |
+
self.assertIsInstance(leaf, xarray.DataArray)
|
81 |
+
return "not a DataArray"
|
82 |
+
|
83 |
+
result = xarray_tree.map_structure(fn, TEST_DATASET)
|
84 |
+
self.assertEqual({"foo": "not a DataArray",
|
85 |
+
"bar": "not a DataArray"}, result)
|
86 |
+
|
87 |
+
def test_map_structure_two_args_different_variable_orders(self):
|
88 |
+
dataset_different_order = TEST_DATASET[["bar", "foo"]]
|
89 |
+
def fn(arg1, arg2):
|
90 |
+
self.assertEqual(arg1.name, arg2.name)
|
91 |
+
xarray_tree.map_structure(fn, TEST_DATASET, dataset_different_order)
|
92 |
+
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
absltest.main()
|