From d453d15e92ca757941e58098fbdd665403d301fb Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 3 Feb 2025 11:40:08 -0500 Subject: [PATCH 1/3] fix indexing of intervalset to be able to use `-1`, add additional test cases --- pynapple/core/interval_set.py | 23 +++++++++++++---------- pynapple/core/metadata_class.py | 2 +- tests/test_metadata.py | 5 +++++ 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index f735cff8..84b3334c 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -294,7 +294,7 @@ def __init__( self._class_attributes = self.__dir__() # get list of all attributes self._class_attributes.append("_class_attributes") # add this property self._initialized = True - if drop_meta is False: + if (drop_meta is False) and (metadata is not None): self.set_info(metadata) def __repr__(self): @@ -406,11 +406,6 @@ def __setitem__(self, key, value): ) def __getitem__(self, key): - try: - metadata = _MetadataMixin.__getitem__(self, key) - except Exception: - metadata = pd.DataFrame(index=self.index) - if isinstance(key, str): # self[str] if key == "start": @@ -435,9 +430,15 @@ def __getitem__(self, key): elif isinstance(key, Number): # self[Number] output = self.values.__getitem__(key) + metadata = self._metadata.iloc[key] return IntervalSet(start=output[0], end=output[1], metadata=metadata) - elif isinstance(key, (slice, list, np.ndarray, pd.Series)): - # self[array_like] + elif isinstance(key, (slice, list, np.ndarray)): + # self[array_like], use iloc for metadata + output = self.values.__getitem__(key) + metadata = self._metadata.iloc[key].reset_index(drop=True) + return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata) + elif isinstance(key, pd.Series): + # use loc for metadata output = self.values.__getitem__(key) metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True) return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata) @@ -487,7 +488,7 @@ def __getitem__(self, key): if key[1] == slice(None, None, None): # self[Any, :] output = self.values.__getitem__(key[0]) - metadata = _MetadataMixin.__getitem__(self, key[0]) + metadata = self._metadata.iloc[key[0]] if isinstance(key[0], Number): return IntervalSet( @@ -500,7 +501,9 @@ def __getitem__(self, key): metadata=metadata.reset_index(drop=True), ) - elif key[1] == slice(0, 2, None): + elif (key[1] == slice(0, 2, None)) or ( + key[1] == slice(None, 2, None) + ): # self[Any, :2] # allow number indexing for start and end times for backward compatibility output = self.values.__getitem__(key[0]) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 8e494b2e..c903111d 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -313,7 +313,7 @@ def get_info(self, key): ) ) ): - # assume key is index, or tupe of index and column name + # assume key is index, or tuple of index and column name # metadata[Number], metadata[array_like], metadata[Any, str], or metadata[Any, [*str]] return self._metadata.loc[key] diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 1f7819ff..dceebf0f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -153,12 +153,17 @@ def test_create_iset_from_df_with_metadata_sort(df, expected): "index", [ 0, + -1, slice(0, 2), [0, 2], + [0, -1], (slice(0, 2), slice(None)), (slice(0, 2), slice(0, 2)), (slice(None), ["start", "end"]), + ([0, -1], slice(None, 2)), (0, slice(None)), + (-1, slice(None)), + ([0, -1], slice(None)), ], ) def test_get_iset_with_metadata(iset_meta, index): From ec68518dc6fc4adc12f9528979bb1e10c3f5f365 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 6 Feb 2025 12:14:00 -0500 Subject: [PATCH 2/3] Adding test for backward compatibility of loading npz file --- tests/test_npz_file.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index d809efb8..4a266668 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-07-10 17:08:55 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 13:13:37 - """Tests of NPZ file functions""" import shutil @@ -157,6 +151,32 @@ def test_load_tsdframe(path, k): assert np.all(tmp.d == data[k].d) +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["iset", "iset_minfo"]) +def test_load_tsdframe(path, k): + file_path = path / (k + ".npz") + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + np.testing.assert_array_almost_equal(tmp.values, data[k].values) + + +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["iset", "iset_minfo"]) +def test_load_intervalset_backward_compatibility(path, k): + file_path = path / (k + ".npz") + tmp = dict(np.load(file_path, allow_pickle=True)) + tmp.pop("_metadata") + np.savez(file_path, **tmp) + + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + np.testing.assert_array_almost_equal(tmp.values, data[k].values) + # Testing the slicing + np.testing.assert_array_almost_equal(tmp[0].values, data[k].values[0, None]) + + @pytest.mark.parametrize("path", [path]) def test_load_non_npz(path): file_path = path / "random.npz" From b6f6ed1b4c6da5ab7e14b8b6422b55ea33d6f9a7 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 6 Feb 2025 12:21:06 -0500 Subject: [PATCH 3/3] adding tests for tsdframe --- tests/test_npz_file.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index 4a266668..d3e8db27 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -151,9 +151,27 @@ def test_load_tsdframe(path, k): assert np.all(tmp.d == data[k].d) +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["tsdframe", "tsdframe_minfo"]) +def test_load_tsdframe_backward_compatibility(path, k): + file_path = path / (k + ".npz") + tmp = dict(np.load(file_path, allow_pickle=True)) + tmp.pop("_metadata") + np.savez(file_path, **tmp) + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + assert np.all(tmp.t == data[k].t) + np.testing.assert_array_almost_equal( + tmp.time_support.values, data[k].time_support.values + ) + assert np.all(tmp.columns == data[k].columns) + assert np.all(tmp.d == data[k].d) + + @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ["iset", "iset_minfo"]) -def test_load_tsdframe(path, k): +def test_load_intervalset(path, k): file_path = path / (k + ".npz") file = nap.NPZFile(file_path) tmp = file.load()