Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of metadata to move away from pandas #415

Draft
wants to merge 21 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,4 @@ your

# Ignore npz files from testing:
tests/*.npz
/.vscode
3 changes: 1 addition & 2 deletions pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from numbers import Number

import numpy as np
import pandas as pd

from ._core_functions import _count, _restrict, _value_from
from .interval_set import IntervalSet
Expand Down Expand Up @@ -670,6 +669,6 @@ def _from_npz_reader(cls, file):
ts = cls(time_support=iset, **kwargs)
if "_metadata" in file: # load metadata if it exists
if file["_metadata"]: # check if metadata is not empty
m = pd.DataFrame.from_dict(file["_metadata"].item())
m = file["_metadata"].item()
ts.set_info(m)
return ts
261 changes: 165 additions & 96 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def __repr__(self):

# Adding an extra column between actual values and metadata
try:
metadata = self._metadata
metadata = pd.DataFrame(index=self.metadata_index, data=self._metadata.data)
col_names = metadata.columns
except Exception:
# Necessary for backward compatibility when saving IntervalSet as pickle
Expand Down Expand Up @@ -384,12 +384,12 @@ def __getattr__(self, name):
try:
metadata = self._metadata
except Exception:
metadata = pd.DataFrame(index=self.index)
metadata = {}

if name == "_metadata":
return metadata
elif name in metadata.columns:
return _MetadataMixin.__getattr__(self, name)
elif name in metadata.keys():
return _MetadataMixin.__getitem__(self, name)
else:
return super().__getattr__(name)

Expand All @@ -413,113 +413,91 @@ def __getitem__(self, key):
elif key == "end":
return self.values[:, 1]
elif key in self._metadata.columns:
return _MetadataMixin.__getitem__(self, key)
return self._metadata[key]
else:
raise IndexError(
f"Unknown string argument. Should be in {['start', 'end'] + list(self._metadata.keys())}"
)

elif isinstance(key, list) and all(isinstance(x, str) for x in key):
# self[[*str]]
# easiest to convert to dataframe and then slice
# in case of mixing ["start", "end"] with metadata columns
df = self.as_dataframe()
if all(x in key for x in ["start", "end"]):
return IntervalSet(df[key])
# only works for list of metadata columns
if all(x in self.metadata_columns for x in key):
return self._metadata[key]
else:
return df[key]
raise IndexError(
f"Unknown string argument. Should be in {list(self._metadata.keys())}"
)

elif isinstance(key, Number):
# self[Number]
# self[Number], numpy-like indexing
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key]
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray)):
# self[array_like], use iloc for metadata
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key].reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
elif isinstance(key, pd.Series):
# use loc for metadata

elif isinstance(key, (slice, list, np.ndarray, pd.Series, pd.Index)):
# self[array_like], numpy-like indexing
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
metadata = self._metadata.iloc[key]
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)

elif isinstance(key, tuple):
if len(key) == 2:
if isinstance(key[1], Number):
# self[Any, Number]
# allow number indexing for start and end times for backward compatibility
return self.values.__getitem__(key)

elif isinstance(key[1], str):
# self[Any, str]
if key[1] == "start":
return self.values[key[0], 0]
elif key[1] == "end":
return self.values[key[0], 1]
elif key[1] in self._metadata.columns:
return _MetadataMixin.__getitem__(self, key)

elif isinstance(key[1], (list, np.ndarray)):
if all(isinstance(x, str) for x in key[1]):
# self[Any, [*str]]
# easiest to convert to dataframe and then slice
# in case of mixing ["start", "end"] with metadata columns
df = self.as_dataframe()
if all(x in key[1] for x in ["start", "end"]):
return IntervalSet(df.loc[key])
else:
return df.loc[key]
elif all(isinstance(x, Number) for x in key[1]):
if all(x in [0, 1] for x in key[1]):
# self[Any, [0,1]]
# allow number indexing for start and end times for backward compatibility
output = self.values.__getitem__(key[0])
if isinstance(key[0], Number):
return IntervalSet(start=output[0], end=output[1])
else:
return IntervalSet(start=output[:, 0], end=output[:, 1])
else:
raise IndexError(
f"index {key[1]} out of bounds for IntervalSet axis 1 with size 2"
)
else:
raise IndexError(f"unknown index {key[1]} for index 2")

elif isinstance(key[1], slice):
if key[1] == slice(None, None, None):
# self[Any, :]
if isinstance(key[1], slice):
# self[Any, slice]
if (
(key[1] == slice(None)) # self[Any, :]
or (key[1] == slice(0, None)) # self[Any, 0:]
or (
(key[1].stop > 2) # self[Any, :3+], self[Any, 0:3+]
and ((key[1].start is None) or (key[1].start == 0))
)
):
# slice all rows
output = self.values.__getitem__(key[0])
metadata = self._metadata.iloc[key[0]]
return IntervalSet(output, metadata=metadata)

if isinstance(key[0], Number):
return IntervalSet(
start=output[0], end=output[1], metadata=metadata
)
else:
return IntervalSet(
start=output[:, 0],
end=output[:, 1],
metadata=metadata.reset_index(drop=True),
)

elif (key[1] == slice(0, 2, None)) or (
key[1] == slice(None, 2, None)
):
# self[Any, :2]
# allow number indexing for start and end times for backward compatibility
output = self.values.__getitem__(key[0])
elif (key[1] == slice(0, 2)) or (key[1] == slice(None, 2)):
# slice start and stop
output = self.values.__getitem__(key)
return IntervalSet(output)

if isinstance(key[0], Number):
return IntervalSet(start=output[0], end=output[1])
else:
return IntervalSet(start=output[:, 0], end=output[:, 1])
elif ((key[1] == slice(0, -1)) or (key[1] == slice(None, -1))) and (
len(self.metadata_columns) > 0
):
# slice start and stop and exclude metadata
output = self.values.__getitem__(
(key[0], slice(key[1].start, None, key[1].step))
)
return IntervalSet(output)

else:
raise IndexError(
f"index {key[1]} out of bounds for IntervalSet axis 1 with size 2"
)
# all other cases, use whatever numpy does
output = self.values.__getitem__(key)
return output

elif (
isinstance(key[1], (list, np.ndarray, pd.Series, pd.Index))
and all(isinstance(x, Number) for x in key[1])
and (len(key[1]) == 2)
):
# return IntervalSet if one start and one end is indexed
output = self.values.__getitem__(key)
if (
(np.issubdtype(np.array(key[1]).dtype, bool) and np.all(key[1]))
or np.all(key[1] == [0, 1])
or np.all(key[1] == [0, -1])
or np.all(key[1] == [-2, -1])
):
return IntervalSet(output)
else:
return output

else:
raise IndexError(f"unknown type {type(key[1])} for index 2")
# treat anything else like numpy indexing
return self.values.__getitem__(key)
# raise IndexError(f"unknown type {type(key[1])} for index 2")

else:
raise IndexError(
Expand Down Expand Up @@ -704,9 +682,9 @@ def intersect(self, a):
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e, m = jitintersect(start1, end1, start2, end2)
m1 = self._metadata.loc[m[:, 0]].reset_index(drop=True)
m2 = a._metadata.loc[m[:, 1]].reset_index(drop=True)
return IntervalSet(s, e, metadata=m1.join(m2))
m1 = self._metadata.loc[m[:, 0]]
m2 = a._metadata.loc[m[:, 1]]
return IntervalSet(s, e, metadata={**m1, **m2})

def union(self, a):
"""
Expand Down Expand Up @@ -753,7 +731,7 @@ def set_diff(self, a):
start2 = a.values[:, 0]
end2 = a.values[:, 1]
s, e, m = jitdiff(start1, end1, start2, end2)
m1 = self._metadata.loc[m].reset_index(drop=True)
m1 = self._metadata.loc[m]
return IntervalSet(s, e, metadata=m1)

def in_interval(self, tsd):
Expand Down Expand Up @@ -920,7 +898,7 @@ def as_dataframe(self):
_
"""
df = pd.DataFrame(data=self.values, columns=["start", "end"])
return pd.concat([df, self._metadata], axis=1)
return pd.concat([df, self._metadata.as_dataframe()], axis=1)

def save(self, filename):
"""
Expand Down Expand Up @@ -964,7 +942,7 @@ def save(self, filename):
start=self.values[:, 0],
end=self.values[:, 1],
type=np.array(["IntervalSet"], dtype=np.str_),
_metadata=self._metadata.to_dict(), # save metadata as dictionary
_metadata=dict(self._metadata), # save metadata as dictionary
)

return
Expand Down Expand Up @@ -1042,8 +1020,8 @@ def split(self, interval_size, time_units="s"):
tokeep = durations >= interval_size
new_starts = new_starts[tokeep]
new_ends = new_ends[tokeep]
new_meta = new_meta[tokeep]
metadata = self._metadata.loc[new_meta].reset_index(drop=True)
new_meta = new_meta[tokeep].astype(int)
metadata = self._metadata.loc[new_meta]

# Removing 1 microsecond to have strictly non-overlapping intervals for intervals coming from the same epoch
new_ends -= 1e-6
Expand Down Expand Up @@ -1198,3 +1176,94 @@ def get_info(self, key):
2 3 y
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(tmp,metadata=metadata)

Grouping by a single column:

>>> ep.groupby("l2")
{'x': [0, 1], 'y': [2]}

Grouping by multiple columns:

>>> ep.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}

Filtering to a specific group using the output dictionary:

>>> groups = ep.groupby("l2")
>>> ep[groups["x"]]
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.

Filtering to a specific group using the get_group argument:

>>> ep.groupby("l2", get_group="x")
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.
"""
return _MetadataMixin.groupby(self, by, get_group)

@add_meta_docstring("groupby_apply")
def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(tmp,metadata=metadata)

Apply a numpy function:

>>> ep.groupby_apply("l2", np.mean)
{'x': 6.75, 'y': 26.5}

Apply a custom function:

>>> ep.groupby_apply("l2", lambda x: x.shape[0])
{'x': 2, 'y': 1}

Apply a function with additional arguments:

>>> ep.groupby_apply("l2", np.mean, axis=1)
{'x': array([ 2.5, 11. ]), 'y': array([26.5])}

Applying a function with additional arguments, where the grouped object is not the first argument:

>>> tsg = nap.TsGroup(
... {
... 1: nap.Ts(t=np.arange(0, 40)),
... 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"),
... 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"),
... },
... )
>>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)]))
>>> func_kwargs = {
>>> "group": tsg,
>>> "feature": feature,
>>> "nb_bins": 2,
>>> }
>>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, grouped_arg="ep", **func_kwargs)
{'x': 1 2 3
0.25 1.025641 1.823362 4.216524
0.75 NaN NaN NaN,
'y': 1 2 3
0.25 NaN NaN NaN
0.75 1.025641 1.978022 4.835165}
"""
return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs)
Loading
Loading