|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Abstract base classes for an xarray-based Predictor API.""" |
|
|
|
import abc |
|
|
|
from typing import Tuple |
|
|
|
from graphcast import losses |
|
from graphcast import xarray_jax |
|
import jax.numpy as jnp |
|
import xarray |
|
|
|
LossAndDiagnostics = losses.LossAndDiagnostics |
|
|
|
|
|
class Predictor(abc.ABC): |
|
"""A possibly-trainable predictor of weather, exposing an xarray-based API. |
|
|
|
Typically wraps an underlying JAX model and handles translating the xarray |
|
Dataset values to and from plain JAX arrays that are convenient for input to |
|
(and output from) the underlying model. |
|
|
|
Different subclasses may exist to wrap different kinds of underlying model, |
|
e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D |
|
inputs/outputs, autoregressive models. |
|
|
|
You can also implement a specific model directly as a Predictor if you want, |
|
for example if it has quite specific/unique requirements for its input/output |
|
or loss function, or if it's convenient to implement directly using xarray. |
|
""" |
|
|
|
@abc.abstractmethod |
|
def __call__(self, |
|
inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**optional_kwargs |
|
) -> xarray.Dataset: |
|
"""Makes predictions. |
|
|
|
This is only used by the Experiment for inference / evaluation, with |
|
training going via the .loss method. So it should default to making |
|
predictions for evaluation, although you can also support making predictions |
|
for use in the loss via an is_training argument -- see |
|
LossFunctionPredictor which helps with that. |
|
|
|
Args: |
|
inputs: An xarray.Dataset of inputs. |
|
targets_template: An xarray.Dataset or other mapping of xarray.DataArrays, |
|
with the same shape as the targets, to demonstrate what kind of |
|
predictions are required. You can use this to determine which variables, |
|
levels and lead times must be predicted. |
|
You are free to raise an error if you don't support predicting what is |
|
requested. |
|
forcings: An xarray.Dataset of forcings terms. Forcings are variables |
|
that can be fed to the model, but do not need to be predicted. This is |
|
often because this variable can be computed analytically (e.g. the toa |
|
radiation of the sun is mostly a function of geometry) or are considered |
|
to be controlled for the experiment (e.g., impose a scenario of C02 |
|
emission into the atmosphere). Unlike `inputs`, the `forcings` can |
|
include information "from the future", that is, information at target |
|
times specified in the `targets_template`. |
|
**optional_kwargs: Implementations may support extra optional kwargs, |
|
provided they set appropriate defaults for them. |
|
|
|
Returns: |
|
Predictions, as an xarray.Dataset or other mapping of DataArrays which |
|
is capable of being evaluated against targets with shape given by |
|
targets_template. |
|
For probabilistic predictors which can return multiple samples from a |
|
predictive distribution, these should (by convention) be returned along |
|
an additional 'sample' dimension. |
|
""" |
|
|
|
def loss(self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**optional_kwargs, |
|
) -> LossAndDiagnostics: |
|
"""Computes a training loss, for predictors that are trainable. |
|
|
|
Why make this the Predictor's responsibility, rather than letting callers |
|
compute their own loss function using predictions obtained from |
|
Predictor.__call__? |
|
|
|
Doing it this way gives Predictors more control over their training setup. |
|
For example, some predictors may wish to train using different targets to |
|
the ones they predict at evaluation time -- perhaps different lead times and |
|
variables, perhaps training to predict transformed versions of targets |
|
where the transform needs to be inverted at evaluation time, etc. |
|
|
|
It's also necessary for generative models (VAEs, GANs, ...) where the |
|
training loss is more complex and isn't expressible as a parameter-free |
|
function of predictions and targets. |
|
|
|
Args: |
|
inputs: An xarray.Dataset. |
|
targets: An xarray.Dataset or other mapping of xarray.DataArrays. See |
|
docs on __call__ for an explanation about the targets. |
|
forcings: xarray.Dataset of forcing terms. |
|
**optional_kwargs: Implementations may support extra optional kwargs, |
|
provided they set appropriate defaults for them. |
|
|
|
Returns: |
|
loss: A DataArray with dimensions ('batch',) containing losses for each |
|
element of the batch. These will be averaged to give the final |
|
loss, locally and across replicas. |
|
diagnostics: Mapping of additional quantities to log by name alongside the |
|
loss. These will will typically correspond to terms in the loss. They |
|
should also have dimensions ('batch',) and will be averaged over the |
|
batch before logging. |
|
You need not include the loss itself in this dict; it will be added for |
|
you. |
|
""" |
|
del targets, forcings, optional_kwargs |
|
batch_size = inputs.sizes['batch'] |
|
dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',)) |
|
return dummy_loss, {} |
|
|
|
def loss_and_predictions( |
|
self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**optional_kwargs, |
|
) -> Tuple[LossAndDiagnostics, xarray.Dataset]: |
|
"""Like .loss but also returns corresponding predictions. |
|
|
|
Implementing this is optional as it's not used directly by the Experiment, |
|
but it is required by autoregressive.Predictor when applying an inner |
|
Predictor autoregressively at training time; we need a loss at each step but |
|
also predictions to feed back in for the next step. |
|
|
|
Note the loss itself may not be directly regressing the predictions towards |
|
targets, the loss may be computed in terms of transformed predictions and |
|
targets (or in some other way). For this reason we can't always cleanly |
|
separate this into step 1: get predictions, step 2: compute loss from them, |
|
hence the need for this combined method. |
|
|
|
Args: |
|
inputs: |
|
targets: |
|
forcings: |
|
**optional_kwargs: |
|
As for self.loss. |
|
|
|
Returns: |
|
(loss, diagnostics) |
|
As for self.loss |
|
predictions: |
|
The predictions which the loss relates to. These should be of the same |
|
shape as what you would get from |
|
`self.__call__(inputs, targets_template=targets)`, and should be in the |
|
same 'domain' as the inputs (i.e. they shouldn't be transformed |
|
differently to how the predictor expects its inputs). |
|
""" |
|
raise NotImplementedError |
|
|