|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for rolling out models.""" |
|
|
|
from typing import Iterator |
|
|
|
from absl import logging |
|
import chex |
|
import dask.array |
|
from graphcast import xarray_tree |
|
import jax |
|
import numpy as np |
|
import typing_extensions |
|
import xarray |
|
|
|
|
|
class PredictorFn(typing_extensions.Protocol): |
|
"""Functional version of base.Predictor.__call__ with explicit rng.""" |
|
|
|
def __call__( |
|
self, rng: chex.PRNGKey, inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**optional_kwargs, |
|
) -> xarray.Dataset: |
|
... |
|
|
|
|
|
def chunked_prediction( |
|
predictor_fn: PredictorFn, |
|
rng: chex.PRNGKey, |
|
inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
num_steps_per_chunk: int = 1, |
|
verbose: bool = False, |
|
) -> xarray.Dataset: |
|
"""Outputs a long trajectory by iteratively concatenating chunked predictions. |
|
|
|
Args: |
|
predictor_fn: Function to use to make predictions for each chunk. |
|
rng: Random key. |
|
inputs: Inputs for the model. |
|
targets_template: Template for the target prediction, requires targets |
|
equispaced in time. |
|
forcings: Optional forcing for the model. |
|
num_steps_per_chunk: How many of the steps in `targets_template` to predict |
|
at each call of `predictor_fn`. It must evenly divide the number of |
|
steps in `targets_template`. |
|
verbose: Whether to log the current chunk being predicted. |
|
|
|
Returns: |
|
Predictions for the targets template. |
|
|
|
""" |
|
chunks_list = [] |
|
for prediction_chunk in chunked_prediction_generator( |
|
predictor_fn=predictor_fn, |
|
rng=rng, |
|
inputs=inputs, |
|
targets_template=targets_template, |
|
forcings=forcings, |
|
num_steps_per_chunk=num_steps_per_chunk, |
|
verbose=verbose): |
|
chunks_list.append(jax.device_get(prediction_chunk)) |
|
return xarray.concat(chunks_list, dim="time") |
|
|
|
|
|
def chunked_prediction_generator( |
|
predictor_fn: PredictorFn, |
|
rng: chex.PRNGKey, |
|
inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
num_steps_per_chunk: int = 1, |
|
verbose: bool = False, |
|
) -> Iterator[xarray.Dataset]: |
|
"""Outputs a long trajectory by yielding chunked predictions. |
|
|
|
Args: |
|
predictor_fn: Function to use to make predictions for each chunk. |
|
rng: Random key. |
|
inputs: Inputs for the model. |
|
targets_template: Template for the target prediction, requires targets |
|
equispaced in time. |
|
forcings: Optional forcing for the model. |
|
num_steps_per_chunk: How many of the steps in `targets_template` to predict |
|
at each call of `predictor_fn`. It must evenly divide the number of |
|
steps in `targets_template`. |
|
verbose: Whether to log the current chunk being predicted. |
|
|
|
Yields: |
|
The predictions for each chunked step of the chunked rollout, such as |
|
if all predictions are concatenated in time this would match the targets |
|
template in structure. |
|
|
|
""" |
|
|
|
|
|
inputs = xarray.Dataset(inputs) |
|
targets_template = xarray.Dataset(targets_template) |
|
forcings = xarray.Dataset(forcings) |
|
|
|
if "datetime" in inputs.coords: |
|
del inputs.coords["datetime"] |
|
|
|
if "datetime" in targets_template.coords: |
|
output_datetime = targets_template.coords["datetime"] |
|
del targets_template.coords["datetime"] |
|
else: |
|
output_datetime = None |
|
|
|
if "datetime" in forcings.coords: |
|
del forcings.coords["datetime"] |
|
|
|
num_target_steps = targets_template.dims["time"] |
|
num_chunks, remainder = divmod(num_target_steps, num_steps_per_chunk) |
|
if remainder != 0: |
|
raise ValueError( |
|
f"The number of steps per chunk {num_steps_per_chunk} must " |
|
f"evenly divide the number of target steps {num_target_steps} ") |
|
|
|
if len(np.unique(np.diff(targets_template.coords["time"].data))) > 1: |
|
raise ValueError("The targets time coordinates must be evenly spaced") |
|
|
|
|
|
|
|
targets_chunk_time = targets_template.time.isel( |
|
time=slice(0, num_steps_per_chunk)) |
|
|
|
current_inputs = inputs |
|
for chunk_index in range(num_chunks): |
|
if verbose: |
|
logging.info("Chunk %d/%d", chunk_index, num_chunks) |
|
logging.flush() |
|
|
|
|
|
target_offset = num_steps_per_chunk * chunk_index |
|
target_slice = slice(target_offset, target_offset + num_steps_per_chunk) |
|
current_targets_template = targets_template.isel(time=target_slice) |
|
|
|
|
|
|
|
actual_target_time = current_targets_template.coords["time"] |
|
current_targets_template = current_targets_template.assign_coords( |
|
time=targets_chunk_time).compute() |
|
|
|
current_forcings = forcings.isel(time=target_slice) |
|
current_forcings = current_forcings.assign_coords(time=targets_chunk_time) |
|
current_forcings = current_forcings.compute() |
|
|
|
rng, this_rng = jax.random.split(rng) |
|
predictions = predictor_fn( |
|
rng=this_rng, |
|
inputs=current_inputs, |
|
targets_template=current_targets_template, |
|
forcings=current_forcings) |
|
|
|
next_frame = xarray.merge([predictions, current_forcings]) |
|
|
|
next_inputs = _get_next_inputs(current_inputs, next_frame) |
|
|
|
|
|
next_inputs = next_inputs.assign_coords(time=current_inputs.coords["time"]) |
|
current_inputs = next_inputs |
|
|
|
|
|
predictions = predictions.assign_coords(time=actual_target_time) |
|
if output_datetime is not None: |
|
predictions.coords["datetime"] = output_datetime.isel( |
|
time=target_slice) |
|
yield predictions |
|
del predictions |
|
|
|
|
|
def _get_next_inputs( |
|
prev_inputs: xarray.Dataset, next_frame: xarray.Dataset, |
|
) -> xarray.Dataset: |
|
"""Computes next inputs, from previous inputs and predictions.""" |
|
|
|
|
|
non_predicted_or_forced_inputs = list( |
|
set(prev_inputs.keys()) - set(next_frame.keys())) |
|
if "time" in prev_inputs[non_predicted_or_forced_inputs].dims: |
|
raise ValueError( |
|
"Found an input with a time index that is not predicted or forced.") |
|
|
|
|
|
next_inputs_keys = list( |
|
set(next_frame.keys()).intersection(set(prev_inputs.keys()))) |
|
next_inputs = next_frame[next_inputs_keys] |
|
|
|
|
|
num_inputs = prev_inputs.dims["time"] |
|
return ( |
|
xarray.concat( |
|
[prev_inputs, next_inputs], dim="time", data_vars="different") |
|
.tail(time=num_inputs)) |
|
|
|
|
|
def extend_targets_template( |
|
targets_template: xarray.Dataset, |
|
required_num_steps: int) -> xarray.Dataset: |
|
"""Extends `targets_template` to `required_num_steps` with lazy arrays. |
|
|
|
It uses lazy dask arrays of zeros, so it does not require instantiating the |
|
array in memory. |
|
|
|
Args: |
|
targets_template: Input template to extend. |
|
required_num_steps: Number of steps required in the returned template. |
|
|
|
Returns: |
|
`xarray.Dataset` identical in variables and timestep to `targets_template` |
|
full of `dask.array.zeros` such that the time axis has `required_num_steps`. |
|
|
|
""" |
|
|
|
|
|
time = targets_template.coords["time"] |
|
|
|
|
|
timestep = time[0].data |
|
if time.shape[0] > 1: |
|
assert np.all(timestep == time[1:] - time[:-1]) |
|
|
|
extended_time = (np.arange(required_num_steps) + 1) * timestep |
|
|
|
if "datetime" in targets_template.coords: |
|
datetime = targets_template.coords["datetime"] |
|
extended_datetime = (datetime[0].data - timestep) + extended_time |
|
else: |
|
extended_datetime = None |
|
|
|
|
|
datetime = targets_template.coords["time"] |
|
|
|
def extend_time(data_array: xarray.DataArray) -> xarray.DataArray: |
|
dims = data_array.dims |
|
shape = list(data_array.shape) |
|
shape[dims.index("time")] = required_num_steps |
|
dask_data = dask.array.zeros( |
|
shape=tuple(shape), |
|
chunks=-1, |
|
dtype=data_array.dtype) |
|
|
|
coords = dict(data_array.coords) |
|
coords["time"] = extended_time |
|
|
|
if extended_datetime is not None: |
|
coords["datetime"] = ("time", extended_datetime) |
|
|
|
return xarray.DataArray( |
|
dims=dims, |
|
data=dask_data, |
|
coords=coords) |
|
|
|
return xarray_tree.map_structure(extend_time, targets_template) |
|
|