|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Wrappers for Predictors which allow them to work with normalized data. |
|
|
|
The Predictor which is wrapped sees normalized inputs and targets, and makes |
|
normalized predictions. The wrapper handles translating the predictions back |
|
to the original domain. |
|
""" |
|
|
|
import logging |
|
from typing import Optional, Tuple |
|
|
|
from graphcast import predictor_base |
|
from graphcast import xarray_tree |
|
import xarray |
|
|
|
|
|
def normalize(values: xarray.Dataset, |
|
scales: xarray.Dataset, |
|
locations: Optional[xarray.Dataset], |
|
) -> xarray.Dataset: |
|
"""Normalize variables using the given scales and (optionally) locations.""" |
|
def normalize_array(array): |
|
if array.name is None: |
|
raise ValueError( |
|
"Can't look up normalization constants because array has no name.") |
|
if locations is not None: |
|
if array.name in locations: |
|
array = array - locations[array.name].astype(array.dtype) |
|
else: |
|
logging.warning('No normalization location found for %s', array.name) |
|
if array.name in scales: |
|
array = array / scales[array.name].astype(array.dtype) |
|
else: |
|
logging.warning('No normalization scale found for %s', array.name) |
|
return array |
|
return xarray_tree.map_structure(normalize_array, values) |
|
|
|
|
|
def unnormalize(values: xarray.Dataset, |
|
scales: xarray.Dataset, |
|
locations: Optional[xarray.Dataset], |
|
) -> xarray.Dataset: |
|
"""Unnormalize variables using the given scales and (optionally) locations.""" |
|
def unnormalize_array(array): |
|
if array.name is None: |
|
raise ValueError( |
|
"Can't look up normalization constants because array has no name.") |
|
if array.name in scales: |
|
array = array * scales[array.name].astype(array.dtype) |
|
else: |
|
logging.warning('No normalization scale found for %s', array.name) |
|
if locations is not None: |
|
if array.name in locations: |
|
array = array + locations[array.name].astype(array.dtype) |
|
else: |
|
logging.warning('No normalization location found for %s', array.name) |
|
return array |
|
return xarray_tree.map_structure(unnormalize_array, values) |
|
|
|
|
|
class InputsAndResiduals(predictor_base.Predictor): |
|
"""Wraps with a residual connection, normalizing inputs and target residuals. |
|
|
|
The inner predictor is given inputs that are normalized using `locations` |
|
and `scales` to roughly zero-mean unit variance. |
|
|
|
For target variables that are present in the inputs, the inner predictor is |
|
trained to predict residuals (target - last_frame_of_input) that have been |
|
normalized using `residual_scales` (and optionally `residual_locations`) to |
|
roughly unit variance / zero mean. |
|
|
|
This replaces `residual.Predictor` in the case where you want normalization |
|
that's based on the scales of the residuals. |
|
|
|
Since we return the underlying predictor's loss on the normalized residuals, |
|
if the underlying predictor is a sum of per-variable losses, the normalization |
|
will affect the relative weighting of the per-variable loss terms (hopefully |
|
in a good way). |
|
|
|
For target variables *not* present in the inputs, the inner predictor is |
|
trained to predict targets directly, that have been normalized in the same |
|
way as the inputs. |
|
|
|
The transforms applied to the targets (the residual connection and the |
|
normalization) are applied in reverse to the predictions before returning |
|
them. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
predictor: predictor_base.Predictor, |
|
stddev_by_level: xarray.Dataset, |
|
mean_by_level: xarray.Dataset, |
|
diffs_stddev_by_level: xarray.Dataset): |
|
self._predictor = predictor |
|
self._scales = stddev_by_level |
|
self._locations = mean_by_level |
|
self._residual_scales = diffs_stddev_by_level |
|
self._residual_locations = None |
|
|
|
def _unnormalize_prediction_and_add_input(self, inputs, norm_prediction): |
|
if norm_prediction.sizes.get('time') != 1: |
|
raise ValueError( |
|
'normalization.InputsAndResiduals only supports predicting a ' |
|
'single timestep.') |
|
if norm_prediction.name in inputs: |
|
|
|
|
|
|
|
prediction = unnormalize( |
|
norm_prediction, self._residual_scales, self._residual_locations) |
|
|
|
|
|
last_input = inputs[norm_prediction.name].isel(time=-1) |
|
prediction = prediction + last_input |
|
return prediction |
|
else: |
|
|
|
|
|
return unnormalize(norm_prediction, self._scales, self._locations) |
|
|
|
def _subtract_input_and_normalize_target(self, inputs, target): |
|
if target.sizes.get('time') != 1: |
|
raise ValueError( |
|
'normalization.InputsAndResiduals only supports wrapping predictors' |
|
'that predict a single timestep.') |
|
if target.name in inputs: |
|
target_residual = target |
|
last_input = inputs[target.name].isel(time=-1) |
|
target_residual = target_residual - last_input |
|
return normalize( |
|
target_residual, self._residual_scales, self._residual_locations) |
|
else: |
|
return normalize(target, self._scales, self._locations) |
|
|
|
def __call__(self, |
|
inputs: xarray.Dataset, |
|
targets_template: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**kwargs |
|
) -> xarray.Dataset: |
|
norm_inputs = normalize(inputs, self._scales, self._locations) |
|
norm_forcings = normalize(forcings, self._scales, self._locations) |
|
norm_predictions = self._predictor( |
|
norm_inputs, targets_template, forcings=norm_forcings, **kwargs) |
|
return xarray_tree.map_structure( |
|
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), |
|
norm_predictions) |
|
|
|
def loss(self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**kwargs, |
|
) -> predictor_base.LossAndDiagnostics: |
|
"""Returns the loss computed on normalized inputs and targets.""" |
|
norm_inputs = normalize(inputs, self._scales, self._locations) |
|
norm_forcings = normalize(forcings, self._scales, self._locations) |
|
norm_target_residuals = xarray_tree.map_structure( |
|
lambda t: self._subtract_input_and_normalize_target(inputs, t), |
|
targets) |
|
return self._predictor.loss( |
|
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) |
|
|
|
def loss_and_predictions( |
|
self, |
|
inputs: xarray.Dataset, |
|
targets: xarray.Dataset, |
|
forcings: xarray.Dataset, |
|
**kwargs, |
|
) -> Tuple[predictor_base.LossAndDiagnostics, |
|
xarray.Dataset]: |
|
"""The loss computed on normalized data, with unnormalized predictions.""" |
|
norm_inputs = normalize(inputs, self._scales, self._locations) |
|
norm_forcings = normalize(forcings, self._scales, self._locations) |
|
norm_target_residuals = xarray_tree.map_structure( |
|
lambda t: self._subtract_input_and_normalize_target(inputs, t), |
|
targets) |
|
(loss, scalars), norm_predictions = self._predictor.loss_and_predictions( |
|
norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs) |
|
predictions = xarray_tree.map_structure( |
|
lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred), |
|
norm_predictions) |
|
return (loss, scalars), predictions |
|
|