Skip to content

Commit

Permalink
fix indexing of intervalset to be able to use -1, add additional te…
Browse files Browse the repository at this point in the history
…st cases
  • Loading branch information
sjvenditto committed Feb 3, 2025
1 parent 46918fe commit d453d15
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
23 changes: 13 additions & 10 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def __init__(
self._class_attributes = self.__dir__() # get list of all attributes
self._class_attributes.append("_class_attributes") # add this property
self._initialized = True
if drop_meta is False:
if (drop_meta is False) and (metadata is not None):
self.set_info(metadata)

def __repr__(self):
Expand Down Expand Up @@ -406,11 +406,6 @@ def __setitem__(self, key, value):
)

def __getitem__(self, key):
try:
metadata = _MetadataMixin.__getitem__(self, key)
except Exception:
metadata = pd.DataFrame(index=self.index)

if isinstance(key, str):
# self[str]
if key == "start":
Expand All @@ -435,9 +430,15 @@ def __getitem__(self, key):
elif isinstance(key, Number):
# self[Number]
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key]
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
# self[array_like]
elif isinstance(key, (slice, list, np.ndarray)):
# self[array_like], use iloc for metadata
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key].reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
elif isinstance(key, pd.Series):
# use loc for metadata
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
Expand Down Expand Up @@ -487,7 +488,7 @@ def __getitem__(self, key):
if key[1] == slice(None, None, None):
# self[Any, :]
output = self.values.__getitem__(key[0])
metadata = _MetadataMixin.__getitem__(self, key[0])
metadata = self._metadata.iloc[key[0]]

if isinstance(key[0], Number):
return IntervalSet(
Expand All @@ -500,7 +501,9 @@ def __getitem__(self, key):
metadata=metadata.reset_index(drop=True),
)

elif key[1] == slice(0, 2, None):
elif (key[1] == slice(0, 2, None)) or (
key[1] == slice(None, 2, None)
):
# self[Any, :2]
# allow number indexing for start and end times for backward compatibility
output = self.values.__getitem__(key[0])
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def get_info(self, key):
)
)
):
# assume key is index, or tupe of index and column name
# assume key is index, or tuple of index and column name
# metadata[Number], metadata[array_like], metadata[Any, str], or metadata[Any, [*str]]
return self._metadata.loc[key]

Expand Down
5 changes: 5 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,17 @@ def test_create_iset_from_df_with_metadata_sort(df, expected):
"index",
[
0,
-1,
slice(0, 2),
[0, 2],
[0, -1],
(slice(0, 2), slice(None)),
(slice(0, 2), slice(0, 2)),
(slice(None), ["start", "end"]),
([0, -1], slice(None, 2)),
(0, slice(None)),
(-1, slice(None)),
([0, -1], slice(None)),
],
)
def test_get_iset_with_metadata(iset_meta, index):
Expand Down

0 comments on commit d453d15

Please sign in to comment.