Skip to content

Commit 143d542

Browse files
committed
adding some docstrings and fixing imports
1 parent 2809164 commit 143d542

File tree

3 files changed

+70
-12
lines changed

3 files changed

+70
-12
lines changed

pynapple/core/base_class.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from numbers import Number
88

99
import numpy as np
10-
import pandas as pd
1110

1211
from ._core_functions import _count, _restrict, _value_from
1312
from .interval_set import IntervalSet

pynapple/core/interval_set.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def __init__(
207207
start = start.sort_values("start").reset_index(drop=True)
208208

209209
metadata = start.drop(columns=["start", "end"])
210-
index = start.index
211210
end = start["end"].values.astype(np.float64)
212211
start = start["start"].values.astype(np.float64)
213212

pynapple/core/metadata_class.py

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
import numpy as np
99
import pandas as pd
1010

11-
from .utils import is_array_like
12-
1311

1412
def add_meta_docstring(meta_func, sep="\n"):
1513
meta_doc = getattr(_MetadataMixin, meta_func).__doc__
@@ -243,7 +241,6 @@ def set_info(self, metadata=None, **kwargs):
243241
if isinstance(v, pd.Series):
244242
if np.all(self.metadata_index == v.index.values):
245243
self._metadata[k] = np.array(v)
246-
# self._metadata[k] = v
247244
else:
248245
raise ValueError(
249246
"Metadata index does not match for argument {}".format(k)
@@ -252,7 +249,6 @@ def set_info(self, metadata=None, **kwargs):
252249
elif isinstance(v, (np.ndarray, list, tuple)):
253250
if len(self.metadata_index) == len(v):
254251
self._metadata[k] = np.array(v)
255-
# self._metadata[k] = pd.Series(v, index=self.metadata_index)
256252
else:
257253
raise ValueError(
258254
f"input array length {len(v)} does not match metadata length {len(self.metadata_index)}."
@@ -263,7 +259,6 @@ def set_info(self, metadata=None, **kwargs):
263259
):
264260
# if only one index and metadata is non-iterable, pack into iterable for single assignment
265261
self._metadata[k] = np.array([v])
266-
# self._metadata[k] = pd.Series(v, index=self.metadata_index)
267262

268263
else:
269264
not_set.append({k: v})
@@ -294,7 +289,7 @@ def get_info(self, key):
294289
295290
Returns
296291
-------
297-
pandas.Series or pandas.DataFrame or Any (for single location)
292+
dict or np.array or Any (for single location)
298293
The metadata information based on the key provided.
299294
300295
Raises
@@ -303,16 +298,19 @@ def get_info(self, key):
303298
If the metadata index is not found.
304299
"""
305300
if isinstance(key, str) and (key in self.metadata_columns):
301+
# single metadata column
306302
return self._metadata[key]
307303

308304
elif (
309305
isinstance(key, (list, np.ndarray))
310306
and all(isinstance(k, str) for k in key)
311307
and all(k in self.metadata_columns for k in key)
312308
):
309+
# multiple metadata columns
313310
return self._metadata[key]
314311

315312
else:
313+
# everything else, use .loc
316314
return self._metadata.loc[key]
317315

318316
def drop_info(self, key):
@@ -437,12 +435,18 @@ def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
437435

438436

439437
class _Metadata(UserDict):
438+
"""
439+
A custom dictionary class for storing metadata information.
440+
"""
440441

441442
def __init__(self, index):
442443
super().__init__()
443444
self.index = index
444445

445446
def __getitem__(self, key):
447+
"""
448+
Wrapper around typical dictionary __getitem__ to allow for multiple key indexing.
449+
"""
446450
try:
447451
if isinstance(key, list):
448452
return {key: self.data[key] for key in key}
@@ -455,34 +459,67 @@ def __getitem__(self, key):
455459

456460
@property
457461
def columns(self):
462+
"""
463+
Metadata keys (columns).
464+
"""
458465
return list(self.data.keys())
459466

460467
@property
461468
def loc(self):
469+
"""
470+
Pandas-like indexing for metadata.
471+
"""
462472
return _MetadataLoc(self, self.index)
463473

464474
@property
465475
def iloc(self):
476+
"""
477+
Numpy-like indexing for metadata.
478+
"""
466479
return _MetadataILoc(self)
467480

468481
@property
469482
def shape(self):
483+
"""
484+
Metadata shape as (n_index, n_columns).
485+
"""
470486
return (len(self.index), len(self.columns))
471487

472488
@property
473489
def dtypes(self):
490+
"""
491+
Dictonary of data types for each metadata column.
492+
"""
474493
return {k: self.data[k].dtype for k in self.columns}
475494

476495
def as_dataframe(self):
496+
"""
497+
Convert metadata dictionary to a pandas DataFrame.
498+
"""
477499
return pd.DataFrame(self.data, index=self.index)
478500

479501
def groupby(self, by):
502+
"""
503+
Grouping function for metadata.
504+
505+
Parameters
506+
----------
507+
by : str or list of str
508+
Metadata column name(s) to group by.
509+
510+
Returns
511+
-------
512+
dict
513+
Dictionary of object indices (dictionary values) corresponding to each group (dictionary keys).
514+
"""
480515
if isinstance(by, str):
516+
# groupby single column
481517
return {
482518
k: np.where(self.data[by] == k)[0] for k in np.unique(self.data[by])
483519
}
484520

485521
elif isinstance(by, list):
522+
# groupby multiple columns
486523
groups = {
487524
k: np.where(
488525
np.all([self.data[col] == k[c] for c, col in enumerate(by)], axis=0)
@@ -493,9 +530,16 @@ def groupby(self, by):
493530
return {k: v for k, v in groups.items() if len(v)}
494531

495532
def copy(self):
533+
"""
534+
Return a deep copy of the metadata object.
535+
"""
496536
return copy.deepcopy(self)
497537

498538
def merge(self, other):
539+
"""
540+
Merge metadata with another metadata object. Operates in place.
541+
Can only merge metadata with the same columns.
542+
"""
499543
if not isinstance(other, _Metadata):
500544
raise TypeError("Can only merge with another _Metadata object")
501545
if not np.all(self.columns == other.columns):
@@ -507,6 +551,10 @@ def merge(self, other):
507551

508552

509553
class _MetadataLoc:
554+
"""
555+
Helper class for pandas-like indexing of metadata.
556+
Assumes that index corresponds to object index values in first axis, and metadata columns in second axis.
557+
"""
510558

511559
def __init__(self, metadata, index):
512560
self.data = metadata.data
@@ -520,12 +568,12 @@ def __getitem__(self, key):
520568
return {k: self.data[k][key] for k in self.keys}
521569

522570
elif isinstance(key, (Number, str)):
523-
# metadata.loc[Number]
571+
# metadata.loc[Number], single row across all columns
524572
idx = self.index_map[key]
525573
return {k: self.data[k][idx] for k in self.keys}
526574

527575
elif isinstance(key, (list, np.ndarray, pd.Index, slice)):
528-
# metadata.loc[array_like]
576+
# metadata.loc[array_like], multiple rows across all columns
529577
idx = self._get_indexder(key)
530578
return {k: self.data[k][idx] for k in self.keys}
531579

@@ -537,12 +585,13 @@ def __getitem__(self, key):
537585
idx = self._get_indexder(key[0])
538586

539587
if isinstance(key[1], str):
540-
# metadata.loc[Any, str], index metadata field
588+
# metadata.loc[Any, str], slice single metadata column
541589
return self.data[key[1]][idx]
542590

543591
elif isinstance(key[1], list) and all(
544592
isinstance(k, str) for k in key[1]
545593
):
594+
# metadata.loc[Any, [*str]], slice multiple metadata columns
546595
return {k: self.data[k][idx] for k in key[1]}
547596
else:
548597
raise IndexError(f"Too many indices for metadata.loc: {key}")
@@ -551,28 +600,39 @@ def __getitem__(self, key):
551600
raise IndexError(f"Unknown metadata index {key}")
552601

553602
def _get_indexder(self, vals):
603+
"""
604+
Function that maps object index values to positional index.
605+
"""
554606
return [self.index_map[val] for val in vals]
555607

556608

557609
class _MetadataILoc:
610+
"""
611+
Helper class for numpy-like indexing of metadata.
612+
Assumes that indices correspond to positional index of row (object index) in first axis and positional index of column (metadata column) in second axis.
613+
"""
558614

559615
def __init__(self, metadata):
560616
self.data = metadata.data
561617
self.keys = metadata.columns
562618

563619
def __getitem__(self, key):
564620
if isinstance(key, Number):
621+
# metadata.iloc[Number], single row across all columns
565622
return {k: [self.data[k][key]] for k in self.keys}
566623

567-
elif isinstance(key, (Number, slice, list, np.ndarray, pd.Index, pd.Series)):
624+
elif isinstance(key, (slice, list, np.ndarray, pd.Index, pd.Series)):
625+
# metadata.iloc[array_like], multiple rows across all columns
568626
return {k: self.data[k][key] for k in self.keys}
569627

570628
elif isinstance(key, tuple) and len(key) == 2:
571629
columns = self.keys[key[1]]
572630

573631
if isinstance(key[0], Number):
632+
# metadata.iloc[Number, *], single row across column(s)
574633
return {k: [self.data[k][key[0]]] for k in columns}
575634
else:
635+
# metadata.iloc[array_like, *], multiple rows across column(s)
576636
return {k: self.data[k][key[0]] for k in columns}
577637

578638
else:

0 commit comments

Comments
 (0)