weather / graphcast /xarray_jax_test.py
Gary0205's picture
Upload 25 files
6d70ed4 verified
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for xarray_jax."""
from absl.testing import absltest
import chex
from graphcast import xarray_jax
import jax
import jax.numpy as jnp
import numpy as np
import xarray
class XarrayJaxTest(absltest.TestCase):
def test_jax_array_wrapper_with_numpy_api(self):
# This is just a side benefit of making things work with xarray, but the
# JaxArrayWrapper does allow you to manipulate JAX arrays using the
# standard numpy API, without converting them to numpy in the process:
ones = jnp.ones((3, 4), dtype=np.float32)
x = xarray_jax.JaxArrayWrapper(ones)
x = np.abs((x + 2) * (x - 3))
x = x[:-1, 1:3]
x = np.concatenate([x, x + 1], axis=0)
x = np.transpose(x, (1, 0))
x = np.reshape(x, (-1,))
x = x.astype(np.int32)
self.assertIsInstance(x, xarray_jax.JaxArrayWrapper)
# An explicit conversion gets us out of JAX-land however:
self.assertIsInstance(np.asarray(x), np.ndarray)
def test_jax_xarray_variable(self):
def ops_via_xarray(inputs):
x = xarray_jax.Variable(('lat', 'lon'), inputs)
# We'll apply a sequence of operations just to test that the end result is
# still a JAX array, i.e. we haven't converted to numpy at any point.
x = np.abs((x + 2) * (x - 3))
x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)})
x = xarray.Variable.concat([x, x + 1], dim='lat')
x = x.transpose('lon', 'lat')
x = x.stack(channels=('lon', 'lat'))
x = x.sum()
return xarray_jax.jax_data(x)
# Check it doesn't leave jax-land when passed concrete values:
ones = jnp.ones((3, 4), dtype=np.float32)
result = ops_via_xarray(ones)
self.assertIsInstance(result, jax.Array)
# And that you can JIT it and compute gradients through it. These will
# involve passing jax tracers through the xarray computation:
jax.jit(ops_via_xarray)(ones)
jax.grad(ops_via_xarray)(ones)
def test_jax_xarray_data_array(self):
def ops_via_xarray(inputs):
x = xarray_jax.DataArray(dims=('lat', 'lon'),
data=inputs,
coords={'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
x = np.abs((x + 2) * (x - 3))
x = x.sel({'lat': slice(0, 20)})
y = xarray_jax.DataArray(dims=('lat', 'lon'),
data=ones,
coords={'lat': np.arange(3, 6) * 10,
'lon': np.arange(4) * 10})
x = xarray.concat([x, y], dim='lat')
x = x.transpose('lon', 'lat')
x = x.stack(channels=('lon', 'lat'))
x = x.unstack()
x = x.sum()
return xarray_jax.jax_data(x)
ones = jnp.ones((3, 4), dtype=np.float32)
result = ops_via_xarray(ones)
self.assertIsInstance(result, jax.Array)
jax.jit(ops_via_xarray)(ones)
jax.grad(ops_via_xarray)(ones)
def test_jax_xarray_dataset(self):
def ops_via_xarray(foo, bar):
x = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
x = np.abs((x + 2) * (x - 3))
x = x.sel({'lat': slice(0, 20)})
y = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3, 6) * 10,
'lon': np.arange(4) * 10})
x = xarray.concat([x, y], dim='lat')
x = x.transpose('lon', 'lat', 'time')
x = x.stack(channels=('lon', 'lat'))
x = (x.foo + x.bar).sum()
return xarray_jax.jax_data(x)
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
result = ops_via_xarray(foo, bar)
self.assertIsInstance(result, jax.Array)
jax.jit(ops_via_xarray)(foo, bar)
jax.grad(ops_via_xarray)(foo, bar)
def test_jit_function_with_xarray_variable_arguments_and_return(self):
function = jax.jit(lambda v: v + 1)
with self.subTest('jax input'):
inputs = xarray_jax.Variable(
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
_ = function(inputs)
# We test running the jitted function a second time, to exercise logic in
# jax which checks if the structure of the inputs (including dimension
# names and coordinates) is the same as it was for the previous call and
# so whether it needs to re-trace-and-compile a new version of the
# function or not. This can run into problems if the 'aux' structure
# returned by the registered flatten function is not hashable/comparable.
outputs = function(inputs)
self.assertEqual(outputs.dims, inputs.dims)
with self.subTest('numpy input'):
inputs = xarray.Variable(
('lat', 'lon'), np.ones((3, 4), dtype=np.float32))
_ = function(inputs)
outputs = function(inputs)
self.assertEqual(outputs.dims, inputs.dims)
def test_jit_problem_if_convert_to_plain_numpy_array(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((2,), dtype=np.float32), dims=('foo',))
with self.assertRaises(jax.errors.TracerArrayConversionError):
# Calling .values on a DataArray converts its values to numpy:
jax.jit(lambda data_array: data_array.values)(inputs)
def test_grad_function_with_xarray_variable_arguments(self):
x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
# For grad we still need a JAX scalar as the output:
jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x)
def test_jit_function_with_xarray_data_array_arguments_and_return(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
fn = jax.jit(lambda v: v + 1)
_ = fn(inputs)
outputs = fn(inputs)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
def test_jit_function_with_data_array_and_jax_coords(self):
inputs = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3)},
jax_coords={'lon': jnp.arange(4) * 10})
# Verify the jax_coord 'lon' retains jax data, and has not been created
# as an index coordinate:
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', inputs.indexes)
@jax.jit
def fn(v):
# The non-JAX coord is passed with numpy array data and an index:
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
self.assertIn('lat', v.indexes)
# The jax_coord is passed with JAX array data:
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', v.indexes)
# Use the jax coord in the computation:
v = v + v.coords['lon']
# Return with an updated jax coord:
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
_ = fn(inputs)
outputs = fn(inputs)
# Verify the jax_coord 'lon' has jax data in the output too:
self.assertIsInstance(
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', outputs.indexes)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
# Check our computations with the coordinate values worked:
chex.assert_trees_all_equal(
outputs.coords['lon'].data, (inputs.coords['lon']+1).data)
chex.assert_trees_all_equal(
outputs.data, (inputs + inputs.coords['lon']).data)
def test_jit_function_with_xarray_dataset_arguments_and_return(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
inputs = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
fn = jax.jit(lambda v: v + 1)
_ = fn(inputs)
outputs = fn(inputs)
self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys())
self.assertEqual(inputs.foo.dims, outputs.foo.dims)
self.assertEqual(inputs.bar.dims, outputs.bar.dims)
chex.assert_trees_all_equal(outputs.coords, inputs.coords)
def test_jit_function_with_dataset_and_jax_coords(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
inputs = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10,
},
jax_coords={'lon': jnp.arange(4) * 10}
)
# Verify the jax_coord 'lon' retains jax data, and has not been created
# as an index coordinate:
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', inputs.indexes)
@jax.jit
def fn(v):
# The non-JAX coords are passed with numpy array data and an index:
self.assertIsInstance(v.coords['lat'].data, np.ndarray)
self.assertIn('lat', v.indexes)
# The jax_coord is passed with JAX array data:
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', v.indexes)
# Use the jax coord in the computation:
v = v + v.coords['lon']
# Return with an updated jax coord:
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1)
_ = fn(inputs)
outputs = fn(inputs)
# Verify the jax_coord 'lon' has jax data in the output too:
self.assertIsInstance(
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('lon', outputs.indexes)
self.assertEqual(outputs.dims, inputs.dims)
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat'])
# Check our computations with the coordinate values worked:
chex.assert_trees_all_equal(
(outputs.coords['lon']).data,
(inputs.coords['lon']+1).data,
)
outputs_dict = {key: outputs[key].data for key in outputs}
inputs_and_inputs_coords_dict = {
key: (inputs + inputs.coords['lon'])[key].data
for key in inputs + inputs.coords['lon']
}
chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict)
def test_flatten_unflatten_variable(self):
variable = xarray_jax.Variable(
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32))
children, aux = xarray_jax._flatten_variable(variable)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_variable(aux, children)
self.assertTrue(variable.equals(roundtrip))
def test_flatten_unflatten_data_array(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3)},
jax_coords={'lon': np.arange(4) * 10},
)
children, aux = xarray_jax._flatten_data_array(data_array)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_data_array(aux, children)
self.assertTrue(data_array.equals(roundtrip))
def test_flatten_unflatten_dataset(self):
foo = jnp.ones((3, 4), dtype=np.float32)
bar = jnp.ones((2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset(
data_vars={'foo': (('lat', 'lon'), foo),
'bar': (('time', 'lat', 'lon'), bar)},
coords={
'time': np.arange(2),
'lat': np.arange(3) * 10},
jax_coords={
'lon': np.arange(4) * 10})
children, aux = xarray_jax._flatten_dataset(dataset)
# Check auxiliary info is hashable/comparable (important for jax.jit):
hash(aux)
self.assertEqual(aux, aux)
roundtrip = xarray_jax._unflatten_dataset(aux, children)
self.assertTrue(dataset.equals(roundtrip))
def test_flatten_unflatten_added_dim(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
leaves, treedef = jax.tree_util.tree_flatten(data_array)
leaves = [jnp.expand_dims(x, 0) for x in leaves]
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves)
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
xarray.testing.assert_identical(
jax.device_get(data_array),
jax.device_get(with_new_dim.isel(new=0)))
def test_map_added_dim(self):
data_array = xarray_jax.DataArray(
data=jnp.ones((3, 4), dtype=np.float32),
dims=('lat', 'lon'),
coords={'lat': np.arange(3),
'lon': np.arange(4) * 10})
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims):
with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0),
data_array)
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims)
xarray.testing.assert_identical(
jax.device_get(data_array),
jax.device_get(with_new_dim.isel(new=0)))
def test_map_remove_dim(self):
foo = jnp.ones((1, 3, 4), dtype=np.float32)
bar = jnp.ones((1, 2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset(
data_vars={'foo': (('batch', 'lat', 'lon'), foo),
'bar': (('batch', 'time', 'lat', 'lon'), bar)},
coords={
'batch': np.array([123]),
'time': np.arange(2),
'lat': np.arange(3) * 10,
'lon': np.arange(4) * 10})
with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]):
with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0),
dataset)
self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims)
self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims)
self.assertNotIn('batch', with_removed_dim.dims)
self.assertNotIn('batch', with_removed_dim.coords)
xarray.testing.assert_identical(
jax.device_get(dataset.isel(batch=0, drop=True)),
jax.device_get(with_removed_dim))
def test_pmap(self):
devices = jax.local_device_count()
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
dataset = xarray_jax.Dataset({
'foo': (('device', 'lat', 'lon'), foo),
'bar': (('device', 'time', 'lat', 'lon'), bar)})
def func(d):
self.assertNotIn('device', d.dims)
return d + 1
func = xarray_jax.pmap(func, dim='device')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
# Can call it again with a different argument structure (it will recompile
# under the hood but should work):
dataset = dataset.drop_vars('foo')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
def test_pmap_with_jax_coords(self):
devices = jax.local_device_count()
foo = jnp.zeros((devices, 3, 4), dtype=np.float32)
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32)
time = jnp.zeros((devices, 2), dtype=np.float32)
dataset = xarray_jax.Dataset(
{'foo': (('device', 'lat', 'lon'), foo),
'bar': (('device', 'time', 'lat', 'lon'), bar)},
coords={
'lat': np.arange(3),
'lon': np.arange(4),
},
jax_coords={
# Currently any jax_coords need a leading device dimension to use
# with pmap, same as for data_vars.
# TODO(matthjw): have pmap automatically broadcast to all devices
# where the device dimension not present.
'time': xarray_jax.Variable(('device', 'time'), time),
}
)
def func(d):
self.assertNotIn('device', d.dims)
self.assertNotIn('device', d.coords['time'].dims)
# The jax_coord 'time' should be passed in backed by a JAX array, but
# not as an index coordinate.
self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper)
self.assertNotIn('time', d.indexes)
return d + 1
func = xarray_jax.pmap(func, dim='device')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
# Can call it again with a different argument structure (it will recompile
# under the hood but should work):
dataset = dataset.drop_vars('foo')
result = func(dataset)
xarray.testing.assert_identical(
jax.device_get(dataset + 1),
jax.device_get(result))
def test_pmap_with_tree_mix_of_xarray_and_jax_array(self):
devices = jax.local_device_count()
data_array = xarray_jax.DataArray(
data=jnp.ones((devices, 3, 4), dtype=np.float32),
dims=('device', 'lat', 'lon'))
plain_array = jnp.ones((devices, 2), dtype=np.float32)
inputs = {'foo': data_array,
'bar': plain_array}
def func(x):
return x['foo'] + 1, x['bar'] + 1
func = xarray_jax.pmap(func, dim='device')
result_foo, result_bar = func(inputs)
xarray.testing.assert_identical(
jax.device_get(inputs['foo'] + 1),
jax.device_get(result_foo))
np.testing.assert_array_equal(
jax.device_get(inputs['bar'] + 1),
jax.device_get(result_bar))
def test_pmap_complains_when_dim_not_first(self):
devices = jax.local_device_count()
data_array = xarray_jax.DataArray(
data=jnp.ones((3, devices, 4), dtype=np.float32),
dims=('lat', 'device', 'lon'))
func = xarray_jax.pmap(lambda x: x+1, dim='device')
with self.assertRaisesRegex(
ValueError, 'Expected dim device at index 0, found at 1'):
func(data_array)
def test_apply_ufunc(self):
inputs = xarray_jax.DataArray(
data=jnp.asarray([[1, 2], [3, 4]]),
dims=('x', 'y'),
coords={'x': [0, 1],
'y': [2, 3]})
result = xarray_jax.apply_ufunc(
lambda x: jnp.sum(x, axis=-1),
inputs,
input_core_dims=[['x']])
expected_result = xarray_jax.DataArray(
data=[4, 6],
dims=('y',),
coords={'y': [2, 3]})
xarray.testing.assert_identical(expected_result, jax.device_get(result))
def test_apply_ufunc_multiple_return_values(self):
def ufunc(array):
return jnp.min(array, axis=-1), jnp.max(array, axis=-1)
inputs = xarray_jax.DataArray(
data=jnp.asarray([[1, 4], [3, 2]]),
dims=('x', 'y'),
coords={'x': [0, 1],
'y': [2, 3]})
result = xarray_jax.apply_ufunc(
ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []])
expected = (
# Mins:
xarray_jax.DataArray(
data=[1, 2],
dims=('y',),
coords={'y': [2, 3]}
),
# Maxes:
xarray_jax.DataArray(
data=[3, 4],
dims=('y',),
coords={'y': [2, 3]}
)
)
xarray.testing.assert_identical(expected[0], jax.device_get(result[0]))
xarray.testing.assert_identical(expected[1], jax.device_get(result[1]))
if __name__ == '__main__':
absltest.main()