File size: 7,402 Bytes
6d70ed4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base classes for an xarray-based Predictor API."""

import abc

from typing import Tuple

from graphcast import losses
from graphcast import xarray_jax
import jax.numpy as jnp
import xarray

LossAndDiagnostics = losses.LossAndDiagnostics


class Predictor(abc.ABC):
  """A possibly-trainable predictor of weather, exposing an xarray-based API.

  Typically wraps an underlying JAX model and handles translating the xarray
  Dataset values to and from plain JAX arrays that are convenient for input to
  (and output from) the underlying model.

  Different subclasses may exist to wrap different kinds of underlying model,
  e.g. models taking stacked inputs/outputs, models taking separate 2D and 3D
  inputs/outputs, autoregressive models.

  You can also implement a specific model directly as a Predictor if you want,
  for example if it has quite specific/unique requirements for its input/output
  or loss function, or if it's convenient to implement directly using xarray.
  """

  @abc.abstractmethod
  def __call__(self,
               inputs: xarray.Dataset,
               targets_template: xarray.Dataset,
               forcings: xarray.Dataset,
               **optional_kwargs
               ) -> xarray.Dataset:
    """Makes predictions.

    This is only used by the Experiment for inference / evaluation, with
    training going via the .loss method. So it should default to making
    predictions for evaluation, although you can also support making predictions
    for use in the loss via an is_training argument -- see
    LossFunctionPredictor which helps with that.

    Args:
      inputs: An xarray.Dataset of inputs.
      targets_template: An xarray.Dataset or other mapping of xarray.DataArrays,
        with the same shape as the targets, to demonstrate what kind of
        predictions are required. You can use this to determine which variables,
        levels and lead times must be predicted.
        You are free to raise an error if you don't support predicting what is
        requested.
      forcings: An xarray.Dataset of forcings terms. Forcings are variables
        that can be fed to the model, but do not need to be predicted. This is
        often because this variable can be computed analytically (e.g. the toa
        radiation of the sun is mostly a function of geometry) or are considered
        to be controlled for the experiment (e.g., impose a scenario of C02
        emission into the atmosphere). Unlike `inputs`, the `forcings` can
        include information "from the future", that is, information at target
        times specified in the `targets_template`.
      **optional_kwargs: Implementations may support extra optional kwargs,
        provided they set appropriate defaults for them.

    Returns:
      Predictions, as an xarray.Dataset or other mapping of DataArrays which
      is capable of being evaluated against targets with shape given by
      targets_template.
      For probabilistic predictors which can return multiple samples from a
      predictive distribution, these should (by convention) be returned along
      an additional 'sample' dimension.
    """

  def loss(self,
           inputs: xarray.Dataset,
           targets: xarray.Dataset,
           forcings: xarray.Dataset,
           **optional_kwargs,
           ) -> LossAndDiagnostics:
    """Computes a training loss, for predictors that are trainable.

    Why make this the Predictor's responsibility, rather than letting callers
    compute their own loss function using predictions obtained from
    Predictor.__call__?

    Doing it this way gives Predictors more control over their training setup.
    For example, some predictors may wish to train using different targets to
    the ones they predict at evaluation time -- perhaps different lead times and
    variables, perhaps training to predict transformed versions of targets
    where the transform needs to be inverted at evaluation time, etc.

    It's also necessary for generative models (VAEs, GANs, ...) where the
    training loss is more complex and isn't expressible as a parameter-free
    function of predictions and targets.

    Args:
      inputs: An xarray.Dataset.
      targets: An xarray.Dataset or other mapping of xarray.DataArrays. See
        docs on __call__ for an explanation about the targets.
      forcings: xarray.Dataset of forcing terms.
      **optional_kwargs: Implementations may support extra optional kwargs,
        provided they set appropriate defaults for them.

    Returns:
      loss: A DataArray with dimensions ('batch',) containing losses for each
        element of the batch. These will be averaged to give the final
        loss, locally and across replicas.
      diagnostics: Mapping of additional quantities to log by name alongside the
        loss. These will will typically correspond to terms in the loss. They
        should also have dimensions ('batch',) and will be averaged over the
        batch before logging.
        You need not include the loss itself in this dict; it will be added for
        you.
    """
    del targets, forcings, optional_kwargs
    batch_size = inputs.sizes['batch']
    dummy_loss = xarray_jax.DataArray(jnp.zeros(batch_size), dims=('batch',))
    return dummy_loss, {}  # pytype: disable=bad-return-type

  def loss_and_predictions(
      self,
      inputs: xarray.Dataset,
      targets: xarray.Dataset,
      forcings: xarray.Dataset,
      **optional_kwargs,
      ) -> Tuple[LossAndDiagnostics, xarray.Dataset]:
    """Like .loss but also returns corresponding predictions.

    Implementing this is optional as it's not used directly by the Experiment,
    but it is required by autoregressive.Predictor when applying an inner
    Predictor autoregressively at training time; we need a loss at each step but
    also predictions to feed back in for the next step.

    Note the loss itself may not be directly regressing the predictions towards
    targets, the loss may be computed in terms of transformed predictions and
    targets (or in some other way). For this reason we can't always cleanly
    separate this into step 1: get predictions, step 2: compute loss from them,
    hence the need for this combined method.

    Args:
      inputs:
      targets:
      forcings:
      **optional_kwargs:
        As for self.loss.

    Returns:
      (loss, diagnostics)
        As for self.loss
      predictions:
        The predictions which the loss relates to. These should be of the same
        shape as what you would get from
        `self.__call__(inputs, targets_template=targets)`, and should be in the
        same 'domain' as the inputs (i.e. they shouldn't be transformed
        differently to how the predictor expects its inputs).
    """
    raise NotImplementedError