Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add column names for compute_wavelet_transform #405

Merged
merged 2 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pynapple/core/_core_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
This module holds the core function of pynapple as well as
the dispatch between numba and jax.
This module holds the core function of pynapple as well as
the dispatch between numba and jax.
If pynajax is installed and `nap.nap_config.backend` is set
to `jax`, the module will call the functions within pynajax.
Otherwise the module will call the functions within `_jitted_functions.py`.
If pynajax is installed and `nap.nap_config.backend` is set
to `jax`, the module will call the functions within pynajax.
Otherwise the module will call the functions within `_jitted_functions.py`.
"""

Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/base_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Abstract class for `core` time series.
Abstract class for `core` time series.

"""

Expand Down
8 changes: 4 additions & 4 deletions pynapple/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

## Backend configuration

By default, pynapple core functions are compiled with [Numba](https://numba.pydata.org/).
It is possible to change the backend to [Jax](https://jax.readthedocs.io/en/latest/index.html)
By default, pynapple core functions are compiled with [Numba](https://numba.pydata.org/).
It is possible to change the backend to [Jax](https://jax.readthedocs.io/en/latest/index.html)
through the [pynajax package](https://github.com/pynapple-org/pynajax).

While numba core functions runs on CPU, the `jax` backend allows pynapple to use GPU accelerated core functions.
For some core functions, the `jax` backend offers speed gains (provided that Jax runs on the GPU).
For some core functions, the `jax` backend offers speed gains (provided that Jax runs on the GPU).

See the example below to update the backend. Don't forget to install [pynajax](https://github.com/pynapple-org/pynajax).

Expand All @@ -16,7 +16,7 @@
import numpy as np
nap.nap_config.set_backend("jax") # Default option is 'numba'.

You can view the current backend with
You can view the current backend with

>>> print(nap.nap_config.backend)
'jax'
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
The class `IntervalSet` deals with non-overlaping epochs. `IntervalSet` objects can interact with each other or with the time series objects.
"""

Expand Down
14 changes: 7 additions & 7 deletions pynapple/core/time_index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""

Similar to pandas.Index, `TsIndex` holds the timestamps associated with the data of a time series.
This class deals with conversion between different time units for all pynapple objects as well
as making sure that timestamps are property sorted before initializing any objects.
- `us`: microseconds
- `ms`: milliseconds
- `s`: seconds (overall default)
Similar to pandas.Index, `TsIndex` holds the timestamps associated with the data of a time series.
This class deals with conversion between different time units for all pynapple objects as well
as making sure that timestamps are property sorted before initializing any objects.

- `us`: microseconds
- `ms`: milliseconds
- `s`: seconds (overall default)
"""

from warnings import warn
Expand Down
20 changes: 10 additions & 10 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
"""

Pynapple time series are containers specialized for neurophysiological time series.

They provides standardized time representation, plus various functions for manipulating times series with identical sampling frequency.
Pynapple time series are containers specialized for neurophysiological time series.

Multiple time series object are avaible depending on the shape of the data.
They provides standardized time representation, plus various functions for manipulating times series with identical sampling frequency.

- `TsdTensor` : for data with of more than 2 dimensions, typically movies.
- `TsdFrame` : for column-based data. It can be easily converted to a pandas.DataFrame. Columns can be labelled and selected similar to pandas.
- `Tsd` : One-dimensional time series. It can be converted to a pandas.Series.
- `Ts` : For timestamps data only.
Multiple time series object are avaible depending on the shape of the data.

Most of the same functions are available through all classes. Objects behaves like numpy.ndarray. Slicing can be done the same way for example
`tsd[0:10]` returns the first 10 rows. Similarly, you can call any numpy functions like `np.mean(tsd, 1)`.
- `TsdTensor` : for data with of more than 2 dimensions, typically movies.
- `TsdFrame` : for column-based data. It can be easily converted to a pandas.DataFrame. Columns can be labelled and selected similar to pandas.
- `Tsd` : One-dimensional time series. It can be converted to a pandas.Series.
- `Ts` : For timestamps data only.

Most of the same functions are available through all classes. Objects behaves like numpy.ndarray. Slicing can be done the same way for example
`tsd[0:10]` returns the first 10 rows. Similarly, you can call any numpy functions like `np.mean(tsd, 1)`.
"""

import abc
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""

The class `TsGroup` helps group objects with different timestamps
The class `TsGroup` helps group objects with different timestamps
(i.e. timestamps of spikes of a population of neurons).

"""
Expand Down
2 changes: 1 addition & 1 deletion pynapple/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Utility functions
Utility functions
"""

import os
Expand Down
10 changes: 5 additions & 5 deletions pynapple/process/_process_functions.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""
This module holds some process function of pynapple that can be
called with numba or pynajax as backend
This module holds some process function of pynapple that can be
called with numba or pynajax as backend
If pynajax is installed and `nap.nap_config.backend` is set
to `jax`, the module will call the functions within pynajax.
Otherwise the module will call the functions within `_jitted_functions.py`.
If pynajax is installed and `nap.nap_config.backend` is set
to `jax`, the module will call the functions within pynajax.
Otherwise the module will call the functions within `_jitted_functions.py`.
"""

Expand Down
4 changes: 1 addition & 3 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Functions to realign time series relative to a reference time.

"""
"""Functions to realign time series relative to a reference time."""

import numpy as np

Expand Down
5 changes: 4 additions & 1 deletion pynapple/process/wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ def compute_wavelet_transform(

if len(output_shape) == 2:
return nap.TsdFrame(
t=sig.index, d=np.squeeze(cwt, axis=1), time_support=sig.time_support
t=sig.index,
d=np.squeeze(cwt, axis=1),
time_support=sig.time_support,
columns=freqs,
)
else:
return nap.TsdTensor(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_signal_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ def test_compute_wavelet_transform(
np.testing.assert_array_almost_equal(
mwt.time_support.values, sig.time_support.values
)
if isinstance(mwt, nap.TsdFrame):
# test column names if TsdFrame
np.testing.assert_array_almost_equal(mwt.columns, freqs)


@pytest.mark.parametrize(
Expand Down
Loading