Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Feb 6, 2025
1 parent b87a013 commit 01aa1f6
Show file tree
Hide file tree
Showing 13 changed files with 199 additions and 172 deletions.
2 changes: 1 addition & 1 deletion pynapple/core/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,6 @@ def _from_npz_reader(cls, file):
ts = cls(time_support=iset, **kwargs)
if "_metadata" in file: # load metadata if it exists
if file["_metadata"]: # check if metadata is not empty
m = pd.DataFrame.from_dict(file["_metadata"].item())
m = file["_metadata"].item()
ts.set_info(m)
return ts
4 changes: 2 additions & 2 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __getitem__(self, key):
elif isinstance(key, list) and all(isinstance(x, str) for x in key):
# self[[*str]]
# only works for list of metadata columns
if all(x in key for x in self.metadata_columns):
if all(x in self.metadata_columns for x in key):
return self._metadata[key]
else:
raise IndexError(
Expand Down Expand Up @@ -943,7 +943,7 @@ def save(self, filename):
start=self.values[:, 0],
end=self.values[:, 1],
type=np.array(["IntervalSet"], dtype=np.str_),
_metadata=self._metadata.to_dict(), # save metadata as dictionary
_metadata=dict(self._metadata), # save metadata as dictionary
)

return
Expand Down
22 changes: 21 additions & 1 deletion pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Union
from collections import UserDict
from .utils import is_array_like
import copy

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -470,6 +471,10 @@ def iloc(self):
def shape(self):
return (len(self.index), len(self.columns))

@property
def dtypes(self):
return {k: self.data[k].dtype for k in self.columns}

def as_dataframe(self):
return pd.DataFrame(self.data, index=self.index)

Expand All @@ -480,12 +485,27 @@ def groupby(self, by):
}

elif isinstance(by, list):
return {
groups = {
k: np.where(
np.all([self.data[col] == k[c] for c, col in enumerate(by)], axis=0)
)[0]
for k in itertools.product(*[np.unique(self.data[col]) for col in by])
}
# remove empty groups
return {k: v for k, v in groups.items() if len(v)}

def copy(self):
return copy.deepcopy(self)

def merge(self, other):
if not isinstance(other, _Metadata):
raise TypeError("Can only merge with another _Metadata object")
if not np.all(self.columns == other.columns):
raise ValueError("Cannot merge metadata with different columns")

self.index = np.concatenate([self.index, other.index])
for k, v in other.data.items():
self.data[k] = np.concatenate([self.data[k], v])


class _MetadataLoc:
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def save(self, filename):
end=self.time_support.end,
columns=cols_name,
type=np.array(["TsdFrame"], dtype=np.str_),
_metadata=self._metadata.to_dict(), # save metadata as dictionary
_metadata=dict(self._metadata), # save metadata as dictionary
)

