|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for xarray_jax.""" |
|
|
|
from absl.testing import absltest |
|
import chex |
|
from graphcast import xarray_jax |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import xarray |
|
|
|
|
|
class XarrayJaxTest(absltest.TestCase): |
|
|
|
def test_jax_array_wrapper_with_numpy_api(self): |
|
|
|
|
|
|
|
ones = jnp.ones((3, 4), dtype=np.float32) |
|
x = xarray_jax.JaxArrayWrapper(ones) |
|
x = np.abs((x + 2) * (x - 3)) |
|
x = x[:-1, 1:3] |
|
x = np.concatenate([x, x + 1], axis=0) |
|
x = np.transpose(x, (1, 0)) |
|
x = np.reshape(x, (-1,)) |
|
x = x.astype(np.int32) |
|
self.assertIsInstance(x, xarray_jax.JaxArrayWrapper) |
|
|
|
self.assertIsInstance(np.asarray(x), np.ndarray) |
|
|
|
def test_jax_xarray_variable(self): |
|
def ops_via_xarray(inputs): |
|
x = xarray_jax.Variable(('lat', 'lon'), inputs) |
|
|
|
|
|
x = np.abs((x + 2) * (x - 3)) |
|
x = x.isel({'lat': slice(0, -1), 'lon': slice(1, 3)}) |
|
x = xarray.Variable.concat([x, x + 1], dim='lat') |
|
x = x.transpose('lon', 'lat') |
|
x = x.stack(channels=('lon', 'lat')) |
|
x = x.sum() |
|
return xarray_jax.jax_data(x) |
|
|
|
|
|
ones = jnp.ones((3, 4), dtype=np.float32) |
|
result = ops_via_xarray(ones) |
|
self.assertIsInstance(result, jax.Array) |
|
|
|
|
|
|
|
jax.jit(ops_via_xarray)(ones) |
|
jax.grad(ops_via_xarray)(ones) |
|
|
|
def test_jax_xarray_data_array(self): |
|
def ops_via_xarray(inputs): |
|
x = xarray_jax.DataArray(dims=('lat', 'lon'), |
|
data=inputs, |
|
coords={'lat': np.arange(3) * 10, |
|
'lon': np.arange(4) * 10}) |
|
x = np.abs((x + 2) * (x - 3)) |
|
x = x.sel({'lat': slice(0, 20)}) |
|
y = xarray_jax.DataArray(dims=('lat', 'lon'), |
|
data=ones, |
|
coords={'lat': np.arange(3, 6) * 10, |
|
'lon': np.arange(4) * 10}) |
|
x = xarray.concat([x, y], dim='lat') |
|
x = x.transpose('lon', 'lat') |
|
x = x.stack(channels=('lon', 'lat')) |
|
x = x.unstack() |
|
x = x.sum() |
|
return xarray_jax.jax_data(x) |
|
|
|
ones = jnp.ones((3, 4), dtype=np.float32) |
|
result = ops_via_xarray(ones) |
|
self.assertIsInstance(result, jax.Array) |
|
|
|
jax.jit(ops_via_xarray)(ones) |
|
jax.grad(ops_via_xarray)(ones) |
|
|
|
def test_jax_xarray_dataset(self): |
|
def ops_via_xarray(foo, bar): |
|
x = xarray_jax.Dataset( |
|
data_vars={'foo': (('lat', 'lon'), foo), |
|
'bar': (('time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'time': np.arange(2), |
|
'lat': np.arange(3) * 10, |
|
'lon': np.arange(4) * 10}) |
|
x = np.abs((x + 2) * (x - 3)) |
|
x = x.sel({'lat': slice(0, 20)}) |
|
y = xarray_jax.Dataset( |
|
data_vars={'foo': (('lat', 'lon'), foo), |
|
'bar': (('time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'time': np.arange(2), |
|
'lat': np.arange(3, 6) * 10, |
|
'lon': np.arange(4) * 10}) |
|
x = xarray.concat([x, y], dim='lat') |
|
x = x.transpose('lon', 'lat', 'time') |
|
x = x.stack(channels=('lon', 'lat')) |
|
x = (x.foo + x.bar).sum() |
|
return xarray_jax.jax_data(x) |
|
|
|
foo = jnp.ones((3, 4), dtype=np.float32) |
|
bar = jnp.ones((2, 3, 4), dtype=np.float32) |
|
result = ops_via_xarray(foo, bar) |
|
self.assertIsInstance(result, jax.Array) |
|
|
|
jax.jit(ops_via_xarray)(foo, bar) |
|
jax.grad(ops_via_xarray)(foo, bar) |
|
|
|
def test_jit_function_with_xarray_variable_arguments_and_return(self): |
|
function = jax.jit(lambda v: v + 1) |
|
with self.subTest('jax input'): |
|
inputs = xarray_jax.Variable( |
|
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) |
|
_ = function(inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = function(inputs) |
|
self.assertEqual(outputs.dims, inputs.dims) |
|
with self.subTest('numpy input'): |
|
inputs = xarray.Variable( |
|
('lat', 'lon'), np.ones((3, 4), dtype=np.float32)) |
|
_ = function(inputs) |
|
outputs = function(inputs) |
|
self.assertEqual(outputs.dims, inputs.dims) |
|
|
|
def test_jit_problem_if_convert_to_plain_numpy_array(self): |
|
inputs = xarray_jax.DataArray( |
|
data=jnp.ones((2,), dtype=np.float32), dims=('foo',)) |
|
with self.assertRaises(jax.errors.TracerArrayConversionError): |
|
|
|
jax.jit(lambda data_array: data_array.values)(inputs) |
|
|
|
def test_grad_function_with_xarray_variable_arguments(self): |
|
x = xarray_jax.Variable(('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) |
|
|
|
jax.grad(lambda v: xarray_jax.jax_data(v.sum()))(x) |
|
|
|
def test_jit_function_with_xarray_data_array_arguments_and_return(self): |
|
inputs = xarray_jax.DataArray( |
|
data=jnp.ones((3, 4), dtype=np.float32), |
|
dims=('lat', 'lon'), |
|
coords={'lat': np.arange(3), |
|
'lon': np.arange(4) * 10}) |
|
fn = jax.jit(lambda v: v + 1) |
|
_ = fn(inputs) |
|
outputs = fn(inputs) |
|
self.assertEqual(outputs.dims, inputs.dims) |
|
chex.assert_trees_all_equal(outputs.coords, inputs.coords) |
|
|
|
def test_jit_function_with_data_array_and_jax_coords(self): |
|
inputs = xarray_jax.DataArray( |
|
data=jnp.ones((3, 4), dtype=np.float32), |
|
dims=('lat', 'lon'), |
|
coords={'lat': np.arange(3)}, |
|
jax_coords={'lon': jnp.arange(4) * 10}) |
|
|
|
|
|
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', inputs.indexes) |
|
|
|
@jax.jit |
|
def fn(v): |
|
|
|
self.assertIsInstance(v.coords['lat'].data, np.ndarray) |
|
self.assertIn('lat', v.indexes) |
|
|
|
|
|
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', v.indexes) |
|
|
|
|
|
v = v + v.coords['lon'] |
|
|
|
|
|
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1) |
|
|
|
_ = fn(inputs) |
|
outputs = fn(inputs) |
|
|
|
|
|
self.assertIsInstance( |
|
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', outputs.indexes) |
|
|
|
self.assertEqual(outputs.dims, inputs.dims) |
|
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat']) |
|
|
|
chex.assert_trees_all_equal( |
|
outputs.coords['lon'].data, (inputs.coords['lon']+1).data) |
|
chex.assert_trees_all_equal( |
|
outputs.data, (inputs + inputs.coords['lon']).data) |
|
|
|
def test_jit_function_with_xarray_dataset_arguments_and_return(self): |
|
foo = jnp.ones((3, 4), dtype=np.float32) |
|
bar = jnp.ones((2, 3, 4), dtype=np.float32) |
|
inputs = xarray_jax.Dataset( |
|
data_vars={'foo': (('lat', 'lon'), foo), |
|
'bar': (('time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'time': np.arange(2), |
|
'lat': np.arange(3) * 10, |
|
'lon': np.arange(4) * 10}) |
|
fn = jax.jit(lambda v: v + 1) |
|
_ = fn(inputs) |
|
outputs = fn(inputs) |
|
self.assertEqual({'foo', 'bar'}, outputs.data_vars.keys()) |
|
self.assertEqual(inputs.foo.dims, outputs.foo.dims) |
|
self.assertEqual(inputs.bar.dims, outputs.bar.dims) |
|
chex.assert_trees_all_equal(outputs.coords, inputs.coords) |
|
|
|
def test_jit_function_with_dataset_and_jax_coords(self): |
|
foo = jnp.ones((3, 4), dtype=np.float32) |
|
bar = jnp.ones((2, 3, 4), dtype=np.float32) |
|
inputs = xarray_jax.Dataset( |
|
data_vars={'foo': (('lat', 'lon'), foo), |
|
'bar': (('time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'time': np.arange(2), |
|
'lat': np.arange(3) * 10, |
|
}, |
|
jax_coords={'lon': jnp.arange(4) * 10} |
|
) |
|
|
|
|
|
self.assertIsInstance(inputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', inputs.indexes) |
|
|
|
@jax.jit |
|
def fn(v): |
|
|
|
self.assertIsInstance(v.coords['lat'].data, np.ndarray) |
|
self.assertIn('lat', v.indexes) |
|
|
|
|
|
self.assertIsInstance(v.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', v.indexes) |
|
|
|
|
|
v = v + v.coords['lon'] |
|
|
|
|
|
return xarray_jax.assign_jax_coords(v, lon=v.coords['lon'] + 1) |
|
|
|
_ = fn(inputs) |
|
outputs = fn(inputs) |
|
|
|
|
|
self.assertIsInstance( |
|
outputs.coords['lon'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('lon', outputs.indexes) |
|
|
|
self.assertEqual(outputs.dims, inputs.dims) |
|
chex.assert_trees_all_equal(outputs.coords['lat'], inputs.coords['lat']) |
|
|
|
chex.assert_trees_all_equal( |
|
(outputs.coords['lon']).data, |
|
(inputs.coords['lon']+1).data, |
|
) |
|
outputs_dict = {key: outputs[key].data for key in outputs} |
|
inputs_and_inputs_coords_dict = { |
|
key: (inputs + inputs.coords['lon'])[key].data |
|
for key in inputs + inputs.coords['lon'] |
|
} |
|
chex.assert_trees_all_equal(outputs_dict, inputs_and_inputs_coords_dict) |
|
|
|
def test_flatten_unflatten_variable(self): |
|
variable = xarray_jax.Variable( |
|
('lat', 'lon'), jnp.ones((3, 4), dtype=np.float32)) |
|
children, aux = xarray_jax._flatten_variable(variable) |
|
|
|
hash(aux) |
|
self.assertEqual(aux, aux) |
|
roundtrip = xarray_jax._unflatten_variable(aux, children) |
|
self.assertTrue(variable.equals(roundtrip)) |
|
|
|
def test_flatten_unflatten_data_array(self): |
|
data_array = xarray_jax.DataArray( |
|
data=jnp.ones((3, 4), dtype=np.float32), |
|
dims=('lat', 'lon'), |
|
coords={'lat': np.arange(3)}, |
|
jax_coords={'lon': np.arange(4) * 10}, |
|
) |
|
children, aux = xarray_jax._flatten_data_array(data_array) |
|
|
|
hash(aux) |
|
self.assertEqual(aux, aux) |
|
roundtrip = xarray_jax._unflatten_data_array(aux, children) |
|
self.assertTrue(data_array.equals(roundtrip)) |
|
|
|
def test_flatten_unflatten_dataset(self): |
|
foo = jnp.ones((3, 4), dtype=np.float32) |
|
bar = jnp.ones((2, 3, 4), dtype=np.float32) |
|
dataset = xarray_jax.Dataset( |
|
data_vars={'foo': (('lat', 'lon'), foo), |
|
'bar': (('time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'time': np.arange(2), |
|
'lat': np.arange(3) * 10}, |
|
jax_coords={ |
|
'lon': np.arange(4) * 10}) |
|
children, aux = xarray_jax._flatten_dataset(dataset) |
|
|
|
hash(aux) |
|
self.assertEqual(aux, aux) |
|
roundtrip = xarray_jax._unflatten_dataset(aux, children) |
|
self.assertTrue(dataset.equals(roundtrip)) |
|
|
|
def test_flatten_unflatten_added_dim(self): |
|
data_array = xarray_jax.DataArray( |
|
data=jnp.ones((3, 4), dtype=np.float32), |
|
dims=('lat', 'lon'), |
|
coords={'lat': np.arange(3), |
|
'lon': np.arange(4) * 10}) |
|
leaves, treedef = jax.tree_util.tree_flatten(data_array) |
|
leaves = [jnp.expand_dims(x, 0) for x in leaves] |
|
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims): |
|
with_new_dim = jax.tree_util.tree_unflatten(treedef, leaves) |
|
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims) |
|
xarray.testing.assert_identical( |
|
jax.device_get(data_array), |
|
jax.device_get(with_new_dim.isel(new=0))) |
|
|
|
def test_map_added_dim(self): |
|
data_array = xarray_jax.DataArray( |
|
data=jnp.ones((3, 4), dtype=np.float32), |
|
dims=('lat', 'lon'), |
|
coords={'lat': np.arange(3), |
|
'lon': np.arange(4) * 10}) |
|
with xarray_jax.dims_change_on_unflatten(lambda dims: ('new',) + dims): |
|
with_new_dim = jax.tree_util.tree_map(lambda x: jnp.expand_dims(x, 0), |
|
data_array) |
|
self.assertEqual(('new', 'lat', 'lon'), with_new_dim.dims) |
|
xarray.testing.assert_identical( |
|
jax.device_get(data_array), |
|
jax.device_get(with_new_dim.isel(new=0))) |
|
|
|
def test_map_remove_dim(self): |
|
foo = jnp.ones((1, 3, 4), dtype=np.float32) |
|
bar = jnp.ones((1, 2, 3, 4), dtype=np.float32) |
|
dataset = xarray_jax.Dataset( |
|
data_vars={'foo': (('batch', 'lat', 'lon'), foo), |
|
'bar': (('batch', 'time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'batch': np.array([123]), |
|
'time': np.arange(2), |
|
'lat': np.arange(3) * 10, |
|
'lon': np.arange(4) * 10}) |
|
with xarray_jax.dims_change_on_unflatten(lambda dims: dims[1:]): |
|
with_removed_dim = jax.tree_util.tree_map(lambda x: jnp.squeeze(x, 0), |
|
dataset) |
|
self.assertEqual(('lat', 'lon'), with_removed_dim['foo'].dims) |
|
self.assertEqual(('time', 'lat', 'lon'), with_removed_dim['bar'].dims) |
|
self.assertNotIn('batch', with_removed_dim.dims) |
|
self.assertNotIn('batch', with_removed_dim.coords) |
|
xarray.testing.assert_identical( |
|
jax.device_get(dataset.isel(batch=0, drop=True)), |
|
jax.device_get(with_removed_dim)) |
|
|
|
def test_pmap(self): |
|
devices = jax.local_device_count() |
|
foo = jnp.zeros((devices, 3, 4), dtype=np.float32) |
|
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32) |
|
dataset = xarray_jax.Dataset({ |
|
'foo': (('device', 'lat', 'lon'), foo), |
|
'bar': (('device', 'time', 'lat', 'lon'), bar)}) |
|
|
|
def func(d): |
|
self.assertNotIn('device', d.dims) |
|
return d + 1 |
|
func = xarray_jax.pmap(func, dim='device') |
|
|
|
result = func(dataset) |
|
xarray.testing.assert_identical( |
|
jax.device_get(dataset + 1), |
|
jax.device_get(result)) |
|
|
|
|
|
|
|
dataset = dataset.drop_vars('foo') |
|
result = func(dataset) |
|
xarray.testing.assert_identical( |
|
jax.device_get(dataset + 1), |
|
jax.device_get(result)) |
|
|
|
def test_pmap_with_jax_coords(self): |
|
devices = jax.local_device_count() |
|
foo = jnp.zeros((devices, 3, 4), dtype=np.float32) |
|
bar = jnp.zeros((devices, 2, 3, 4), dtype=np.float32) |
|
time = jnp.zeros((devices, 2), dtype=np.float32) |
|
dataset = xarray_jax.Dataset( |
|
{'foo': (('device', 'lat', 'lon'), foo), |
|
'bar': (('device', 'time', 'lat', 'lon'), bar)}, |
|
coords={ |
|
'lat': np.arange(3), |
|
'lon': np.arange(4), |
|
}, |
|
jax_coords={ |
|
|
|
|
|
|
|
|
|
'time': xarray_jax.Variable(('device', 'time'), time), |
|
} |
|
) |
|
|
|
def func(d): |
|
self.assertNotIn('device', d.dims) |
|
self.assertNotIn('device', d.coords['time'].dims) |
|
|
|
|
|
|
|
self.assertIsInstance(d.coords['time'].data, xarray_jax.JaxArrayWrapper) |
|
self.assertNotIn('time', d.indexes) |
|
|
|
return d + 1 |
|
func = xarray_jax.pmap(func, dim='device') |
|
|
|
result = func(dataset) |
|
xarray.testing.assert_identical( |
|
jax.device_get(dataset + 1), |
|
jax.device_get(result)) |
|
|
|
|
|
|
|
dataset = dataset.drop_vars('foo') |
|
result = func(dataset) |
|
xarray.testing.assert_identical( |
|
jax.device_get(dataset + 1), |
|
jax.device_get(result)) |
|
|
|
def test_pmap_with_tree_mix_of_xarray_and_jax_array(self): |
|
devices = jax.local_device_count() |
|
data_array = xarray_jax.DataArray( |
|
data=jnp.ones((devices, 3, 4), dtype=np.float32), |
|
dims=('device', 'lat', 'lon')) |
|
plain_array = jnp.ones((devices, 2), dtype=np.float32) |
|
inputs = {'foo': data_array, |
|
'bar': plain_array} |
|
|
|
def func(x): |
|
return x['foo'] + 1, x['bar'] + 1 |
|
|
|
func = xarray_jax.pmap(func, dim='device') |
|
result_foo, result_bar = func(inputs) |
|
xarray.testing.assert_identical( |
|
jax.device_get(inputs['foo'] + 1), |
|
jax.device_get(result_foo)) |
|
np.testing.assert_array_equal( |
|
jax.device_get(inputs['bar'] + 1), |
|
jax.device_get(result_bar)) |
|
|
|
def test_pmap_complains_when_dim_not_first(self): |
|
devices = jax.local_device_count() |
|
data_array = xarray_jax.DataArray( |
|
data=jnp.ones((3, devices, 4), dtype=np.float32), |
|
dims=('lat', 'device', 'lon')) |
|
|
|
func = xarray_jax.pmap(lambda x: x+1, dim='device') |
|
with self.assertRaisesRegex( |
|
ValueError, 'Expected dim device at index 0, found at 1'): |
|
func(data_array) |
|
|
|
def test_apply_ufunc(self): |
|
inputs = xarray_jax.DataArray( |
|
data=jnp.asarray([[1, 2], [3, 4]]), |
|
dims=('x', 'y'), |
|
coords={'x': [0, 1], |
|
'y': [2, 3]}) |
|
result = xarray_jax.apply_ufunc( |
|
lambda x: jnp.sum(x, axis=-1), |
|
inputs, |
|
input_core_dims=[['x']]) |
|
expected_result = xarray_jax.DataArray( |
|
data=[4, 6], |
|
dims=('y',), |
|
coords={'y': [2, 3]}) |
|
xarray.testing.assert_identical(expected_result, jax.device_get(result)) |
|
|
|
def test_apply_ufunc_multiple_return_values(self): |
|
def ufunc(array): |
|
return jnp.min(array, axis=-1), jnp.max(array, axis=-1) |
|
inputs = xarray_jax.DataArray( |
|
data=jnp.asarray([[1, 4], [3, 2]]), |
|
dims=('x', 'y'), |
|
coords={'x': [0, 1], |
|
'y': [2, 3]}) |
|
result = xarray_jax.apply_ufunc( |
|
ufunc, inputs, input_core_dims=[['x']], output_core_dims=[[], []]) |
|
expected = ( |
|
|
|
xarray_jax.DataArray( |
|
data=[1, 2], |
|
dims=('y',), |
|
coords={'y': [2, 3]} |
|
), |
|
|
|
xarray_jax.DataArray( |
|
data=[3, 4], |
|
dims=('y',), |
|
coords={'y': [2, 3]} |
|
) |
|
) |
|
xarray.testing.assert_identical(expected[0], jax.device_get(result[0])) |
|
xarray.testing.assert_identical(expected[1], jax.device_get(result[1])) |
|
|
|
if __name__ == '__main__': |
|
absltest.main() |
|
|