Skip to content

Commit 5e30a18

Browse files
authored
Merge pull request #26 from pynapple-org/filter
IIR Filter
2 parents 164c047 + 8d26c20 commit 5e30a18

18 files changed

+771
-105
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
*.npz
2+
13
# Byte-compiled / optimized / DLL files
24
__pycache__/
35
*.py[cod]

docs/examples/README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ The functions that have been optimized with `pynajax` are :
99

1010
- [`threshold`](https://pynapple-org.github.io/pynapple/reference/core/time_series/#pynapple.core.time_series.Tsd.threshold)
1111

12-
- [`event_trigger_average`](https://pynapple-org.github.io/pynapple/reference/process/perievent/#pynapple.process.perievent.compute_event_trigger_average)
12+
- [`event_trigger_average`](https://pynapple-org.github.io/pynapple/reference/process/perievent/#pynapple.process.perievent.compute_event_trigger_average)
13+
14+
- filtering
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
# filtering
3+
4+
This notebook compare the jax implementation of Butterworth filter with [scipy sosfiltfilt](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfiltfilt.html).
5+
6+
Performances of the `'sinc'` mode can be found in the convolve benchmark as it is the function being called underneath.
7+
8+
⚠️ **Warning:** We do not recommend using GPU for filtering as it is much slower for the moment compared to CPU.
9+
10+
11+
"""
12+
import os
13+
import numpy as np
14+
import pynapple as nap
15+
from time import perf_counter
16+
import matplotlib.pyplot as plt
17+
18+
import warnings
19+
warnings.filterwarnings("ignore")
20+
21+
22+
23+
# %%
24+
# Machine Configuration
25+
import jax
26+
print(jax.devices())
27+
28+
# %%
29+
def get_mean_perf(tsd, mode, n=10):
30+
tmp = np.zeros(n)
31+
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
32+
for i in range(n):
33+
t1 = perf_counter()
34+
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
35+
t2 = perf_counter()
36+
tmp[i] = t2 - t1
37+
return [np.mean(tmp), np.std(tmp)]
38+
39+
# %%
40+
# # Increasing number of time points
41+
42+
def benchmark_time_points(mode):
43+
times = []
44+
for T in np.arange(1000, 100000, 20000):
45+
time_array = np.arange(T)/1000
46+
data_array = np.random.randn(len(time_array))
47+
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
48+
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
49+
tsd = nap.Tsd(t=time_array, d=data_array)#, time_support=ep)
50+
times.append([T]+get_mean_perf(tsd, mode))
51+
return np.array(times)
52+
53+
54+
# %%
55+
# Calling with numba/scipy
56+
nap.nap_config.set_backend("numba")
57+
times_butter_scipy = benchmark_time_points(mode="butter")
58+
59+
# %%
60+
# Calling with jax
61+
nap.nap_config.set_backend("jax")
62+
times_butter_jax = benchmark_time_points(mode="butter")
63+
64+
# %%
65+
# Figure
66+
67+
plt.figure()
68+
for arr, label in zip(
69+
[times_butter_scipy, times_butter_jax],
70+
["Butter (scipy)", "Butter (jax)"],
71+
):
72+
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
73+
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)
74+
75+
plt.legend()
76+
plt.xlabel("Number of time points")
77+
plt.ylabel("Time (s)")
78+
plt.title("Butterworth filter low pass")
79+
# plt.show()
80+
81+
82+
# %%
83+
# # Increasing number of dimensions
84+
85+
def benchmark_dimensions(mode):
86+
times = []
87+
T = 60000
88+
for n in np.arange(1, 100, 20):
89+
time_array = np.arange(T)/1000
90+
data_array = np.random.randn(len(time_array), n)
91+
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
92+
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
93+
tsd = nap.TsdFrame(t=time_array, d=data_array, time_support=ep)
94+
times.append([n]+get_mean_perf(tsd, mode))
95+
return np.array(times)
96+
97+
# %%
98+
# Calling with numba/scipy
99+
nap.nap_config.set_backend("numba")
100+
dims_butter_scipy = benchmark_dimensions(mode="butter")
101+
102+
# %%
103+
# Calling with jax
104+
nap.nap_config.set_backend("jax")
105+
dims_butter_jax = benchmark_dimensions(mode="butter")
106+
107+
# %%
108+
# Figure
109+
110+
111+
plt.figure()
112+
113+
for arr, label in zip(
114+
[dims_butter_scipy, dims_butter_jax],
115+
["Butter (scipy)", "Butter (jax)"],
116+
):
117+
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
118+
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)
119+
120+
plt.legend()
121+
plt.xlabel("Number of dimensions")
122+
plt.ylabel("Time (s)")
123+
plt.title("Butterworth filter low pass")
124+
plt.show()
125+

docs/images/convolve_benchmark.png

-3.71 KB
Loading

src/pynajax/jax_core_bin_average.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import jax
22
import jax.numpy as jnp
33
import numpy as np
4-
5-
# import pynapple as nap
64
from numba import jit
75

86

src/pynajax/jax_core_convolve.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,12 @@ def convolve_intervals(time_array, data_array, starts, ends, kernel, trim="both"
170170
extra = (extra[0], extra[1] + 1)
171171

172172
n = len(starts)
173-
idx_start_shift = idx_start + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
174-
idx_end_shift = idx_end + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
173+
idx_start_shift = (
174+
idx_start + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
175+
)
176+
idx_end_shift = (
177+
idx_end + np.arange(1, n + 1) * extra[0] + np.arange(0, n) * extra[1]
178+
)
175179

176180
idx = _get_slicing(idx_start_shift, idx_end_shift)
177181

src/pynajax/jax_core_threshold.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,12 @@ def threshold(time_array, data_array, starts, ends, thr, method):
4747
ix2 = jnp.diff(ix * 1)
4848

4949
new_starts = (
50-
time_array[1:][ix2 == 1] - (time_array[1:][ix2 == 1] - time_array[0:-1][ix2 == 1]) / 2
50+
time_array[1:][ix2 == 1]
51+
- (time_array[1:][ix2 == 1] - time_array[0:-1][ix2 == 1]) / 2
5152
)
5253
new_ends = (
53-
time_array[0:-1][ix2 == -1] + (time_array[1:][ix2 == -1] - time_array[0:-1][ix2 == -1]) / 2
54+
time_array[0:-1][ix2 == -1]
55+
+ (time_array[1:][ix2 == -1] - time_array[0:-1][ix2 == -1]) / 2
5456
)
5557

5658
if ix[0]: # First element to keep as start

src/pynajax/jax_process_filtering.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
from functools import partial
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy as np
6+
import scipy.signal as signal
7+
8+
from .utils import (
9+
_get_shifted_indices,
10+
_get_slicing,
11+
_odd_ext_multiepoch,
12+
_revert_epochs,
13+
)
14+
15+
16+
@partial(jax.jit, static_argnums=(3, ))
17+
def _recursion_loop_sos(signal, sos, zi, nan_function):
18+
"""
19+
Applies a recursive second-order section (SOS) filter to the input signal.
20+
21+
Parameters
22+
----------
23+
signal : jnp.ndarray
24+
The input signal to be filtered, with shape (n_samples,).
25+
sos : jnp.ndarray
26+
Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6).
27+
zi : jnp.ndarray
28+
Initial conditions for the filter, with shape (n_sections, 2, n_epochs).
29+
nan_function : callable
30+
A function that specifies how to re-initialize the initial conditions when a NaN is encountered in the signal.
31+
It should take two arguments: the epoch number and the current filter state, and return a tuple of the updated
32+
epoch number and the re-initialized filter state.
33+
34+
Returns
35+
-------
36+
jnp.ndarray
37+
The filtered signal, with the same shape as the input signal.
38+
"""
39+
40+
def internal_loop(s, x_zi):
41+
x_cur, zi_slice = x_zi
42+
x_new = sos[s, 0] * x_cur + zi_slice[s, 0]
43+
zi_slice = zi_slice.at[s, 0].set(
44+
sos[s, 1] * x_cur - sos[s, 4] * x_new + zi_slice[s, 1]
45+
)
46+
zi_slice = zi_slice.at[s, 1].set(
47+
sos[s, 2] * x_cur - sos[s, 5] * x_new)
48+
x_cur = x_new
49+
return x_cur, zi_slice
50+
51+
def recursion_step(carry, x):
52+
epoch_num, zi_slice = carry
53+
54+
x_cur, zi_slice = jax.lax.fori_loop(
55+
lower=0, upper=sos.shape[0], body_fun=internal_loop, init_val=(x, zi_slice)
56+
)
57+
58+
# Use jax.lax.cond to choose between nan_case and not_nan_case
59+
epoch_num, zi_slice = jax.lax.cond(
60+
jnp.isnan(x), # Condition to check
61+
nan_function, # Function to call if x is NaN
62+
lambda i, x: (i, zi_slice), # Function to call if x is not NaN
63+
epoch_num,
64+
zi,
65+
)
66+
67+
return (epoch_num, zi_slice), x_cur
68+
69+
_, res = jax.lax.scan(recursion_step, (0, zi[..., 0]), signal)
70+
71+
return res
72+
73+
74+
# vectorize the recursion over signals.
75+
_vmap_recursion_sos = jax.vmap(_recursion_loop_sos, in_axes=(1, None, 2, None), out_axes=1)
76+
77+
78+
def _insert_constant(idx_start, idx_end, data_array, window_size, const=jnp.nan):
79+
"""
80+
Insert a constant value array between epochs in a time series data array.
81+
82+
This function interleaves a constant value array of specified size between each epoch in the data array.
83+
84+
Parameters
85+
----------
86+
idx_start : jnp.ndarray
87+
Array of start indices for each epoch.
88+
idx_end : jnp.ndarray
89+
Array of end indices for each epoch.
90+
data_array : jnp.ndarray
91+
The input data array, with shape (n_samples, ...).
92+
window_size : int
93+
The size of the constant array to be inserted between epochs.
94+
const : float, optional
95+
The constant value to be inserted, by default jnp.nan.
96+
97+
Returns
98+
-------
99+
data_array: jnp.ndarray
100+
The modified data array with the constant arrays inserted.
101+
ix_orig: jnp.ndarray
102+
Indices corresponding to the samples in the original data array.
103+
ix_shift: jnp.ndarray
104+
The shifted indices after the constant array has been interleaved.
105+
idx_start_shift:
106+
The shifted start indices of the epochs in the modified array.
107+
idx_end_shift:
108+
The shifted end indices of the epochs in the modified array.
109+
"""
110+
# shift by a window every epoch
111+
idx_start_shift, idx_end_shift = _get_shifted_indices(
112+
idx_start, idx_end, window_size
113+
)
114+
115+
# get the indices for setting elements
116+
ix_orig = _get_slicing(idx_start, idx_end)
117+
ix_shift = _get_slicing(idx_start_shift, idx_end_shift)
118+
119+
tot_size = ix_shift[-1] - ix_shift[0] + 1
120+
data_array = (
121+
jnp.full((tot_size, *data_array.shape[1:]), const)
122+
.at[ix_shift]
123+
.set(data_array[ix_orig])
124+
)
125+
return data_array, ix_orig, ix_shift, idx_start_shift, idx_end_shift
126+
127+
128+
def jax_sosfiltfilt(sos, time_array, data_array, starts, ends):
129+
"""
130+
Apply forward-backward filtering using a second-order section (SOS) filter.
131+
132+
This function applies an SOS filter to the data array in both forward and reverse directions,
133+
which results in zero-phase filtering.
134+
135+
Parameters
136+
----------
137+
sos : np.ndarray
138+
Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6).
139+
time_array : np.ndarray
140+
The time array corresponding to the data, with shape (n_samples,).
141+
data_array : jnp.ndarray
142+
The data array to be filtered, with shape (n_samples, ...).
143+
starts : np.ndarray
144+
Array of start indices for the epochs in the data array.
145+
ends : np.ndarray
146+
Array of end indices for the epochs in the data array.
147+
148+
Returns
149+
-------
150+
: jnp.ndarray
151+
The zero-phase filtered data array, with the same shape as the input data array.
152+
"""
153+
154+
original_shape = data_array.shape
155+
data_array = data_array.reshape(data_array.shape[0], -1)
156+
157+
# same default padding as scipy.sosfiltfilt ("pad" method and "odd" padtype).
158+
n_sections = sos.shape[0]
159+
ntaps = 2 * n_sections + 1
160+
ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum())
161+
pad_num = 3 * ntaps
162+
163+
ext, ix_start_pad, ix_end_pad, ix_data = _odd_ext_multiepoch(pad_num, time_array, data_array, starts, ends)
164+
165+
# get the start/end index of each epoch after padding
166+
ix_start_ep = np.hstack((ix_start_pad[0], ix_start_pad[1:-1] + pad_num))
167+
ix_end_ep = np.hstack((ix_start_ep[1:], ix_end_pad[-1]))
168+
169+
zi = signal.sosfilt_zi(sos)
170+
171+
# this braodcast has shape (*zi.shape, data_array.shape[1], len(ix_start_pad))
172+
z0 = zi[..., jnp.newaxis, jnp.newaxis] * ext.T[jnp.newaxis, jnp.newaxis, ..., ix_start_ep]
173+
174+
if len(starts) > 1:
175+
# multi epoch case augmenting with nans.
176+
aug_data, ix_orig, ix_shift, idx_start_shift, idx_end_shift = _insert_constant(
177+
ix_start_ep, ix_end_ep, ext, window_size=1, const=np.nan
178+
)
179+
180+
# grab the next initial condition, increase the epoch counter
181+
nan_func = lambda ep_num, x: (ep_num + 1, x[..., ep_num + 1])
182+
else:
183+
# single epoch, no augmentation
184+
nan_func = lambda ep_num, x: (ep_num + 1, x[..., 0])
185+
aug_data = ext
186+
idx_start_shift = ix_start_ep
187+
idx_end_shift = ix_end_ep
188+
ix_shift = slice(None)
189+
190+
191+
# call forward recursion
192+
out = _vmap_recursion_sos(aug_data, sos, z0, nan_func)
193+
194+
# reverse time axis
195+
irev = _revert_epochs(idx_start_shift, idx_end_shift)
196+
out = out.at[ix_shift].set(out[irev])
197+
198+
# compute new init cond
199+
z0 = zi[..., jnp.newaxis, jnp.newaxis] * out.T[jnp.newaxis, jnp.newaxis, ..., idx_start_shift]
200+
201+
# call backward recursion
202+
out = _vmap_recursion_sos(out, sos, z0, nan_func)
203+
204+
# re-flip axis
205+
out = out.at[ix_shift].set(out[irev])
206+
207+
# remove nans and padding
208+
out = out[ix_shift][ix_data]
209+
210+
return out.reshape(original_shape)

0 commit comments

Comments
 (0)