From da4a824bd54156a5a7d148eb6b29f8accb8f440b Mon Sep 17 00:00:00 2001 From: Tom Andersson Date: Wed, 13 Sep 2023 17:16:39 +0100 Subject: [PATCH] Unit tests for saving/loading `DataProcessor`/`TaskLoader`/`ConvNP` --- tests/test_data_processor.py | 75 ++++++++++++++++++++---------------- tests/test_model.py | 68 +++++++++++++++----------------- tests/test_task_loader.py | 75 +++++++++++++++++++++++++++++++++++- tests/utils.py | 26 +++++++++++++ 4 files changed, 173 insertions(+), 71 deletions(-) diff --git a/tests/test_data_processor.py b/tests/test_data_processor.py index a2c4f151..0d8a588c 100644 --- a/tests/test_data_processor.py +++ b/tests/test_data_processor.py @@ -1,13 +1,16 @@ # %% -from typing import Union - -import xarray as xr import numpy as np import pandas as pd import unittest +import tempfile from deepsensor.data.processor import DataProcessor -from tests.utils import gen_random_data_xr, gen_random_data_pandas +from tests.utils import ( + gen_random_data_xr, + gen_random_data_pandas, + assert_allclose_pd, + assert_allclose_xr, +) def _gen_data_xr(coords=None, dims=None, data_vars=None): @@ -42,28 +45,6 @@ class TestDataProcessor(unittest.TestCase): - ... """ - def assert_allclose_pd( - self, df1: Union[pd.DataFrame, pd.Series], df2: Union[pd.DataFrame, pd.Series] - ): - if isinstance(df1, pd.Series): - df1 = df1.to_frame() - if isinstance(df2, pd.Series): - df2 = df2.to_frame() - try: - pd.testing.assert_frame_equal(df1, df2) - except AssertionError: - return False - return True - - def assert_allclose_xr( - self, da1: Union[xr.DataArray, xr.Dataset], da2: Union[xr.DataArray, xr.Dataset] - ): - try: - xr.testing.assert_allclose(da1, da2) - except AssertionError: - return False - return True - def test_only_passing_one_x_mapping_raises_valueerror(self): with self.assertRaises(ValueError): DataProcessor(x1_map=(20, 40), x2_map=None) @@ -90,11 +71,11 @@ def test_unnorm_restores_data_for_each_method(self): da_norm, df_norm = dp([da_raw, df_raw], method=method) da_unnorm, df_unnorm = dp.unnormalise([da_norm, df_norm]) self.assertTrue( - self.assert_allclose_xr(da_unnorm, da_raw), + assert_allclose_xr(da_unnorm, da_raw), f"Original {type(da_raw).__name__} not restored for method {method}.", ) self.assertTrue( - self.assert_allclose_pd(df_unnorm, df_raw), + assert_allclose_pd(df_unnorm, df_raw), f"Original {type(df_raw).__name__} not restored for method {method}.", ) @@ -122,7 +103,7 @@ def test_different_names_xr(self): da_unnorm = dp.unnormalise(da_norm) self.assertTrue( - self.assert_allclose_xr(da_unnorm, da_raw), + assert_allclose_xr(da_unnorm, da_raw), f"Original {type(da_raw).__name__} not restored.", ) @@ -141,7 +122,7 @@ def test_same_names_xr(self): da_unnorm = dp.unnormalise(da_norm) self.assertTrue( - self.assert_allclose_xr(da_unnorm, da_raw), + assert_allclose_xr(da_unnorm, da_raw), f"Original {type(da_raw).__name__} not restored.", ) @@ -207,7 +188,7 @@ def test_different_names_pandas(self): df_unnorm = dp.unnormalise(df_norm) self.assertTrue( - self.assert_allclose_pd(df_unnorm, df_raw), + assert_allclose_pd(df_unnorm, df_raw), f"Original {type(df_raw).__name__} not restored.", ) @@ -227,7 +208,7 @@ def test_same_names_pandas(self): df_unnorm = dp.unnormalise(df_norm) self.assertTrue( - self.assert_allclose_pd(df_unnorm, df_raw), + assert_allclose_pd(df_unnorm, df_raw), f"Original {type(df_raw).__name__} not restored.", ) @@ -273,7 +254,7 @@ def test_extra_indexes_preserved_pandas(self): self.assertListEqual(list(df_raw.index.names), list(df_unnorm.index.names)) self.assertTrue( - self.assert_allclose_pd(df_unnorm, df_raw), + assert_allclose_pd(df_unnorm, df_raw), f"Original {type(df_raw).__name__} not restored.", ) @@ -301,6 +282,34 @@ def test_wrong_extra_indexes_pandas(self): with self.assertRaises(ValueError): dp(df_raw) + def test_saving_and_loading(self): + """Test saving and loading DataProcessor""" + with tempfile.TemporaryDirectory() as tmp_dir: + da_raw = _gen_data_xr() + df_raw = _gen_data_pandas() + + dp = DataProcessor( + x1_map=(20, 40), + x2_map=(40, 60), + time_name="time", + x1_name="lat", + x2_name="lon", + ) + # Normalise some data to store normalisation parameters in config + da_norm = dp(da_raw, method="mean_std") + df_norm = dp(df_raw, method="min_max") + + dp.save(tmp_dir) + + dp_loaded = DataProcessor(tmp_dir) + + # Check that the TaskLoader was saved and loaded correctly + self.assertEqual( + dp.config, + dp_loaded.config, + "Config not saved and loaded correctly", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_model.py b/tests/test_model.py index 9f052c57..a17f70db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,5 +1,6 @@ import copy import itertools +import tempfile from parameterized import parameterized @@ -393,53 +394,46 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self): def test_saving_and_loading(self): """Test saving and loading of model""" - folder = f"tmp_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}" - os.makedirs(folder, exist_ok=True) + with tempfile.TemporaryDirectory() as folder: + ds_raw = xr.tutorial.open_dataset("air_temperature") - ds_raw = xr.tutorial.open_dataset("air_temperature") + data_processor = DataProcessor(x1_name="lat", x2_name="lon") + ds = data_processor(ds_raw) - data_processor = DataProcessor(x1_name="lat", x2_name="lon") - ds = data_processor(ds_raw) + t2m_fpath = f"{folder}/air_temperature_normalised.nc" + ds.to_netcdf(t2m_fpath) - t2m_fpath = f"{folder}/air_temperature_normalised.nc" - ds.to_netcdf(t2m_fpath) + task_loader = TaskLoader(context=t2m_fpath, target=t2m_fpath) - task_loader = TaskLoader(context=t2m_fpath, target=t2m_fpath) - - model = ConvNP( - data_processor, task_loader, unet_channels=(32,) * 3, verbose=False - ) - - # Train the model for a few iterations to test the trained model is restored correctly later. - task = task_loader("2014-12-31", 40, datewise_deterministic=True) - trainer = Trainer(model) - for _ in range(10): - trainer([task]) - mean_ds_before, std_ds_before = model.predict(task, X_t=ds_raw) - mean_ds_before["air"].plot() - - data_processor.save(folder) - task_loader.save(folder) - model.save(folder) + model = ConvNP( + data_processor, task_loader, unet_channels=(5,) * 3, verbose=False + ) - data_processor_loaded = DataProcessor(folder) - task_loader_loaded = TaskLoader(folder) - model_loaded = ConvNP(data_processor_loaded, task_loader_loaded, folder) + # Train the model for a few iterations to test the trained model is restored correctly later. + task = task_loader("2014-12-31", 40, datewise_deterministic=True) + trainer = Trainer(model) + for _ in range(10): + trainer([task]) + mean_ds_before, std_ds_before = model.predict(task, X_t=ds_raw) + mean_ds_before["air"].plot() - task = task_loader_loaded("2014-12-31", 40, datewise_deterministic=True) - mean_ds_loaded, std_ds_loaded = model_loaded.predict(task, X_t=ds_raw) - mean_ds_loaded["air"].plot() + data_processor.save(folder) + task_loader.save(folder) + model.save(folder) - xr.testing.assert_allclose(mean_ds_before, mean_ds_loaded) - print("Means match") + data_processor_loaded = DataProcessor(folder) + task_loader_loaded = TaskLoader(folder) + model_loaded = ConvNP(data_processor_loaded, task_loader_loaded, folder) - xr.testing.assert_allclose(std_ds_before, std_ds_loaded) - print("Standard deviations match") + task = task_loader_loaded("2014-12-31", 40, datewise_deterministic=True) + mean_ds_loaded, std_ds_loaded = model_loaded.predict(task, X_t=ds_raw) + mean_ds_loaded["air"].plot() - # Delete temporary folder - import shutil + xr.testing.assert_allclose(mean_ds_before, mean_ds_loaded) + print("Means match") - shutil.rmtree(folder) + xr.testing.assert_allclose(std_ds_before, std_ds_loaded) + print("Standard deviations match") def assert_shape(x, shape: tuple): diff --git a/tests/test_task_loader.py b/tests/test_task_loader.py index 1c47f644..c0756e56 100644 --- a/tests/test_task_loader.py +++ b/tests/test_task_loader.py @@ -8,8 +8,17 @@ import pandas as pd import unittest +import os +import shutil +import tempfile + from deepsensor.errors import InvalidSamplingStrategyError -from tests.utils import gen_random_data_xr, gen_random_data_pandas +from tests.utils import ( + gen_random_data_xr, + gen_random_data_pandas, + assert_allclose_pd, + assert_allclose_xr, +) from deepsensor.data.loader import TaskLoader @@ -237,6 +246,70 @@ def test_links(self): task = tl("2020-01-01", "split", "split", split_frac=1.1) task = tl("2020-01-01", "split", "split", split_frac=-0.1) + def test_saving_and_loading(self): + """Test saving and loading TaskLoader""" + with tempfile.TemporaryDirectory() as tmp_dir: + xarray_fpath = f"{tmp_dir}/da.nc" + aux_fpath = f"{tmp_dir}/da.nc" + pandas_fpath = f"{tmp_dir}/df.csv" + self.da.to_netcdf(xarray_fpath) + self.aux_da.to_netcdf(aux_fpath) + self.df.to_csv(pandas_fpath) + + # Instantiating with file paths, using all the kwargs + tl = TaskLoader( + context=[aux_fpath, xarray_fpath, pandas_fpath], + target=[xarray_fpath, pandas_fpath], + links=[(2, 1)], + aux_at_contexts=xarray_fpath, + aux_at_targets=xarray_fpath, + context_delta_t=[0, -1, 0], + target_delta_t=[0, 1], + ) + + tl.save(tmp_dir) + + tl_loaded = TaskLoader(tmp_dir) + + # Check that the TaskLoader was saved and loaded correctly + self.assertEqual( + tl.config, + tl_loaded.config, + "Config not saved and loaded correctly", + ) + for i, context in enumerate(tl.context): + if isinstance(context, pd.DataFrame): + assert_allclose_pd(context, tl_loaded.context[i]) + elif isinstance(context, xr.Dataset): + assert_allclose_xr(context, tl_loaded.context[i]) + else: + raise ValueError( + f"Context data type {type(context).__name__} not supported." + ) + self.assertEqual( + tl.aux_at_contexts, + tl_loaded.aux_at_contexts, + "aux_at_contexts not saved and loaded correctly", + ) + self.assertEqual( + tl.aux_at_targets, + tl_loaded.aux_at_targets, + "aux_at_targets not saved and loaded correctly", + ) + self.assertEqual( + tl.links, tl_loaded.links, "Links not saved and loaded correctly" + ) + self.assertEqual( + tl.context_delta_t, + tl_loaded.context_delta_t, + "context_delta_t not saved and loaded correctly", + ) + self.assertEqual( + tl.target_delta_t, + tl_loaded.target_delta_t, + "target_delta_t not saved and loaded correctly", + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/utils.py b/tests/utils.py index cbea8993..3e5b9351 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import pandas as pd import xarray as xr +from typing import Union + def gen_random_data_xr(coords: dict, dims: list = None, data_vars: list = None): """Generate random xarray data @@ -48,3 +50,27 @@ def gen_random_data_pandas(coords: dict, dims: list = None, cols: list = None): df = pd.DataFrame(index=mi, columns=cols) df[:] = np.random.rand(*df.shape) return df + + +def assert_allclose_pd( + df1: Union[pd.DataFrame, pd.Series], df2: Union[pd.DataFrame, pd.Series] +): + if isinstance(df1, pd.Series): + df1 = df1.to_frame() + if isinstance(df2, pd.Series): + df2 = df2.to_frame() + try: + pd.testing.assert_frame_equal(df1, df2) + except AssertionError: + return False + return True + + +def assert_allclose_xr( + da1: Union[xr.DataArray, xr.Dataset], da2: Union[xr.DataArray, xr.Dataset] +): + try: + xr.testing.assert_allclose(da1, da2) + except AssertionError: + return False + return True