Skip to content

Commit

Permalink
Unit tests for saving/loading DataProcessor/TaskLoader/ConvNP
Browse files Browse the repository at this point in the history
  • Loading branch information
tom-andersson committed Sep 13, 2023
1 parent 363a5fc commit da4a824
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 71 deletions.
75 changes: 42 additions & 33 deletions tests/test_data_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# %%
from typing import Union

import xarray as xr
import numpy as np
import pandas as pd
import unittest
import tempfile

from deepsensor.data.processor import DataProcessor
from tests.utils import gen_random_data_xr, gen_random_data_pandas
from tests.utils import (
gen_random_data_xr,
gen_random_data_pandas,
assert_allclose_pd,
assert_allclose_xr,
)


def _gen_data_xr(coords=None, dims=None, data_vars=None):
Expand Down Expand Up @@ -42,28 +45,6 @@ class TestDataProcessor(unittest.TestCase):
- ...
"""

def assert_allclose_pd(
self, df1: Union[pd.DataFrame, pd.Series], df2: Union[pd.DataFrame, pd.Series]
):
if isinstance(df1, pd.Series):
df1 = df1.to_frame()
if isinstance(df2, pd.Series):
df2 = df2.to_frame()
try:
pd.testing.assert_frame_equal(df1, df2)
except AssertionError:
return False
return True

def assert_allclose_xr(
self, da1: Union[xr.DataArray, xr.Dataset], da2: Union[xr.DataArray, xr.Dataset]
):
try:
xr.testing.assert_allclose(da1, da2)
except AssertionError:
return False
return True

def test_only_passing_one_x_mapping_raises_valueerror(self):
with self.assertRaises(ValueError):
DataProcessor(x1_map=(20, 40), x2_map=None)
Expand All @@ -90,11 +71,11 @@ def test_unnorm_restores_data_for_each_method(self):
da_norm, df_norm = dp([da_raw, df_raw], method=method)
da_unnorm, df_unnorm = dp.unnormalise([da_norm, df_norm])
self.assertTrue(
self.assert_allclose_xr(da_unnorm, da_raw),
assert_allclose_xr(da_unnorm, da_raw),
f"Original {type(da_raw).__name__} not restored for method {method}.",
)
self.assertTrue(
self.assert_allclose_pd(df_unnorm, df_raw),
assert_allclose_pd(df_unnorm, df_raw),
f"Original {type(df_raw).__name__} not restored for method {method}.",
)

Expand Down Expand Up @@ -122,7 +103,7 @@ def test_different_names_xr(self):

da_unnorm = dp.unnormalise(da_norm)
self.assertTrue(
self.assert_allclose_xr(da_unnorm, da_raw),
assert_allclose_xr(da_unnorm, da_raw),
f"Original {type(da_raw).__name__} not restored.",
)

Expand All @@ -141,7 +122,7 @@ def test_same_names_xr(self):

da_unnorm = dp.unnormalise(da_norm)
self.assertTrue(
self.assert_allclose_xr(da_unnorm, da_raw),
assert_allclose_xr(da_unnorm, da_raw),
f"Original {type(da_raw).__name__} not restored.",
)

Expand Down Expand Up @@ -207,7 +188,7 @@ def test_different_names_pandas(self):
df_unnorm = dp.unnormalise(df_norm)

self.assertTrue(
self.assert_allclose_pd(df_unnorm, df_raw),
assert_allclose_pd(df_unnorm, df_raw),
f"Original {type(df_raw).__name__} not restored.",
)

Expand All @@ -227,7 +208,7 @@ def test_same_names_pandas(self):
df_unnorm = dp.unnormalise(df_norm)

self.assertTrue(
self.assert_allclose_pd(df_unnorm, df_raw),
assert_allclose_pd(df_unnorm, df_raw),
f"Original {type(df_raw).__name__} not restored.",
)

Expand Down Expand Up @@ -273,7 +254,7 @@ def test_extra_indexes_preserved_pandas(self):

self.assertListEqual(list(df_raw.index.names), list(df_unnorm.index.names))
self.assertTrue(
self.assert_allclose_pd(df_unnorm, df_raw),
assert_allclose_pd(df_unnorm, df_raw),
f"Original {type(df_raw).__name__} not restored.",
)

Expand Down Expand Up @@ -301,6 +282,34 @@ def test_wrong_extra_indexes_pandas(self):
with self.assertRaises(ValueError):
dp(df_raw)

def test_saving_and_loading(self):
"""Test saving and loading DataProcessor"""
with tempfile.TemporaryDirectory() as tmp_dir:
da_raw = _gen_data_xr()
df_raw = _gen_data_pandas()

dp = DataProcessor(
x1_map=(20, 40),
x2_map=(40, 60),
time_name="time",
x1_name="lat",
x2_name="lon",
)
# Normalise some data to store normalisation parameters in config
da_norm = dp(da_raw, method="mean_std")
df_norm = dp(df_raw, method="min_max")

dp.save(tmp_dir)

dp_loaded = DataProcessor(tmp_dir)

# Check that the TaskLoader was saved and loaded correctly
self.assertEqual(
dp.config,
dp_loaded.config,
"Config not saved and loaded correctly",
)


if __name__ == "__main__":
unittest.main()
68 changes: 31 additions & 37 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import itertools
import tempfile

from parameterized import parameterized

Expand Down Expand Up @@ -393,53 +394,46 @@ def test_highlevel_predict_coords_align_with_X_t_offgrid(self):

def test_saving_and_loading(self):
"""Test saving and loading of model"""
folder = f"tmp_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}"
os.makedirs(folder, exist_ok=True)
with tempfile.TemporaryDirectory() as folder:
ds_raw = xr.tutorial.open_dataset("air_temperature")

ds_raw = xr.tutorial.open_dataset("air_temperature")
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)

data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)
t2m_fpath = f"{folder}/air_temperature_normalised.nc"
ds.to_netcdf(t2m_fpath)

t2m_fpath = f"{folder}/air_temperature_normalised.nc"
ds.to_netcdf(t2m_fpath)
task_loader = TaskLoader(context=t2m_fpath, target=t2m_fpath)

task_loader = TaskLoader(context=t2m_fpath, target=t2m_fpath)

model = ConvNP(
data_processor, task_loader, unet_channels=(32,) * 3, verbose=False
)

# Train the model for a few iterations to test the trained model is restored correctly later.
task = task_loader("2014-12-31", 40, datewise_deterministic=True)
trainer = Trainer(model)
for _ in range(10):
trainer([task])
mean_ds_before, std_ds_before = model.predict(task, X_t=ds_raw)
mean_ds_before["air"].plot()

data_processor.save(folder)
task_loader.save(folder)
model.save(folder)
model = ConvNP(
data_processor, task_loader, unet_channels=(5,) * 3, verbose=False
)

data_processor_loaded = DataProcessor(folder)
task_loader_loaded = TaskLoader(folder)
model_loaded = ConvNP(data_processor_loaded, task_loader_loaded, folder)
# Train the model for a few iterations to test the trained model is restored correctly later.
task = task_loader("2014-12-31", 40, datewise_deterministic=True)
trainer = Trainer(model)
for _ in range(10):
trainer([task])
mean_ds_before, std_ds_before = model.predict(task, X_t=ds_raw)
mean_ds_before["air"].plot()

task = task_loader_loaded("2014-12-31", 40, datewise_deterministic=True)
mean_ds_loaded, std_ds_loaded = model_loaded.predict(task, X_t=ds_raw)
mean_ds_loaded["air"].plot()
data_processor.save(folder)
task_loader.save(folder)
model.save(folder)

xr.testing.assert_allclose(mean_ds_before, mean_ds_loaded)
print("Means match")
data_processor_loaded = DataProcessor(folder)
task_loader_loaded = TaskLoader(folder)
model_loaded = ConvNP(data_processor_loaded, task_loader_loaded, folder)

xr.testing.assert_allclose(std_ds_before, std_ds_loaded)
print("Standard deviations match")
task = task_loader_loaded("2014-12-31", 40, datewise_deterministic=True)
mean_ds_loaded, std_ds_loaded = model_loaded.predict(task, X_t=ds_raw)
mean_ds_loaded["air"].plot()

# Delete temporary folder
import shutil
xr.testing.assert_allclose(mean_ds_before, mean_ds_loaded)
print("Means match")

shutil.rmtree(folder)
xr.testing.assert_allclose(std_ds_before, std_ds_loaded)
print("Standard deviations match")


def assert_shape(x, shape: tuple):
Expand Down
75 changes: 74 additions & 1 deletion tests/test_task_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@
import pandas as pd
import unittest

import os
import shutil
import tempfile

from deepsensor.errors import InvalidSamplingStrategyError
from tests.utils import gen_random_data_xr, gen_random_data_pandas
from tests.utils import (
gen_random_data_xr,
gen_random_data_pandas,
assert_allclose_pd,
assert_allclose_xr,
)

from deepsensor.data.loader import TaskLoader

Expand Down Expand Up @@ -237,6 +246,70 @@ def test_links(self):
task = tl("2020-01-01", "split", "split", split_frac=1.1)
task = tl("2020-01-01", "split", "split", split_frac=-0.1)

def test_saving_and_loading(self):
"""Test saving and loading TaskLoader"""
with tempfile.TemporaryDirectory() as tmp_dir:
xarray_fpath = f"{tmp_dir}/da.nc"
aux_fpath = f"{tmp_dir}/da.nc"
pandas_fpath = f"{tmp_dir}/df.csv"
self.da.to_netcdf(xarray_fpath)
self.aux_da.to_netcdf(aux_fpath)
self.df.to_csv(pandas_fpath)

# Instantiating with file paths, using all the kwargs
tl = TaskLoader(
context=[aux_fpath, xarray_fpath, pandas_fpath],
target=[xarray_fpath, pandas_fpath],
links=[(2, 1)],
aux_at_contexts=xarray_fpath,
aux_at_targets=xarray_fpath,
context_delta_t=[0, -1, 0],
target_delta_t=[0, 1],
)

tl.save(tmp_dir)

tl_loaded = TaskLoader(tmp_dir)

# Check that the TaskLoader was saved and loaded correctly
self.assertEqual(
tl.config,
tl_loaded.config,
"Config not saved and loaded correctly",
)
for i, context in enumerate(tl.context):
if isinstance(context, pd.DataFrame):
assert_allclose_pd(context, tl_loaded.context[i])
elif isinstance(context, xr.Dataset):
assert_allclose_xr(context, tl_loaded.context[i])
else:
raise ValueError(
f"Context data type {type(context).__name__} not supported."
)
self.assertEqual(
tl.aux_at_contexts,
tl_loaded.aux_at_contexts,
"aux_at_contexts not saved and loaded correctly",
)
self.assertEqual(
tl.aux_at_targets,
tl_loaded.aux_at_targets,
"aux_at_targets not saved and loaded correctly",
)
self.assertEqual(
tl.links, tl_loaded.links, "Links not saved and loaded correctly"
)
self.assertEqual(
tl.context_delta_t,
tl_loaded.context_delta_t,
"context_delta_t not saved and loaded correctly",
)
self.assertEqual(
tl.target_delta_t,
tl_loaded.target_delta_t,
"target_delta_t not saved and loaded correctly",
)


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import pandas as pd
import xarray as xr

from typing import Union


def gen_random_data_xr(coords: dict, dims: list = None, data_vars: list = None):
"""Generate random xarray data
Expand Down Expand Up @@ -48,3 +50,27 @@ def gen_random_data_pandas(coords: dict, dims: list = None, cols: list = None):
df = pd.DataFrame(index=mi, columns=cols)
df[:] = np.random.rand(*df.shape)
return df


def assert_allclose_pd(
df1: Union[pd.DataFrame, pd.Series], df2: Union[pd.DataFrame, pd.Series]
):
if isinstance(df1, pd.Series):
df1 = df1.to_frame()
if isinstance(df2, pd.Series):
df2 = df2.to_frame()
try:
pd.testing.assert_frame_equal(df1, df2)
except AssertionError:
return False
return True


def assert_allclose_xr(
da1: Union[xr.DataArray, xr.Dataset], da2: Union[xr.DataArray, xr.Dataset]
):
try:
xr.testing.assert_allclose(da1, da2)
except AssertionError:
return False
return True

0 comments on commit da4a824

Please sign in to comment.