Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Sep 16, 2024
1 parent 83fff24 commit 8d26c20
Showing 1 changed file with 32 additions and 57 deletions.
89 changes: 32 additions & 57 deletions docs/examples/plot_benchmark_filtering.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,34 @@
"""
# filtering
"""
This notebook compare the jax implementation of Butterworth filter with [scipy sosfiltfilt](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfiltfilt.html).
Performances of the `'sinc'` mode can be found in the convolve benchmark as it is the function being called underneath.
⚠️ **Warning:** We do not recommend using GPU for filtering as it is much slower for the moment compared to CPU.
"""
import os
import numpy as np
import pynapple as nap
import jax.numpy as jnp
from time import perf_counter
import matplotlib.pyplot as plt

import warnings
warnings.filterwarnings("ignore")



# %%
# Machine Configuration
import jax
print(jax.devices())


# %%
def get_mean_perf(tsd, mode, n=10):
tmp = np.zeros(n)
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
for i in range(n):
t1 = perf_counter()
_ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode)
Expand All @@ -28,6 +37,7 @@ def get_mean_perf(tsd, mode, n=10):
return [np.mean(tmp), np.std(tmp)]

# %%
# # Increasing number of time points

def benchmark_time_points(mode):
times = []
Expand All @@ -36,57 +46,25 @@ def benchmark_time_points(mode):
data_array = np.random.randn(len(time_array))
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
tsd = nap.Tsd(t=time_array, d=data_array, time_support=ep)
tsd = nap.Tsd(t=time_array, d=data_array)#, time_support=ep)
times.append([T]+get_mean_perf(tsd, mode))
return np.array(times)

def benchmark_dimensions(mode):
times = []
T = 60000
for n in np.arange(1, 100, 20):
time_array = np.arange(T)/1000
data_array = np.random.randn(len(time_array), n)
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
tsd = nap.TsdFrame(t=time_array, d=data_array)#, time_support=ep)
times.append([n]+get_mean_perf(tsd, mode))
return np.array(times)


# %%
# # Increasing number of time points
#
# Calling with numba/scipy
nap.nap_config.set_backend("numba")

times_sinc_numba = benchmark_time_points(mode="sinc")
times_butter_scipy = benchmark_time_points(mode="butter")

# %%
# Calling with jax
nap.nap_config.set_backend("jax")

times_sinc_jax = benchmark_time_points(mode="sinc")
times_butter_jax = benchmark_time_points(mode="butter")

# # %%
# %%
# Figure


plt.figure(figsize = (16, 5))
plt.subplot(121)
for arr, label in zip(
[times_sinc_numba, times_sinc_jax],
["Sinc (numba)", "Sinc (jax)"],
):
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)
plt.legend()
plt.xlabel("Number of time points")
plt.ylabel("Time (s)")
plt.title("Windowed-sinc low pass")

plt.subplot(122)
plt.figure()
for arr, label in zip(
[times_butter_scipy, times_butter_jax],
["Butter (scipy)", "Butter (jax)"],
Expand All @@ -103,38 +81,35 @@ def benchmark_dimensions(mode):

# %%
# # Increasing number of dimensions
#

def benchmark_dimensions(mode):
times = []
T = 60000
for n in np.arange(1, 100, 20):
time_array = np.arange(T)/1000
data_array = np.random.randn(len(time_array), n)
startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2)
ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1])
tsd = nap.TsdFrame(t=time_array, d=data_array, time_support=ep)
times.append([n]+get_mean_perf(tsd, mode))
return np.array(times)

# %%
# Calling with numba/scipy
nap.nap_config.set_backend("numba")

dims_sinc_numba = benchmark_dimensions(mode="sinc")
dims_butter_scipy = benchmark_dimensions(mode="butter")

# %%
# Calling with jax
nap.nap_config.set_backend("jax")

dims_sinc_jax = benchmark_dimensions(mode="sinc")
dims_butter_jax = benchmark_dimensions(mode="butter")

# # %%
# %%
# Figure


plt.figure(figsize = (16, 5))
plt.subplot(121)
for arr, label in zip(
[dims_sinc_numba, dims_sinc_jax],
["Sinc (numba)", "Sinc (jax)"],
):
plt.plot(arr[:, 0], arr[:, 1], "o-", label=label)
plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2)
plt.legend()
plt.xlabel("Number of dimensions")
plt.ylabel("Time (s)")
plt.title("Windowed-sinc low pass")
plt.figure()

plt.subplot(122)
for arr, label in zip(
[dims_butter_scipy, dims_butter_jax],
["Butter (scipy)", "Butter (jax)"],
Expand Down

0 comments on commit 8d26c20

Please sign in to comment.