weather / graphcast /data_utils_test.py
Gary0205's picture
Upload 25 files
6d70ed4 verified
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for `data_utils.py`."""
import datetime
from absl.testing import absltest
from absl.testing import parameterized
from graphcast import data_utils
import numpy as np
import xarray as xa
class DataUtilsTest(parameterized.TestCase):
def setUp(self):
super().setUp()
# Fix the seed for reproducibility.
np.random.seed(0)
def test_year_progress_is_zero_at_year_start_or_end(self):
year_progress = data_utils.get_year_progress(
np.array([
0,
data_utils.AVG_SEC_PER_YEAR,
data_utils.AVG_SEC_PER_YEAR * 42, # 42 years.
])
)
np.testing.assert_array_equal(year_progress, np.zeros(year_progress.shape))
def test_year_progress_is_almost_one_before_year_ends(self):
year_progress = data_utils.get_year_progress(
np.array([
data_utils.AVG_SEC_PER_YEAR - 1,
(data_utils.AVG_SEC_PER_YEAR - 1) * 42, # ~42 years
])
)
with self.subTest("Year progress values are close to 1"):
self.assertTrue(np.all(year_progress > 0.999))
with self.subTest("Year progress values != 1"):
self.assertTrue(np.all(year_progress < 1.0))
def test_day_progress_computes_for_all_times_and_longitudes(self):
times = np.random.randint(low=0, high=1e10, size=10)
longitudes = np.arange(0, 360.0, 1.0)
day_progress = data_utils.get_day_progress(times, longitudes)
with self.subTest("Day progress is computed for all times and longinutes"):
self.assertSequenceEqual(
day_progress.shape, (len(times), len(longitudes))
)
@parameterized.named_parameters(
dict(
testcase_name="random_date_1",
year=1988,
month=11,
day=7,
hour=2,
minute=45,
second=34,
),
dict(
testcase_name="random_date_2",
year=2022,
month=3,
day=12,
hour=7,
minute=1,
second=0,
),
)
def test_day_progress_is_in_between_zero_and_one(
self, year, month, day, hour, minute, second
):
# Datetime from a timestamp.
dt = datetime.datetime(year, month, day, hour, minute, second)
# Epoch time.
epoch_time = datetime.datetime(1970, 1, 1)
# Seconds since epoch.
seconds_since_epoch = np.array([(dt - epoch_time).total_seconds()])
# Longitudes with 1 degree resolution.
longitudes = np.arange(0, 360.0, 1.0)
day_progress = data_utils.get_day_progress(seconds_since_epoch, longitudes)
with self.subTest("Day progress >= 0"):
self.assertTrue(np.all(day_progress >= 0.0))
with self.subTest("Day progress < 1"):
self.assertTrue(np.all(day_progress < 1.0))
def test_day_progress_is_zero_at_day_start_or_end(self):
day_progress = data_utils.get_day_progress(
seconds_since_epoch=np.array([
0,
data_utils.SEC_PER_DAY,
data_utils.SEC_PER_DAY * 42, # 42 days.
]),
longitude=np.array([0.0]),
)
np.testing.assert_array_equal(day_progress, np.zeros(day_progress.shape))
def test_day_progress_specific_value(self):
day_progress = data_utils.get_day_progress(
seconds_since_epoch=np.array([123]),
longitude=np.array([0.0]),
)
np.testing.assert_array_almost_equal(
day_progress, np.array([[0.00142361]]), decimal=6
)
def test_featurize_progress_valid_values_and_dimensions(self):
day_progress = np.array([0.0, 0.45, 0.213])
feature_dimensions = ("time",)
progress_features = data_utils.featurize_progress(
name="day_progress", dims=feature_dimensions, progress=day_progress
)
for feature in progress_features.values():
with self.subTest(f"Valid dimensions for {feature}"):
self.assertSequenceEqual(feature.dims, feature_dimensions)
with self.subTest("Valid values for day_progress"):
np.testing.assert_array_equal(
day_progress, progress_features["day_progress"].values
)
with self.subTest("Valid values for day_progress_sin"):
np.testing.assert_array_almost_equal(
np.array([0.0, 0.30901699, 0.97309851]),
progress_features["day_progress_sin"].values,
decimal=6,
)
with self.subTest("Valid values for day_progress_cos"):
np.testing.assert_array_almost_equal(
np.array([1.0, -0.95105652, 0.23038943]),
progress_features["day_progress_cos"].values,
decimal=6,
)
def test_featurize_progress_invalid_dimensions(self):
year_progress = np.array([0.0, 0.45, 0.213])
feature_dimensions = ("time", "longitude")
with self.assertRaises(ValueError):
data_utils.featurize_progress(
name="year_progress", dims=feature_dimensions, progress=year_progress
)
def test_add_derived_vars_variables_added(self):
data = xa.Dataset(
data_vars={
"var1": (["x", "lon", "datetime"], 8 * np.random.randn(2, 2, 3))
},
coords={
"lon": np.array([0.0, 0.5]),
"datetime": np.array([
datetime.datetime(2021, 1, 1),
datetime.datetime(2023, 1, 1),
datetime.datetime(2023, 1, 3),
]),
},
)
data_utils.add_derived_vars(data)
all_variables = set(data.variables)
with self.subTest("Original value was not removed"):
self.assertIn("var1", all_variables)
with self.subTest("Year progress feature was added"):
self.assertIn(data_utils.YEAR_PROGRESS, all_variables)
with self.subTest("Day progress feature was added"):
self.assertIn(data_utils.DAY_PROGRESS, all_variables)
def test_add_derived_vars_existing_vars_not_overridden(self):
dims = ["x", "lon", "datetime"]
data = xa.Dataset(
data_vars={
"var1": (dims, 8 * np.random.randn(2, 2, 3)),
data_utils.YEAR_PROGRESS: (dims, np.full((2, 2, 3), 0.111)),
data_utils.DAY_PROGRESS: (dims, np.full((2, 2, 3), 0.222)),
},
coords={
"lon": np.array([0.0, 0.5]),
"datetime": np.array([
datetime.datetime(2021, 1, 1),
datetime.datetime(2023, 1, 1),
datetime.datetime(2023, 1, 3),
]),
},
)
data_utils.add_derived_vars(data)
with self.subTest("Year progress feature was not overridden"):
np.testing.assert_allclose(data[data_utils.YEAR_PROGRESS], 0.111)
with self.subTest("Day progress feature was not overridden"):
np.testing.assert_allclose(data[data_utils.DAY_PROGRESS], 0.222)
@parameterized.named_parameters(
dict(testcase_name="missing_datetime", coord_name="lon"),
dict(testcase_name="missing_lon", coord_name="datetime"),
)
def test_add_derived_vars_missing_coordinate_raises_value_error(
self, coord_name
):
with self.subTest(f"Missing {coord_name} coordinate"):
data = xa.Dataset(
data_vars={"var1": (["x", coord_name], 8 * np.random.randn(2, 2))},
coords={
coord_name: np.array([0.0, 0.5]),
},
)
with self.assertRaises(ValueError):
data_utils.add_derived_vars(data)
def test_add_tisr_var_variable_added(self):
data = xa.Dataset(
data_vars={
"var1": (["time", "lat", "lon"], np.full((2, 2, 2), 8.0))
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
"time", np.array([10, 20], dtype="datetime64[D]")
),
},
)
data_utils.add_tisr_var(data)
self.assertIn(data_utils.TISR, set(data.variables))
def test_add_tisr_var_existing_var_not_overridden(self):
dims = ["time", "lat", "lon"]
data = xa.Dataset(
data_vars={
"var1": (dims, np.full((2, 2, 2), 8.0)),
data_utils.TISR: (dims, np.full((2, 2, 2), 1200.0)),
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
"time", np.array([10, 20], dtype="datetime64[D]")
),
},
)
data_utils.add_derived_vars(data)
np.testing.assert_allclose(data[data_utils.TISR], 1200.0)
def test_add_tisr_var_works_with_batch_dim_size_one(self):
data = xa.Dataset(
data_vars={
"var1": (
["batch", "time", "lat", "lon"],
np.full((1, 2, 2, 2), 8.0),
)
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
("batch", "time"), np.array([[10, 20]], dtype="datetime64[D]")
),
},
)
data_utils.add_tisr_var(data)
self.assertIn(data_utils.TISR, set(data.variables))
def test_add_tisr_var_fails_with_batch_dim_size_greater_than_one(self):
data = xa.Dataset(
data_vars={
"var1": (
["batch", "time", "lat", "lon"],
np.full((2, 2, 2, 2), 8.0),
)
},
coords={
"lat": np.array([2.0, 1.0]),
"lon": np.array([0.0, 0.5]),
"time": np.array([100, 200], dtype="timedelta64[s]"),
"datetime": xa.Variable(
("batch", "time"),
np.array([[10, 20], [100, 200]], dtype="datetime64[D]"),
),
},
)
with self.assertRaisesRegex(ValueError, r"cannot select a dimension"):
data_utils.add_tisr_var(data)
if __name__ == "__main__":
absltest.main()