Skip to content

Commit

Permalink
adding some docstrings and fixing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Feb 10, 2025
1 parent 2809164 commit 143d542
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
1 change: 0 additions & 1 deletion pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from numbers import Number

import numpy as np
import pandas as pd

from ._core_functions import _count, _restrict, _value_from
from .interval_set import IntervalSet
Expand Down
1 change: 0 additions & 1 deletion pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,6 @@ def __init__(
start = start.sort_values("start").reset_index(drop=True)

metadata = start.drop(columns=["start", "end"])
index = start.index
end = start["end"].values.astype(np.float64)
start = start["start"].values.astype(np.float64)

Expand Down
80 changes: 70 additions & 10 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import numpy as np
import pandas as pd

from .utils import is_array_like


def add_meta_docstring(meta_func, sep="\n"):
meta_doc = getattr(_MetadataMixin, meta_func).__doc__
Expand Down Expand Up @@ -243,7 +241,6 @@ def set_info(self, metadata=None, **kwargs):
if isinstance(v, pd.Series):
if np.all(self.metadata_index == v.index.values):
self._metadata[k] = np.array(v)
# self._metadata[k] = v
else:
raise ValueError(
"Metadata index does not match for argument {}".format(k)
Expand All @@ -252,7 +249,6 @@ def set_info(self, metadata=None, **kwargs):
elif isinstance(v, (np.ndarray, list, tuple)):
if len(self.metadata_index) == len(v):
self._metadata[k] = np.array(v)
# self._metadata[k] = pd.Series(v, index=self.metadata_index)
else:
raise ValueError(
f"input array length {len(v)} does not match metadata length {len(self.metadata_index)}."
Expand All @@ -263,7 +259,6 @@ def set_info(self, metadata=None, **kwargs):
):
# if only one index and metadata is non-iterable, pack into iterable for single assignment
self._metadata[k] = np.array([v])
# self._metadata[k] = pd.Series(v, index=self.metadata_index)

else:
not_set.append({k: v})
Expand Down Expand Up @@ -294,7 +289,7 @@ def get_info(self, key):
Returns
-------
pandas.Series or pandas.DataFrame or Any (for single location)
dict or np.array or Any (for single location)
The metadata information based on the key provided.
Raises
Expand All @@ -303,16 +298,19 @@ def get_info(self, key):
If the metadata index is not found.
"""
if isinstance(key, str) and (key in self.metadata_columns):
# single metadata column
return self._metadata[key]

elif (
isinstance(key, (list, np.ndarray))
and all(isinstance(k, str) for k in key)
and all(k in self.metadata_columns for k in key)
):
# multiple metadata columns
return self._metadata[key]

else:
# everything else, use .loc
return self._metadata.loc[key]

def drop_info(self, key):
Expand Down Expand Up @@ -437,12 +435,18 @@ def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):


class _Metadata(UserDict):
"""
A custom dictionary class for storing metadata information.
"""

def __init__(self, index):
super().__init__()
self.index = index

def __getitem__(self, key):
"""
Wrapper around typical dictionary __getitem__ to allow for multiple key indexing.
"""
try:
if isinstance(key, list):
return {key: self.data[key] for key in key}
Expand All @@ -455,34 +459,67 @@ def __getitem__(self, key):

@property
def columns(self):
"""
Metadata keys (columns).
"""
return list(self.data.keys())

@property
def loc(self):
"""
Pandas-like indexing for metadata.
"""
return _MetadataLoc(self, self.index)

@property
def iloc(self):
"""
Numpy-like indexing for metadata.
"""
return _MetadataILoc(self)

@property
def shape(self):
"""
Metadata shape as (n_index, n_columns).
"""
return (len(self.index), len(self.columns))

@property
def dtypes(self):
"""
Dictonary of data types for each metadata column.
"""
return {k: self.data[k].dtype for k in self.columns}

def as_dataframe(self):
"""
Convert metadata dictionary to a pandas DataFrame.
"""
return pd.DataFrame(self.data, index=self.index)

def groupby(self, by):
"""
Grouping function for metadata.
Parameters
----------
by : str or list of str
Metadata column name(s) to group by.
Returns
-------
dict
Dictionary of object indices (dictionary values) corresponding to each group (dictionary keys).
"""
if isinstance(by, str):
# groupby single column
return {
k: np.where(self.data[by] == k)[0] for k in np.unique(self.data[by])
}

elif isinstance(by, list):
# groupby multiple columns
groups = {
k: np.where(
np.all([self.data[col] == k[c] for c, col in enumerate(by)], axis=0)
Expand All @@ -493,9 +530,16 @@ def groupby(self, by):
return {k: v for k, v in groups.items() if len(v)}

def copy(self):
"""
Return a deep copy of the metadata object.
"""
return copy.deepcopy(self)

def merge(self, other):
"""
Merge metadata with another metadata object. Operates in place.
Can only merge metadata with the same columns.
"""
if not isinstance(other, _Metadata):
raise TypeError("Can only merge with another _Metadata object")
if not np.all(self.columns == other.columns):
Expand All @@ -507,6 +551,10 @@ def merge(self, other):


class _MetadataLoc:
"""
Helper class for pandas-like indexing of metadata.
Assumes that index corresponds to object index values in first axis, and metadata columns in second axis.
"""

def __init__(self, metadata, index):
self.data = metadata.data
Expand All @@ -520,12 +568,12 @@ def __getitem__(self, key):
return {k: self.data[k][key] for k in self.keys}

elif isinstance(key, (Number, str)):
# metadata.loc[Number]
# metadata.loc[Number], single row across all columns
idx = self.index_map[key]
return {k: self.data[k][idx] for k in self.keys}

elif isinstance(key, (list, np.ndarray, pd.Index, slice)):
# metadata.loc[array_like]
# metadata.loc[array_like], multiple rows across all columns
idx = self._get_indexder(key)
return {k: self.data[k][idx] for k in self.keys}

Expand All @@ -537,12 +585,13 @@ def __getitem__(self, key):
idx = self._get_indexder(key[0])

if isinstance(key[1], str):
# metadata.loc[Any, str], index metadata field
# metadata.loc[Any, str], slice single metadata column
return self.data[key[1]][idx]

elif isinstance(key[1], list) and all(
isinstance(k, str) for k in key[1]
):
# metadata.loc[Any, [*str]], slice multiple metadata columns
return {k: self.data[k][idx] for k in key[1]}
else:
raise IndexError(f"Too many indices for metadata.loc: {key}")
Expand All @@ -551,28 +600,39 @@ def __getitem__(self, key):
raise IndexError(f"Unknown metadata index {key}")

def _get_indexder(self, vals):
"""
Function that maps object index values to positional index.
"""
return [self.index_map[val] for val in vals]


class _MetadataILoc:
"""
Helper class for numpy-like indexing of metadata.
Assumes that indices correspond to positional index of row (object index) in first axis and positional index of column (metadata column) in second axis.
"""

def __init__(self, metadata):
self.data = metadata.data
self.keys = metadata.columns

def __getitem__(self, key):
if isinstance(key, Number):
# metadata.iloc[Number], single row across all columns
return {k: [self.data[k][key]] for k in self.keys}

elif isinstance(key, (Number, slice, list, np.ndarray, pd.Index, pd.Series)):
elif isinstance(key, (slice, list, np.ndarray, pd.Index, pd.Series)):
# metadata.iloc[array_like], multiple rows across all columns
return {k: self.data[k][key] for k in self.keys}

elif isinstance(key, tuple) and len(key) == 2:
columns = self.keys[key[1]]

if isinstance(key[0], Number):
# metadata.iloc[Number, *], single row across column(s)
return {k: [self.data[k][key[0]]] for k in columns}
else:
# metadata.iloc[array_like, *], multiple rows across column(s)
return {k: self.data[k][key[0]] for k in columns}

else:
Expand Down

0 comments on commit 143d542

Please sign in to comment.