|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for `data_utils.py`.""" |
|
|
|
import datetime |
|
from absl.testing import absltest |
|
from absl.testing import parameterized |
|
from graphcast import data_utils |
|
import numpy as np |
|
import xarray as xa |
|
|
|
|
|
class DataUtilsTest(parameterized.TestCase): |
|
|
|
def setUp(self): |
|
super().setUp() |
|
|
|
np.random.seed(0) |
|
|
|
def test_year_progress_is_zero_at_year_start_or_end(self): |
|
year_progress = data_utils.get_year_progress( |
|
np.array([ |
|
0, |
|
data_utils.AVG_SEC_PER_YEAR, |
|
data_utils.AVG_SEC_PER_YEAR * 42, |
|
]) |
|
) |
|
np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape)) |
|
|
|
def test_year_progress_is_almost_one_before_year_ends(self): |
|
year_progress = data_utils.get_year_progress( |
|
np.array([ |
|
data_utils.AVG_SEC_PER_YEAR - 1, |
|
(data_utils.AVG_SEC_PER_YEAR - 1) * 42, |
|
]) |
|
) |
|
with self.subTest("Year progress values are close to 1"): |
|
self.assertTrue(np.all(year_progress > 0.999)) |
|
with self.subTest("Year progress values != 1"): |
|
self.assertTrue(np.all(year_progress < 1.0)) |
|
|
|
def test_day_progress_computes_for_all_times_and_longitudes(self): |
|
times = np.random.randint(low=0, high=1e10, size=10) |
|
longitudes = np.arange(0, 360.0, 1.0) |
|
day_progress = data_utils.get_day_progress(times, longitudes) |
|
with self.subTest("Day progress is computed for all times and longinutes"): |
|
self.assertSequenceEqual( |
|
day_progress.shape, (len(times), len(longitudes)) |
|
) |
|
|
|
@parameterized.named_parameters( |
|
dict( |
|
testcase_name="random_date_1", |
|
year=1988, |
|
month=11, |
|
day=7, |
|
hour=2, |
|
minute=45, |
|
second=34, |
|
), |
|
dict( |
|
testcase_name="random_date_2", |
|
year=2022, |
|
month=3, |
|
day=12, |
|
hour=7, |
|
minute=1, |
|
second=0, |
|
), |
|
) |
|
def test_day_progress_is_in_between_zero_and_one( |
|
self, year, month, day, hour, minute, second |
|
): |
|
|
|
dt = datetime.datetime(year, month, day, hour, minute, second) |
|
|
|
epoch_time = datetime.datetime(1970, 1, 1) |
|
|
|
seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()]) |
|
|
|
|
|
longitudes = np.arange(0, 360.0, 1.0) |
|
|
|
day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes) |
|
with self.subTest("Day progress >= 0"): |
|
self.assertTrue(np.all(day_progress >= 0.0)) |
|
with self.subTest("Day progress < 1"): |
|
self.assertTrue(np.all(day_progress < 1.0)) |
|
|
|
def test_day_progress_is_zero_at_day_start_or_end(self): |
|
day_progress = data_utils.get_day_progress( |
|
seconds_since_epoch=np.array([ |
|
0, |
|
data_utils.SEC_PER_DAY, |
|
data_utils.SEC_PER_DAY * 42, |
|
]), |
|
longitude=np.array([0.0]), |
|
) |
|
np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape)) |
|
|
|
def test_day_progress_specific_value(self): |
|
day_progress = data_utils.get_day_progress( |
|
seconds_since_epoch=np.array([123]), |
|
longitude=np.array([0.0]), |
|
) |
|
np.testing.assert_array_almost_equal( |
|
day_progress, np.array([[0.00142361]]), decimal=6 |
|
) |
|
|
|
def test_featurize_progress_valid_values_and_dimensions(self): |
|
day_progress = np.array([0.0, 0.45, 0.213]) |
|
feature_dimensions = ("time",) |
|
progress_features = data_utils.featurize_progress( |
|
name="day_progress", dims=feature_dimensions, progress=day_progress |
|
) |
|
for feature in progress_features.values(): |
|
with self.subTest(f"Valid dimensions for {feature}"): |
|
self.assertSequenceEqual(feature.dims, feature_dimensions) |
|
|
|
with self.subTest("Valid values for day_progress"): |
|
np.testing.assert_array_equal( |
|
day_progress, progress_features["day_progress"].values |
|
) |
|
|
|
with self.subTest("Valid values for day_progress_sin"): |
|
np.testing.assert_array_almost_equal( |
|
np.array([0.0, 0.30901699, 0.97309851]), |
|
progress_features["day_progress_sin"].values, |
|
decimal=6, |
|
) |
|
|
|
with self.subTest("Valid values for day_progress_cos"): |
|
np.testing.assert_array_almost_equal( |
|
np.array([1.0, -0.95105652, 0.23038943]), |
|
progress_features["day_progress_cos"].values, |
|
decimal=6, |
|
) |
|
|
|
def test_featurize_progress_invalid_dimensions(self): |
|
year_progress = np.array([0.0, 0.45, 0.213]) |
|
feature_dimensions = ("time", "longitude") |
|
with self.assertRaises(ValueError): |
|
data_utils.featurize_progress( |
|
name="year_progress", dims=feature_dimensions, progress=year_progress |
|
) |
|
|
|
def test_add_derived_vars_variables_added(self): |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3)) |
|
}, |
|
coords={ |
|
"lon": np.array([0.0, 0.5]), |
|
"datetime": np.array([ |
|
datetime.datetime(2021, 1, 1), |
|
datetime.datetime(2023, 1, 1), |
|
datetime.datetime(2023, 1, 3), |
|
]), |
|
}, |
|
) |
|
data_utils.add_derived_vars(data) |
|
all_variables = set(data.variables) |
|
|
|
with self.subTest("Original value was not removed"): |
|
self.assertIn("var1", all_variables) |
|
with self.subTest("Year progress feature was added"): |
|
self.assertIn(data_utils.YEAR_PROGRESS, all_variables) |
|
with self.subTest("Day progress feature was added"): |
|
self.assertIn(data_utils.DAY_PROGRESS, all_variables) |
|
|
|
def test_add_derived_vars_existing_vars_not_overridden(self): |
|
dims = ["x", "lon", "datetime"] |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": (dims, 8 * np.random.randn(2, 2, 3)), |
|
data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)), |
|
data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)), |
|
}, |
|
coords={ |
|
"lon": np.array([0.0, 0.5]), |
|
"datetime": np.array([ |
|
datetime.datetime(2021, 1, 1), |
|
datetime.datetime(2023, 1, 1), |
|
datetime.datetime(2023, 1, 3), |
|
]), |
|
}, |
|
) |
|
|
|
data_utils.add_derived_vars(data) |
|
|
|
with self.subTest("Year progress feature was not overridden"): |
|
np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111) |
|
with self.subTest("Day progress feature was not overridden"): |
|
np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222) |
|
|
|
@parameterized.named_parameters( |
|
dict(testcase_name="missing_datetime", coord_name="lon"), |
|
dict(testcase_name="missing_lon", coord_name="datetime"), |
|
) |
|
def test_add_derived_vars_missing_coordinate_raises_value_error( |
|
self, coord_name |
|
): |
|
with self.subTest(f"Missing {coord_name} coordinate"): |
|
data = xa.Dataset( |
|
data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))}, |
|
coords={ |
|
coord_name: np.array([0.0, 0.5]), |
|
}, |
|
) |
|
with self.assertRaises(ValueError): |
|
data_utils.add_derived_vars(data) |
|
|
|
def test_add_tisr_var_variable_added(self): |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0)) |
|
}, |
|
coords={ |
|
"lat": np.array([2.0, 1.0]), |
|
"lon": np.array([0.0, 0.5]), |
|
"time": np.array([100, 200], dtype="timedelta64[s]"), |
|
"datetime": xa.Variable( |
|
"time", np.array([10, 20], dtype="datetime64[D]") |
|
), |
|
}, |
|
) |
|
|
|
data_utils.add_tisr_var(data) |
|
|
|
self.assertIn(data_utils.TISR, set(data.variables)) |
|
|
|
def test_add_tisr_var_existing_var_not_overridden(self): |
|
dims = ["time", "lat", "lon"] |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": (dims, np.full((2, 2, 2), 8.0)), |
|
data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)), |
|
}, |
|
coords={ |
|
"lat": np.array([2.0, 1.0]), |
|
"lon": np.array([0.0, 0.5]), |
|
"time": np.array([100, 200], dtype="timedelta64[s]"), |
|
"datetime": xa.Variable( |
|
"time", np.array([10, 20], dtype="datetime64[D]") |
|
), |
|
}, |
|
) |
|
|
|
data_utils.add_derived_vars(data) |
|
|
|
np.testing.assert_allclose(data[data_utils.TISR], 1200.0) |
|
|
|
def test_add_tisr_var_works_with_batch_dim_size_one(self): |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": ( |
|
["batch", "time", "lat", "lon"], |
|
np.full((1, 2, 2, 2), 8.0), |
|
) |
|
}, |
|
coords={ |
|
"lat": np.array([2.0, 1.0]), |
|
"lon": np.array([0.0, 0.5]), |
|
"time": np.array([100, 200], dtype="timedelta64[s]"), |
|
"datetime": xa.Variable( |
|
("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]") |
|
), |
|
}, |
|
) |
|
|
|
data_utils.add_tisr_var(data) |
|
|
|
self.assertIn(data_utils.TISR, set(data.variables)) |
|
|
|
def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self): |
|
data = xa.Dataset( |
|
data_vars={ |
|
"var1": ( |
|
["batch", "time", "lat", "lon"], |
|
np.full((2, 2, 2, 2), 8.0), |
|
) |
|
}, |
|
coords={ |
|
"lat": np.array([2.0, 1.0]), |
|
"lon": np.array([0.0, 0.5]), |
|
"time": np.array([100, 200], dtype="timedelta64[s]"), |
|
"datetime": xa.Variable( |
|
("batch", "time"), |
|
np.array([[10, 20], [100, 200]], dtype="datetime64[D]"), |
|
), |
|
}, |
|
) |
|
|
|
with self.assertRaisesRegex(ValueError, r"cannot select a dimension"): |
|
data_utils.add_tisr_var(data) |
|
|
|
|
|
if __name__ == "__main__": |
|
absltest.main() |
|
|