diff --git a/doc/releases.md b/doc/releases.md index 6baece56..4328521e 100644 --- a/doc/releases.md +++ b/doc/releases.md @@ -18,6 +18,13 @@ of the Flatiron institute. ## Releases +### 0.8.4 (2025-02-07) + +- Fix value printing of IntervalSet when rows are collapsed +- Backward compatibility fix for loading npz files with TsGroup +- Fix indexing of IntervalSet to be able to use -1 +- Add column names for compute_wavelet_transform + ### 0.8.3 (2025-01-24) - `compute_mean_power_spectral_density` computes the mean periodogram. diff --git a/pynapple/__init__.py b/pynapple/__init__.py index f21739d8..afdc02b6 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.8.3" +__version__ = "0.8.4" from .core import ( IntervalSet, Ts, diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index 33cb2a0a..1e60c9f2 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -1,10 +1,10 @@ """ - This module holds the core function of pynapple as well as - the dispatch between numba and jax. +This module holds the core function of pynapple as well as +the dispatch between numba and jax. - If pynajax is installed and `nap.nap_config.backend` is set - to `jax`, the module will call the functions within pynajax. - Otherwise the module will call the functions within `_jitted_functions.py`. +If pynajax is installed and `nap.nap_config.backend` is set +to `jax`, the module will call the functions within pynajax. +Otherwise the module will call the functions within `_jitted_functions.py`. """ diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index fae1b565..bdbed867 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -1,5 +1,5 @@ """ - Abstract class for `core` time series. +Abstract class for `core` time series. """ diff --git a/pynapple/core/config.py b/pynapple/core/config.py index ca25faa8..dc22efbd 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -2,12 +2,12 @@ ## Backend configuration -By default, pynapple core functions are compiled with [Numba](https://numba.pydata.org/). -It is possible to change the backend to [Jax](https://jax.readthedocs.io/en/latest/index.html) +By default, pynapple core functions are compiled with [Numba](https://numba.pydata.org/). +It is possible to change the backend to [Jax](https://jax.readthedocs.io/en/latest/index.html) through the [pynajax package](https://github.com/pynapple-org/pynajax). While numba core functions runs on CPU, the `jax` backend allows pynapple to use GPU accelerated core functions. -For some core functions, the `jax` backend offers speed gains (provided that Jax runs on the GPU). +For some core functions, the `jax` backend offers speed gains (provided that Jax runs on the GPU). See the example below to update the backend. Don't forget to install [pynajax](https://github.com/pynapple-org/pynajax). @@ -16,7 +16,7 @@ import numpy as np nap.nap_config.set_backend("jax") # Default option is 'numba'. -You can view the current backend with +You can view the current backend with >>> print(nap.nap_config.backend) 'jax' diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index e2692121..36bc245e 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -1,4 +1,4 @@ -""" +""" The class `IntervalSet` deals with non-overlaping epochs. `IntervalSet` objects can interact with each other or with the time series objects. """ @@ -294,7 +294,7 @@ def __init__( self._class_attributes = self.__dir__() # get list of all attributes self._class_attributes.append("_class_attributes") # add this property self._initialized = True - if drop_meta is False: + if (drop_meta is False) and (metadata is not None): self.set_info(metadata) def __repr__(self): @@ -336,7 +336,7 @@ def __repr__(self): np.hstack( ( self.index[-n_rows:, None], - self.values[0:n_rows], + self.values[-n_rows:], _convert_iter_to_str(metadata.values[-n_rows:]), ), dtype=object, @@ -406,11 +406,6 @@ def __setitem__(self, key, value): ) def __getitem__(self, key): - try: - metadata = _MetadataMixin.__getitem__(self, key) - except Exception: - metadata = pd.DataFrame(index=self.index) - if isinstance(key, str): # self[str] if key == "start": @@ -435,9 +430,15 @@ def __getitem__(self, key): elif isinstance(key, Number): # self[Number] output = self.values.__getitem__(key) + metadata = self._metadata.iloc[key] return IntervalSet(start=output[0], end=output[1], metadata=metadata) - elif isinstance(key, (slice, list, np.ndarray, pd.Series)): - # self[array_like] + elif isinstance(key, (slice, list, np.ndarray)): + # self[array_like], use iloc for metadata + output = self.values.__getitem__(key) + metadata = self._metadata.iloc[key].reset_index(drop=True) + return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata) + elif isinstance(key, pd.Series): + # use loc for metadata output = self.values.__getitem__(key) metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True) return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata) @@ -487,7 +488,7 @@ def __getitem__(self, key): if key[1] == slice(None, None, None): # self[Any, :] output = self.values.__getitem__(key[0]) - metadata = _MetadataMixin.__getitem__(self, key[0]) + metadata = self._metadata.iloc[key[0]] if isinstance(key[0], Number): return IntervalSet( @@ -500,7 +501,9 @@ def __getitem__(self, key): metadata=metadata.reset_index(drop=True), ) - elif key[1] == slice(0, 2, None): + elif (key[1] == slice(0, 2, None)) or ( + key[1] == slice(None, 2, None) + ): # self[Any, :2] # allow number indexing for start and end times for backward compatibility output = self.values.__getitem__(key[0]) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 8e494b2e..c903111d 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -313,7 +313,7 @@ def get_info(self, key): ) ) ): - # assume key is index, or tupe of index and column name + # assume key is index, or tuple of index and column name # metadata[Number], metadata[array_like], metadata[Any, str], or metadata[Any, [*str]] return self._metadata.loc[key] diff --git a/pynapple/core/time_index.py b/pynapple/core/time_index.py index ca423f79..458c4620 100644 --- a/pynapple/core/time_index.py +++ b/pynapple/core/time_index.py @@ -1,12 +1,12 @@ """ - Similar to pandas.Index, `TsIndex` holds the timestamps associated with the data of a time series. - This class deals with conversion between different time units for all pynapple objects as well - as making sure that timestamps are property sorted before initializing any objects. - - - `us`: microseconds - - `ms`: milliseconds - - `s`: seconds (overall default) +Similar to pandas.Index, `TsIndex` holds the timestamps associated with the data of a time series. +This class deals with conversion between different time units for all pynapple objects as well +as making sure that timestamps are property sorted before initializing any objects. + + - `us`: microseconds + - `ms`: milliseconds + - `s`: seconds (overall default) """ from warnings import warn diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 213d100a..128bfc85 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1,18 +1,18 @@ """ - - Pynapple time series are containers specialized for neurophysiological time series. - They provides standardized time representation, plus various functions for manipulating times series with identical sampling frequency. +Pynapple time series are containers specialized for neurophysiological time series. - Multiple time series object are avaible depending on the shape of the data. +They provides standardized time representation, plus various functions for manipulating times series with identical sampling frequency. - - `TsdTensor` : for data with of more than 2 dimensions, typically movies. - - `TsdFrame` : for column-based data. It can be easily converted to a pandas.DataFrame. Columns can be labelled and selected similar to pandas. - - `Tsd` : One-dimensional time series. It can be converted to a pandas.Series. - - `Ts` : For timestamps data only. +Multiple time series object are avaible depending on the shape of the data. - Most of the same functions are available through all classes. Objects behaves like numpy.ndarray. Slicing can be done the same way for example - `tsd[0:10]` returns the first 10 rows. Similarly, you can call any numpy functions like `np.mean(tsd, 1)`. +- `TsdTensor` : for data with of more than 2 dimensions, typically movies. +- `TsdFrame` : for column-based data. It can be easily converted to a pandas.DataFrame. Columns can be labelled and selected similar to pandas. +- `Tsd` : One-dimensional time series. It can be converted to a pandas.Series. +- `Ts` : For timestamps data only. + +Most of the same functions are available through all classes. Objects behaves like numpy.ndarray. Slicing can be done the same way for example +`tsd[0:10]` returns the first 10 rows. Similarly, you can call any numpy functions like `np.mean(tsd, 1)`. """ import abc diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 487dcc56..44cbe229 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1,6 +1,6 @@ """ -The class `TsGroup` helps group objects with different timestamps +The class `TsGroup` helps group objects with different timestamps (i.e. timestamps of spikes of a population of neurons). """ @@ -1419,13 +1419,31 @@ def _from_npz_reader(cls, file): tsgroup = cls(group, time_support=time_support, bypass_check=True) - # do we need to enforce that these keys are not in metadata? - # not_info_keys = {"start", "end", "t", "index", "d", "rate", "keys"} - if "_metadata" in file: # load metadata if it exists if file["_metadata"]: # check that metadata is not empty metainfo = pd.DataFrame.from_dict(file["_metadata"].item()) tsgroup.set_info(metainfo) + + metainfo = {} + not_info_keys = { + "start", + "end", + "t", + "index", + "d", + "rate", + "keys", + "_metadata", + "type", + } + + for k in set(file.keys()) - not_info_keys: + tmp = file[k] + if len(tmp) == len(tsgroup): + metainfo[k] = tmp + + tsgroup.set_info(**metainfo) + return tsgroup @add_meta_docstring("set_info") diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index eb271930..38379349 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -1,5 +1,5 @@ """ - Utility functions +Utility functions """ import os diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 6c9a15f9..94e07e75 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -1,10 +1,10 @@ """ - This module holds some process function of pynapple that can be - called with numba or pynajax as backend +This module holds some process function of pynapple that can be +called with numba or pynajax as backend - If pynajax is installed and `nap.nap_config.backend` is set - to `jax`, the module will call the functions within pynajax. - Otherwise the module will call the functions within `_jitted_functions.py`. +If pynajax is installed and `nap.nap_config.backend` is set +to `jax`, the module will call the functions within pynajax. +Otherwise the module will call the functions within `_jitted_functions.py`. """ diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 2af7a731..2b0d8572 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -1,6 +1,4 @@ -"""Functions to realign time series relative to a reference time. - -""" +"""Functions to realign time series relative to a reference time.""" import numpy as np diff --git a/pynapple/process/wavelets.py b/pynapple/process/wavelets.py index cb6598ff..98abb82e 100644 --- a/pynapple/process/wavelets.py +++ b/pynapple/process/wavelets.py @@ -127,7 +127,10 @@ def compute_wavelet_transform( if len(output_shape) == 2: return nap.TsdFrame( - t=sig.index, d=np.squeeze(cwt, axis=1), time_support=sig.time_support + t=sig.index, + d=np.squeeze(cwt, axis=1), + time_support=sig.time_support, + columns=freqs, ) else: return nap.TsdTensor( diff --git a/pyproject.toml b/pyproject.toml index 3b272490..2479df68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pynapple" -version = "0.8.3" +version = "0.8.4" description = "PYthon Neural Analysis Package Pour Laboratoires d’Excellence" readme = "README.md" authors = [{ name = "Guillaume Viejo", email = "guillaume.viejo@gmail.com" }] diff --git a/setup.py b/setup.py index 66003127..1b53583e 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/pynapple-org/pynapple', - version='v0.8.3', + version='v0.8.4', zip_safe=False, long_description_content_type='text/markdown', download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.8.3.tar.gz' diff --git a/template_loader.py b/template_loader.py deleted file mode 100644 index 4b2e56f1..00000000 --- a/template_loader.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Template for building a data loader with pynapple -""" -# @Author: gviejo -# @Date: 2022-01-26 18:13:42 -# @Last Modified by: gviejo -# @Last Modified time: 2022-08-18 18:02:38 - -import os - -from pynwb import NWBHDF5IO - -from pynapple.io.loader import BaseLoader - - -class MyCustomIO(BaseLoader): - def __init__(self, path): - """ - - Parameters - ---------- - path : str - The path to the data. - """ - self.basename = os.path.basename(path) - - super().__init__(path) - - # Need to check if nwb file exists and if data are there - loading_my_data = True - if self.path is not None: - nwb_path = os.path.join(self.path, "pynapplenwb") - if os.path.exists(nwb_path): - files = os.listdir(nwb_path) - if len([f for f in files if f.endswith(".nwb")]): - success = self.load_my_nwb(path) - if success: - loading_my_data = False - - # Bypass if data have already been transfered to nwb - if loading_my_data: - self.load_my_data(path) - - self.save_my_data_in_nwb(path) - - def load_my_data(self, path): - """ - This load the raw data - - Parameters - ---------- - path : str - Path to the session - """ - """ - Load Raw data here - """ - print(path) - return None - - def save_my_data_in_nwb(self, path): - """ - Save the raw data to NWB - - Parameters - ---------- - path : TYPE - Description - """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - - io = NWBHDF5IO(self.nwbfilepath, "r+") - - """ - Save data in NWB here - """ - - io.close() - - return - - def load_my_nwb(self, path): - """ - This load the nwb that is already create by the base loader - - Parameters - ---------- - path : str - Path to the session - """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - - io = NWBHDF5IO(self.nwbfilepath, "r") - nwbfile = io.read() - print(nwbfile) - - """ - Add code to write to nwb file here - """ - - io.close() - - -mydata = MyCustomIO(".") - -print(type(mydata)) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 1f7819ff..dceebf0f 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -153,12 +153,17 @@ def test_create_iset_from_df_with_metadata_sort(df, expected): "index", [ 0, + -1, slice(0, 2), [0, 2], + [0, -1], (slice(0, 2), slice(None)), (slice(0, 2), slice(0, 2)), (slice(None), ["start", "end"]), + ([0, -1], slice(None, 2)), (0, slice(None)), + (-1, slice(None)), + ([0, -1], slice(None)), ], ) def test_get_iset_with_metadata(iset_meta, index): diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index d809efb8..cd85c535 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -1,9 +1,3 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-07-10 17:08:55 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 13:13:37 - """Tests of NPZ file functions""" import shutil @@ -115,6 +109,31 @@ def test_load_tsgroup(path, k): ) +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["tsgroup", "tsgroup_minfo"]) +def test_load_tsgroup_backward_compatibility(path, k): + """ + For npz files saved without the _metadata keys + """ + file_path = path / (k + ".npz") + tmp = dict(np.load(file_path, allow_pickle=True)) + # Adding one metadata element outside the _metadata key + tag = np.random.randn(3) + tmp["tag"] = tag + np.savez(file_path, **tmp) + + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + assert tmp.keys() == list(data[k].keys()) + assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) + np.testing.assert_array_almost_equal( + tmp.time_support.values, data[k].time_support.values + ) + assert "rate" in tmp.metadata.columns + np.testing.assert_array_almost_equal(tmp.tag.values, tag) + + @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ["tsd"]) def test_load_tsd(path, k): @@ -157,6 +176,50 @@ def test_load_tsdframe(path, k): assert np.all(tmp.d == data[k].d) +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["tsdframe", "tsdframe_minfo"]) +def test_load_tsdframe_backward_compatibility(path, k): + file_path = path / (k + ".npz") + tmp = dict(np.load(file_path, allow_pickle=True)) + tmp.pop("_metadata") + np.savez(file_path, **tmp) + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + assert np.all(tmp.t == data[k].t) + np.testing.assert_array_almost_equal( + tmp.time_support.values, data[k].time_support.values + ) + assert np.all(tmp.columns == data[k].columns) + assert np.all(tmp.d == data[k].d) + + +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["iset", "iset_minfo"]) +def test_load_intervalset(path, k): + file_path = path / (k + ".npz") + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + np.testing.assert_array_almost_equal(tmp.values, data[k].values) + + +@pytest.mark.parametrize("path", [path]) +@pytest.mark.parametrize("k", ["iset", "iset_minfo"]) +def test_load_intervalset_backward_compatibility(path, k): + file_path = path / (k + ".npz") + tmp = dict(np.load(file_path, allow_pickle=True)) + tmp.pop("_metadata") + np.savez(file_path, **tmp) + + file = nap.NPZFile(file_path) + tmp = file.load() + assert isinstance(tmp, type(data[k])) + np.testing.assert_array_almost_equal(tmp.values, data[k].values) + # Testing the slicing + np.testing.assert_array_almost_equal(tmp[0].values, data[k].values[0, None]) + + @pytest.mark.parametrize("path", [path]) def test_load_non_npz(path): file_path = path / "random.npz" diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index f9d92800..7a2241e0 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -397,6 +397,9 @@ def test_compute_wavelet_transform( np.testing.assert_array_almost_equal( mwt.time_support.values, sig.time_support.values ) + if isinstance(mwt, nap.TsdFrame): + # test column names if TsdFrame + np.testing.assert_array_almost_equal(mwt.columns, freqs) @pytest.mark.parametrize(