return
Expand Down
14 changes: 7 additions & 7 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,18 +806,18 @@ def to_tsd(self, *args):
"""
if len(args):
if isinstance(args[0], pd.Series):
if pd.Index.equals(self._metadata.index, args[0].index):
if np.array_equal(self._metadata.index, args[0].index):
_values = args[0].values.flatten()
else:
raise RuntimeError("Index are not equals")
elif isinstance(args[0], (np.ndarray, list)):
if len(self._metadata) == len(args[0]):
if self._metadata.shape[0] == len(args[0]):
_values = np.array(args[0])
else:
raise RuntimeError("Values is not the same length.")
elif isinstance(args[0], str):
if args[0] in self._metadata.columns:
_values = self._metadata[args[0]].values
_values = self._metadata[args[0]]
else:
raise RuntimeError(
"Key {} not in metadata of TsGroup".format(args[0])
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def merge_group(
tsg1 = tsgroups[0]
items = tsg1.items()
keys = set(tsg1.keys())
metadata = tsg1._metadata
metadata = tsg1._metadata.copy()

for i, tsg in enumerate(tsgroups[1:]):
if not ignore_metadata:
Expand All @@ -1106,7 +1106,7 @@ def merge_group(
f"TsGroup at position {i+2} has different metadata columns from previous TsGroup objects. "
"Set `ignore_metadata=True` to bypass the check."
)
metadata = pd.concat([metadata, tsg._metadata], axis=0)
metadata.merge(tsg._metadata)

if not reset_index:
key_overlap = keys.intersection(tsg.keys())
Expand Down Expand Up @@ -1331,7 +1331,7 @@ def save(self, filename):

dicttosave = {"type": np.array(["TsGroup"], dtype=np.str_)}
# don't save rate in metadata since it will be re-added when loading
dicttosave["_metadata"] = self.metadata.drop(columns="rate").to_dict()
dicttosave["_metadata"] = self._metadata.iloc[:, 1:]

# are these things that still need to be enforced?
# for k in self._metadata.columns:
Expand Down Expand Up @@ -1424,7 +1424,7 @@ def _from_npz_reader(cls, file):

if "_metadata" in file: # load metadata if it exists
if file["_metadata"]: # check that metadata is not empty
metainfo = pd.DataFrame.from_dict(file["_metadata"].item())
metainfo = file["_metadata"].item()
tsgroup.set_info(metainfo)
return tsgroup

Expand Down
4 changes: 3 additions & 1 deletion pynapple/process/correlograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def compute_crosscorrelogram(
crosscorrs = pd.DataFrame.from_dict(crosscorrs)

if norm:
freq = newgroup.get_info("rate")
freq = pd.Series(
index=newgroup.metadata_index, data=newgroup.get_info("rate")
)
freq2 = pd.Series(
index=pairs, data=list(map(lambda n: freq.loc[n[1]], pairs))
)
Expand Down
20 changes: 9 additions & 11 deletions tests/test_jitted.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def test_jitintersect():
s,
e,
metadata={
"label1": ep1.label1.loc[m[:, 0]].reset_index(drop=True),
"label2": ep2.label2.loc[m[:, 1]].reset_index(drop=True),
"label1": ep1.get_info((m[:, 0], "label1")),
"label2": ep2.get_info((m[:, 1], "label2")),
},
)

Expand All @@ -334,9 +334,9 @@ def test_jitintersect():
)

# stack labels to match up with start and end times
label1 = np.hstack((ep1.label1.values, np.nan * np.ones(len(ep2))))
label1 = np.hstack((ep1.label1, np.nan * np.ones(len(ep2))))
label1 = np.hstack((label1, label1))
label2 = np.hstack((np.nan * np.ones(len(ep1)), ep2.label2.values))
label2 = np.hstack((np.nan * np.ones(len(ep1)), ep2.label2))
label2 = np.hstack((label2, label2))

df = pd.DataFrame(
Expand All @@ -361,7 +361,7 @@ def test_jitintersect():
ep4 = nap.IntervalSet(start, end, metadata={"label1": label1, "label2": label2})

np.testing.assert_array_almost_equal(ep3, ep4)
pd.testing.assert_frame_equal(ep3._metadata, ep4._metadata)
pd.testing.assert_frame_equal(ep3.metadata, ep4.metadata)


def test_jitunion():
Expand Down Expand Up @@ -407,19 +407,17 @@ def test_jitdiff():
s, e, m = nap.core._jitted_functions.jitdiff(
ep1.start, ep1.end, ep2.start, ep2.end
)
ep3 = nap.IntervalSet(
s, e, metadata={"label1": ep1.label1.loc[m].reset_index(drop=True)}
)
ep3 = nap.IntervalSet(s, e, metadata={"label1": ep1.get_info((m, "label1"))})

i_sets = (ep1, ep2)
time = np.hstack(
[i_set["start"] for i_set in i_sets] + [i_set["end"] for i_set in i_sets]
)
label1 = np.hstack(
(
ep1.label1.values,
ep1.label1,
np.nan * np.ones(len(ep2)),
ep1.label1.values,
ep1.label1,
np.nan * np.ones(len(ep2)),
)
)
Expand Down Expand Up @@ -455,7 +453,7 @@ def test_jitdiff():
)

np.testing.assert_array_almost_equal(ep3, ep4)
pd.testing.assert_frame_equal(ep3._metadata, ep4._metadata)
pd.testing.assert_frame_equal(ep3.metadata, ep4.metadata)


def test_jitunion_isets():
Expand Down
Loading

0 comments on commit 01aa1f6

Please sign in to comment.