From 26c2b25e5b451886bda2cdea2a0e1ca383f25609 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 6 Jun 2024 17:02:16 +0100 Subject: [PATCH 001/195] inital commit for signal proc - complex morelet and fft v0 --- pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 201 ++++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 pynapple/process/signal_processing.py diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 2e1af412..db2581d5 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -24,3 +24,4 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) +from .signal_processing import * diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py new file mode 100644 index 00000000..6afda102 --- /dev/null +++ b/pynapple/process/signal_processing.py @@ -0,0 +1,201 @@ +import numpy as np +from itertools import repeat +import pynapple as nap +from tqdm import tqdm +import matplotlib.pyplot as plt + + +# -------------------------------------------------------------------------------- + +def compute_fft(sig, fs): + """ + Performs numpy fft on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor + + :param sig: :param fs: :return: + """ + if not isinstance(sig, nap.Tsd): + raise TypeError("Currently compute_fft is only implemented for Tsd") + fft_result = np.fft.fft(sig.values) + fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) + return fft_result, fft_freq + + +def morlet(M, ncycles=5.0, scaling=1.0): + """ + Defines the complex Morelet wavelet + :param M: Length of the wavelet. :param ncycles: number of wavelet cycles to use. Default is 5 :param scaling: Scaling factor. Default is 1. :return: (M,) ndarray Morelet wavelet + """ + x = np.linspace(-scaling * 2 * np.pi, scaling * 2 * np.pi, M) + return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)) + + +""" +The following code has been adapted from functions in the neurodsp package: +https://github.com/neurodsp-tools/neurodsp + +..todo: reference licence in LICENCE directory +""" + + +def check_n_cycles(n_cycles, len_cycles=None): + """Check an input as a number of cycles definition, and make it iterable. + + Parameters ---------- n_cycles : float or list Definition of number of cycles. If a single value, the same number of cycles is used for each frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. len_cycles : int, optional What the length of `n_cycles` should, if it's a list. + Returns ------- n_cycles : iterable An iterable version of the number of cycles. """ + if isinstance(n_cycles, (int, float, np.number)): + + if n_cycles <= 0: + raise ValueError('Number of cycles must be a positive number.') + + n_cycles = repeat(n_cycles) + + elif isinstance(n_cycles, (tuple, list, np.ndarray)): + + for cycle in n_cycles: + if cycle <= 0: + raise ValueError('Each number of cycles must be a positive number.') + + if len_cycles and len(n_cycles) != len_cycles: + raise ValueError('The length of number of cycles does not match other inputs.') + + n_cycles = iter(n_cycles) + + return n_cycles + + +def create_freqs(freq_start, freq_stop, freq_step=1): + """Create an array of frequencies. + + Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop : float Stopping value for the frequency definition, inclusive. freq_step : float, optional, default: 1 Step value, for linearly spaced values between start and stop. + Returns ------- freqs : 1d array Frequency indices. """ + return np.arange(freq_start, freq_stop + freq_step, freq_step) + + +def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp'): + """Compute the time-frequency representation of a signal using morlet wavelets. + + Parameters + ---------- + sig : 1d array + Time series. + fs : float + Sampling rate, in Hz. + freqs : 1d array or list of float + If array, frequency values to estimate with morlet wavelets. + If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + n_cycles : float or 1d array + Length of the filter, as the number of cycles for each frequency. + If 1d array, this defines n_cycles for each frequency. + scaling : float + Scaling factor. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + mwt : 2d array + Time frequency representation of the input signal. + + Notes + ----- + This computes the continuous wavelet transform at specified frequencies across time. + + Examples + -------- + Compute a Morlet wavelet time-frequency representation of a signal: + + >>> from neurodsp.sim import sim_combined + >>> sig = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) + >>> mwt = compute_wavelet_transform(sig, fs=500, freqs=[1, 30]) + """ + if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): + raise TypeError("`sig` must be instance of Tsd or TsdFrame") + + if isinstance(freqs, (tuple, list)): + freqs = create_freqs(*freqs) + n_cycles = check_n_cycles(n_cycles, len(freqs)) + if isinstance(sig, nap.Tsd): + mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + wav = convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + mwt[ind, :] = wav + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt)) + else: + mwt = np.zeros([sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex) + for channel_i in tqdm(range(sig.values.shape[1])): + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + wav = convolve_wavelet(sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm) + mwt[:, ind, channel_i] = wav + return nap.TsdTensor(t=sig.index, d=mwt) + + +def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'): + """Convolve a signal with a complex wavelet. + + Parameters + ---------- + sig : 1d array + Time series to filter. + fs : float + Sampling rate, in Hz. + freq : float + Center frequency of bandpass filter. + n_cycles : float, optional, default: 7 + Length of the filter, as the number of cycles of the oscillation with specified frequency. + scaling : float, optional, default: 0.5 + Scaling factor for the morlet wavelet. + wavelet_len : int, optional + Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + array + Complex time series. + + Notes + ----- + + * The real part of the returned array is the filtered signal. + * Taking np.abs() of output gives the analytic amplitude. + * Taking np.angle() of output gives the analytic phase. + + Examples + -------- + Convolve a complex wavelet with a simulated signal: + + >>> from neurodsp.sim import sim_combined + >>> sig = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) + >>> cts = convolve_wavelet(sig, fs=500, freq=10) + """ + if norm not in ['sss', 'amp']: + raise ValueError('Given `norm` must be `sss` or `amp`') + + if wavelet_len is None: + wavelet_len = int(n_cycles * fs / freq) + + if wavelet_len > sig.shape[-1]: + raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.') + + morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) + + if norm == 'sss': + morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) + elif norm == 'amp': + morlet_f = morlet_f / np.sum(np.abs(morlet_f)) + + mwt_real = sig.convolve(np.real(morlet_f)) + mwt_imag = sig.convolve(np.imag(morlet_f)) + + return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file From f080d4f38114fa2540b38a9e87beb8625d517262 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 00:33:58 +0100 Subject: [PATCH 002/195] basic pywavelets functionality matched --- pynapple/process/__init__.py | 6 +- pynapple/process/signal_processing.py | 297 +++++++++++++------------- 2 files changed, 149 insertions(+), 154 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index db2581d5..08d58648 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -24,4 +24,8 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) -from .signal_processing import * +from .signal_processing import ( + compute_wavelet_transform, + compute_spectrum, + compute_welch_spectrum +) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 6afda102..c49e729d 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,100 +1,72 @@ +""" +Signal processing tools for Pynapple. + +Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. +""" + import numpy as np -from itertools import repeat import pynapple as nap -from tqdm import tqdm -import matplotlib.pyplot as plt +from math import ceil, floor +import json +from scipy.signal import welch +with open('wavelets.json') as f: + WAVELET_DICT = json.load(f) -# -------------------------------------------------------------------------------- -def compute_fft(sig, fs): +def compute_spectrum(sig, fs=None): """ Performs numpy fft on sig, returns output ..todo: Make sig handle TsdFrame, TsdTensor - :param sig: :param fs: :return: + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) return fft_result, fft_freq -def morlet(M, ncycles=5.0, scaling=1.0): - """ - Defines the complex Morelet wavelet - :param M: Length of the wavelet. :param ncycles: number of wavelet cycles to use. Default is 5 :param scaling: Scaling factor. Default is 1. :return: (M,) ndarray Morelet wavelet +def compute_welch_spectrum(sig, fs=None): """ - x = np.linspace(-scaling * 2 * np.pi, scaling * 2 * np.pi, M) - return np.exp(1j * ncycles * x) * (np.exp(-0.5 * (x ** 2)) * np.pi ** (-0.25)) - - -""" -The following code has been adapted from functions in the neurodsp package: -https://github.com/neurodsp-tools/neurodsp - -..todo: reference licence in LICENCE directory -""" - + Performs scipy Welch's decomposition on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor -def check_n_cycles(n_cycles, len_cycles=None): - """Check an input as a number of cycles definition, and make it iterable. - - Parameters ---------- n_cycles : float or list Definition of number of cycles. If a single value, the same number of cycles is used for each frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. len_cycles : int, optional What the length of `n_cycles` should, if it's a list. - Returns ------- n_cycles : iterable An iterable version of the number of cycles. """ - if isinstance(n_cycles, (int, float, np.number)): - - if n_cycles <= 0: - raise ValueError('Number of cycles must be a positive number.') - - n_cycles = repeat(n_cycles) - - elif isinstance(n_cycles, (tuple, list, np.ndarray)): - - for cycle in n_cycles: - if cycle <= 0: - raise ValueError('Each number of cycles must be a positive number.') - - if len_cycles and len(n_cycles) != len_cycles: - raise ValueError('The length of number of cycles does not match other inputs.') - - n_cycles = iter(n_cycles) - - return n_cycles - - -def create_freqs(freq_start, freq_stop, freq_step=1): - """Create an array of frequencies. - - Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop : float Stopping value for the frequency definition, inclusive. freq_step : float, optional, default: 1 Step value, for linearly spaced values between start and stop. - Returns ------- freqs : 1d array Frequency indices. """ - return np.arange(freq_start, freq_stop + freq_step, freq_step) + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal + """ + if not isinstance(sig, nap.Tsd): + raise TypeError("Currently compute_fft is only implemented for Tsd") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + freqs, spectogram = welch(sig.values, fs=fs) + return spectogram, freqs -def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp'): - """Compute the time-frequency representation of a signal using morlet wavelets. +def compute_wavelet_transform(sig, freqs, fs=None): + """ + Compute the time-frequency representation of a signal using morlet wavelets. Parameters ---------- - sig : 1d array + sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float - Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. - norm : {'sss', 'amp'}, optional - Normalization method: - - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal Returns ------- @@ -104,98 +76,117 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=7, scaling=0.5, norm='amp Notes ----- This computes the continuous wavelet transform at specified frequencies across time. - - Examples - -------- - Compute a Morlet wavelet time-frequency representation of a signal: - - >>> from neurodsp.sim import sim_combined - >>> sig = sim_combined(n_seconds=10, fs=500, - ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) - >>> mwt = compute_wavelet_transform(sig, fs=500, freqs=[1, 30]) """ if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd or TsdFrame") - - if isinstance(freqs, (tuple, list)): - freqs = create_freqs(*freqs) - n_cycles = check_n_cycles(n_cycles, len(freqs)) + raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + assert fs/2 > np.max(freqs), "`freqs` contains values over the Nyquist frequency." if isinstance(sig, nap.Tsd): - mwt = np.zeros([len(freqs), len(sig)], dtype=complex) - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - wav = convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) - mwt[ind, :] = wav - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt)) - else: - mwt = np.zeros([sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex) - for channel_i in tqdm(range(sig.values.shape[1])): - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - wav = convolve_wavelet(sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm) - mwt[:, ind, channel_i] = wav - return nap.TsdTensor(t=sig.index, d=mwt) - - -def convolve_wavelet(sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm='sss'): - """Convolve a signal with a complex wavelet. + mwt, f = _cwt(sig, + freqs=freqs, + wavelet="cmor1.5-1.0", + sampling_period=1/fs) + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + elif isinstance(sig, nap.TsdFrame): + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): + mwt[:, :, channel_i] = np.transpose(_cwt(sig[:, channel_i], + freqs=freqs, + wavelet="cmor1.5-1.0", + sampling_period=1/fs)[0]) + return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + elif isinstance(sig, nap.TsdTensor): + raise NotImplemented("cwt for TsdTensor is not yet implemented") + + +def _cwt(data, freqs, wavelet, sampling_period, axis=-1): + """ + cwt(data, scales, wavelet) + + One dimensional Continuous Wavelet Transform. Parameters ---------- - sig : 1d array - Time series to filter. - fs : float - Sampling rate, in Hz. - freq : float - Center frequency of bandpass filter. - n_cycles : float, optional, default: 7 - Length of the filter, as the number of cycles of the oscillation with specified frequency. - scaling : float, optional, default: 0.5 - Scaling factor for the morlet wavelet. - wavelet_len : int, optional - Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. - norm : {'sss', 'amp'}, optional - Normalization method: - - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + data : pynapple.Tsd or pynapple.TsdFrame + Input time series signal. + freqs : 1d array + Frequency values to estimate with morlet wavelets. + wavelet : Wavelet object or name + Wavelet to use, only implemented for "cmor1.5-1.0". + sampling_period : float + Sampling period for the frequencies output. + The values computed for ``coefs`` are independent of the choice of + ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling + period). + axis: int, optional + Axis over which to compute the CWT. If not given, the last axis is + used. Returns ------- - array - Complex time series. - - Notes - ----- - - * The real part of the returned array is the filtered signal. - * Taking np.abs() of output gives the analytic amplitude. - * Taking np.angle() of output gives the analytic phase. - - Examples - -------- - Convolve a complex wavelet with a simulated signal: - - >>> from neurodsp.sim import sim_combined - >>> sig = sim_combined(n_seconds=10, fs=500, - ... components={'sim_powerlaw': {}, 'sim_oscillation' : {'freq': 10}}) - >>> cts = convolve_wavelet(sig, fs=500, freq=10) + coefs : array_like + Continuous wavelet transform of the input signal for the given scales + and wavelet. The first axis of ``coefs`` corresponds to the scales. + The remaining axes match the shape of ``data``. + frequencies : array_like + If the unit of sampling period are seconds and given, then frequencies + are in hertz. Otherwise, a sampling period of 1 is assumed. + + ..todo:: This should use pynapple convolve but currently that cannot handle imaginary numbers as it uses scipy convolve """ - if norm not in ['sss', 'amp']: - raise ValueError('Given `norm` must be `sss` or `amp`') - - if wavelet_len is None: - wavelet_len = int(n_cycles * fs / freq) - - if wavelet_len > sig.shape[-1]: - raise ValueError('The length of the wavelet is greater than the signal. Can not proceed.') - - morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) - - if norm == 'sss': - morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) - elif norm == 'amp': - morlet_f = morlet_f / np.sum(np.abs(morlet_f)) - - mwt_real = sig.convolve(np.real(morlet_f)) - mwt_imag = sig.convolve(np.imag(morlet_f)) - - return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file + int_psi = np.array(WAVELET_DICT[wavelet]['int_psi_real'])*1j + np.array(WAVELET_DICT[wavelet]['int_psi_imag']) + x = np.array(WAVELET_DICT[wavelet]["x"]) + central_freq = WAVELET_DICT[wavelet]["central_freq"] + scales = central_freq/(freqs*sampling_period) + out = np.empty((np.size(scales),) + data.shape, dtype=np.complex128) + + if data.ndim > 1: + # move axis to be transformed last (so it is contiguous) + data = data.swapaxes(-1, axis) + # reshape to (n_batch, data.shape[-1]) + data_shape_pre = data.shape + data = data.reshape((-1, data.shape[-1])) + + for i, scale in enumerate(scales): + step = x[1] - x[0] + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + + if data.ndim == 1: + conv = np.convolve(data, int_psi_scale) + else: + # batch convolution via loop + conv_shape = list(data.shape) + conv_shape[-1] += int_psi_scale.size - 1 + conv_shape = tuple(conv_shape) + conv = np.empty(conv_shape, dtype=np.complex128) + for n in range(data.shape[0]): + conv[n, :] = np.convolve(data[n], int_psi_scale) + + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + if out.dtype.kind != 'c': + coef = coef.real + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - data.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: + raise ValueError( + f"Selected scale of {scale} too small.") + if data.ndim > 1: + # restore original data shape and axis position + coef = coef.reshape(data_shape_pre) + coef = coef.swapaxes(axis, -1) + out[i, ...] = coef + + frequencies = central_freq/scales + if np.isscalar(frequencies): + frequencies = np.array([frequencies]) + frequencies /= sampling_period + return out, frequencies From 9aa05ac3755b97363fe54d353c882c328c7593ca Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 19:50:15 +0100 Subject: [PATCH 003/195] different wavelet definition --- pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 212 ++++++++++++++++++++++++++ 2 files changed, 213 insertions(+) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 08d58648..120a3363 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -26,6 +26,7 @@ ) from .signal_processing import ( compute_wavelet_transform, + compute_wavelet_transform_og, compute_spectrum, compute_welch_spectrum ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index c49e729d..e2b7bf9f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -9,6 +9,7 @@ from math import ceil, floor import json from scipy.signal import welch +from itertools import repeat with open('wavelets.json') as f: WAVELET_DICT = json.load(f) @@ -190,3 +191,214 @@ def _cwt(data, freqs, wavelet, sampling_period, axis=-1): frequencies = np.array([frequencies]) frequencies /= sampling_period return out, frequencies + + + + + + +# ------------------------------------------------------------------------------- + +def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): + """ + Defines the complex Morelet wavelet kernel + + Parameters + ---------- + M : int + Length of the wavelet + ncycles : float + number of wavelet cycles to use. Default is 5 + scaling: float + Scaling factor. Default is 1.5 + precision: int + Precision of wavelet to use + + Returns + ------- + np.ndarray + Morelet wavelet kernel + """ + x = np.linspace(-precision, precision, M) + return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) + +""" +The following code has been adapted from functions in the neurodsp package: +https://github.com/neurodsp-tools/neurodsp + +..todo: reference licence in LICENCE directory +""" + +def _check_n_cycles(n_cycles, len_cycles=None): + """ + Check an input as a number of cycles, and make it iterable. + + Parameters + ---------- + n_cycles : float or list + Definition of number of cycles to check. If a single value, the same number of cycles is used for each + frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. + len_cycles: int, optional + What the length of `n_cycles` should be, if it's a list. + + Returns + ------- + iter + An iterable version of the number of cycles. + """ + if isinstance(n_cycles, (int, float, np.number)): + if n_cycles <= 0: + raise ValueError("Number of cycles must be a positive number.") + n_cycles = repeat(n_cycles) + elif isinstance(n_cycles, (tuple, list, np.ndarray)): + for cycle in n_cycles: + if cycle <= 0: + raise ValueError("Each number of cycles must be a positive number.") + if len_cycles and len(n_cycles) != len_cycles: + raise ValueError( + "The length of number of cycles does not match other inputs." + ) + n_cycles = iter(n_cycles) + return n_cycles + + +def _create_freqs(freq_start, freq_stop, freq_step=1): + """ + Creates an array of frequencies. + + ..todo:: Implement log scaling + + Parameters + ---------- + freq_start : float + Starting value for the frequency definition. + freq_stop: float + Stopping value for the frequency definition, inclusive. + freq_step: float, optional + Step value, for linearly spaced values between start and stop. + + Returns + ------- + freqs: 1d array + Frequency indices. + """ + return np.arange(freq_start, freq_stop + freq_step, freq_step) + + +def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp"): + """ + Compute the time-frequency representation of a signal using morlet wavelets. + + ..todo:: better normalization between channels + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float + Sampling rate, in Hz. + freqs : 1d array or list of float + If array, frequency values to estimate with morlet wavelets. + If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + n_cycles : float or 1d array + Length of the filter, as the number of cycles for each frequency. + If 1d array, this defines n_cycles for each frequency. + scaling : float + Scaling factor. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + mwt : 2d array + Time frequency representation of the input signal. + + Notes + ----- + This computes the continuous wavelet transform at specified frequencies across time. + """ + if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): + raise TypeError("`sig` must be instance of Tsd or TsdFrame") + if isinstance(freqs, (tuple, list)): + freqs = _create_freqs(*freqs) + n_cycles = _check_n_cycles(n_cycles, len(freqs)) + if isinstance(sig, nap.Tsd): + mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + mwt[ind, :] = _convolve_wavelet(sig, + fs, + freq, + n_cycle, + scaling, + norm=norm) + return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + else: + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): + for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): + mwt[:, ind, channel_i] = _convolve_wavelet( + sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + ) + return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + + +def _convolve_wavelet( + sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm="sss" +): + """ + Convolve a signal with a complex wavelet. + + Parameters + ---------- + sig : pynapple.Tsd + Time series to filter. + fs : float + Sampling rate, in Hz. + freq : float + Center frequency of bandpass filter. + n_cycles : float, optional, default: 7 + Length of the filter, as the number of cycles of the oscillation with specified frequency. + scaling : float, optional, default: 0.5 + Scaling factor for the morlet wavelet. + wavelet_len : int, optional + Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. + norm : {'sss', 'amp'}, optional + Normalization method: + + * 'sss' - divide by the square root of the sum of squares + * 'amp' - divide by the sum of amplitudes + + Returns + ------- + array + Complex- valued time series. + + Notes + ----- + + * The real part of the returned array is the filtered signal. + * Taking np.abs() of output gives the analytic amplitude. + * Taking np.angle() of output gives the analytic phase. ..todo: this this still true? + """ + if norm not in ["sss", "amp"]: + raise ValueError("Given `norm` must be `sss` or `amp`") + if wavelet_len is None: + wavelet_len = int(n_cycles * fs / freq) + if wavelet_len > sig.shape[-1]: + raise ValueError( + "The length of the wavelet is greater than the signal. Can not proceed." + ) + morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) + if norm == "sss": + morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) + elif norm == "amp": + morlet_f = morlet_f / np.sum(np.abs(morlet_f)) + mwt_real = sig.convolve(np.real(morlet_f)) + mwt_imag = sig.convolve(np.imag(morlet_f)) + return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file From 4af8d90dc7013df6be7753a02cbbcbe4adf5f053 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 27 Jun 2024 22:48:45 +0100 Subject: [PATCH 004/195] wavlet workaround fixed, tutorial added --- docs/examples/tutorial_signal_processing.py | 235 ++++++++++++++++++++ pynapple/process/__init__.py | 1 - pynapple/process/signal_processing.py | 200 +++-------------- 3 files changed, 263 insertions(+), 173 deletions(-) create mode 100644 docs/examples/tutorial_signal_processing.py diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py new file mode 100644 index 00000000..c0d585ba --- /dev/null +++ b/docs/examples/tutorial_signal_processing.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +""" +Signal Processing Local Field Potentials +============ + +Working with Local Field Potential data. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +import numpy as np +import pynapple as nap +import pandas as pd +import os +#import requests +from zipfile import ZipFile +import matplotlib.pyplot as plt +import matplotlib +matplotlib.use("TkAgg") + +# %% +# *** +# Downloading the data +# ------------------ +# First things first: Let's download and extract the data +# path = "data/A2929-200711" +# extract_to = "data" +# if extract_to not in os.listdir("."): +# os.mkdir(extract_to) +# if path not in os.listdir("."): +# # Download the file +# response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") +# zip_path = os.path.join(extract_to, '/downloaded_file.zip') +# # Write the zip file to disk +# with open(zip_path, 'wb') as f: +# f.write(response.content) +# # Unzip the file +# with ZipFile(zip_path, 'r') as zip_ref: +# zip_ref.extractall(extract_to) + + +# %% +# *** +# Parsing the data +# ------------------ +# Now that we have the data, we must append the 2kHz LFP recording to the .nwb file +# eeg_path = "data/A2929-200711/A2929-200711.dat" +# frequency = 20000 # Hz +# n_channels = 16 +# f = open(eeg_path, 'rb') +# startoffile = f.seek(0, 0) +# endoffile = f.seek(0, 2) +# f.close() +# bytes_size = 2 +# n_samples = int((endoffile-startoffile)/n_channels/bytes_size) +# duration = n_samples/frequency +# interval = 1/frequency +# fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) +# timestep = np.arange(0, n_samples)/frequency +# eeg = nap.TsdFrame(t=timestep, d=fp) +# nap.append_NWB_LFP("data/A2929-200711/", +# eeg) + + +# %% +# Let's save the RoiResponseSeries as a variable called 'transients' and print it +FS = 1250 +# data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") +data = nap.load_file("data/stable.nwb") +print(data["ElectricalSeries"]) +# normed_electrical_series = data["ElectricalSeries"].values +# normed_electrical_series = normed_electrical_series[:, :] +# normed_electrical_series[:, :10] = normed_electrical_series[:, :10] - np.expand_dims(np.mean(normed_electrical_series[:, :10], axis=1), axis=1) +# normed_electrical_series[:, 10:] = normed_electrical_series[:, 10:] - np.expand_dims(np.mean(normed_electrical_series[:, 10:], axis=1), axis=1) +NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) + +# %% +# *** +# Selecting slices +# ----------------------------------- +# Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake +wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) +sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) + +# %% +# *** +# Plotting the LFP activity of one slices +# ----------------------------------- +# Let's plot +fig, ax = plt.subplots(2) +for channel in range(sleep_minute.shape[1]): + ax[0].plot(sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data") +ax[0].set_title("Sleep ephys") +for channel in range(wake_minute.shape[1]): + ax[1].plot(wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data") +ax[1].set_title("Wake ephys") +plt.show() + + +# %% +# There is much shared information between channels, and wake and sleep don't seem visibly different. +# Let's take the Fourier transforms of one channel for both and see if differences are present +channel = 1 +fig, ax = plt.subplots(1) +fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], + fs=int(FS)) +ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") +ax.set_xlim((0, FS/2 )) +fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], + fs=int(FS)) +ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") +ax.set_title(f"Fourier Decomposition for channel {channel}") +ax.legend() +plt.show() + + +# %% +# There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? +# Let's explore further with a wavelet decomposition + +def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect='auto', **kwargs) + ax.invert_yaxis() + ax.set_xlabel('Time (s)') + ax.set_ylabel('Frequency (Hz)') + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + if isinstance(y_ticks, int): + y_ticks_pos = np.linspace(0, freqs.size, y_ticks) + y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) + else: + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + +fig, ax = plt.subplots(2) +freqs = np.array([2.59, 3.66, 5.18, 8.0, 10.36, 14.65, 20.72, 29.3, 41.44, 58.59, 82.88, 117.19, + 165.75, 234.38, 331.5, 468.75, 624., ]) +mwt_sleep = nap.compute_wavelet_transform( + sleep_minute[:, channel], + fs=None, + freqs=freqs + ) +plot_timefrequency(sleep_minute.index.values[:], freqs[:], np.transpose(mwt_sleep[:,:].values), ax=ax[0]) +ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") +mwt_wake = nap.compute_wavelet_transform( + wake_minute[:, channel], + fs=None, + freqs=freqs + ) +plot_timefrequency(wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:,:].values), ax=ax[1]) +ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") +plt.margins(0) +plt.show() + +# %% +# Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data +freq = 3 +interval = (937, 939) +wake_second = wake_minute.value_from(wake_minute, nap.IntervalSet(interval[0],interval[1])) +fig, ax = plt.subplots(1) +ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(wake_second.index.values, + mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Theta oscillations") +ax.set_title(f"{freqs[freq]}Hz oscillation power.") +plt.show() + + +# %% +# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data +freq = 0 +# interval = (10, 15) +interval = (20, 25) +sleep_second = sleep_minute.value_from(sleep_minute, nap.IntervalSet(interval[0],interval[1])) +_, ax = plt.subplots(1) +ax.plot(sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(sleep_second.index.values, + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Slow Wave Oscillations") +ax.set_title(f"{freqs[freq]}Hz oscillation power") +plt.show() + +# %% +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep + +_, ax = plt.subplots(20, figsize=(10, 50)) +mwt_sleep = np.transpose(mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))) +ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) +plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) + +ax[2].plot(sleep_second.index, sleep_second.values[:, 0]) +ax[2].plot(sleep_second.index, mwt_sleep[freq, :].real) +ax[2].set_title(f"{freqs[freq]}Hz") + +ax[3].plot(sleep_second.index, np.abs(mwt_sleep[freq, :])) +# ax[3].plot(lfp.index, lfp.values[:,0]) +ax[4].plot(sleep_second.index, np.angle(mwt_sleep[freq, :])) + +spikes = {} +for i in data["units"].index: + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > interval[0]) & (data["units"][i].times() < interval[1])] + +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append(np.angle(mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))])) + phase[i] = np.array(phase_i) + +for i in range(15): + ax[5 + i].scatter(spikes[i], phase[i]) + ax[5 + i].set_xlim(interval[0], interval[1]) + ax[5 + i].set_ylim(-np.pi, np.pi) + ax[5 + i].set_xlabel("time (s)") + ax[5 + i].set_ylabel("phase") + +plt.tight_layout() +plt.show() \ No newline at end of file diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 120a3363..08d58648 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -26,7 +26,6 @@ ) from .signal_processing import ( compute_wavelet_transform, - compute_wavelet_transform_og, compute_spectrum, compute_welch_spectrum ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e2b7bf9f..cc9ab1bc 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -54,151 +54,6 @@ def compute_welch_spectrum(sig, fs=None): return spectogram, freqs -def compute_wavelet_transform(sig, freqs, fs=None): - """ - Compute the time-frequency representation of a signal using morlet wavelets. - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Time series. - freqs : 1d array or list of float - If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - fs : float, optional - Sampling rate, in Hz. If None, will be calculated from the given signal - - Returns - ------- - mwt : 2d array - Time frequency representation of the input signal. - - Notes - ----- - This computes the continuous wavelet transform at specified frequencies across time. - """ - if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) - assert fs/2 > np.max(freqs), "`freqs` contains values over the Nyquist frequency." - if isinstance(sig, nap.Tsd): - mwt, f = _cwt(sig, - freqs=freqs, - wavelet="cmor1.5-1.0", - sampling_period=1/fs) - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) - elif isinstance(sig, nap.TsdFrame): - mwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - for channel_i in range(sig.values.shape[1]): - mwt[:, :, channel_i] = np.transpose(_cwt(sig[:, channel_i], - freqs=freqs, - wavelet="cmor1.5-1.0", - sampling_period=1/fs)[0]) - return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) - elif isinstance(sig, nap.TsdTensor): - raise NotImplemented("cwt for TsdTensor is not yet implemented") - - -def _cwt(data, freqs, wavelet, sampling_period, axis=-1): - """ - cwt(data, scales, wavelet) - - One dimensional Continuous Wavelet Transform. - - Parameters - ---------- - data : pynapple.Tsd or pynapple.TsdFrame - Input time series signal. - freqs : 1d array - Frequency values to estimate with morlet wavelets. - wavelet : Wavelet object or name - Wavelet to use, only implemented for "cmor1.5-1.0". - sampling_period : float - Sampling period for the frequencies output. - The values computed for ``coefs`` are independent of the choice of - ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling - period). - axis: int, optional - Axis over which to compute the CWT. If not given, the last axis is - used. - - Returns - ------- - coefs : array_like - Continuous wavelet transform of the input signal for the given scales - and wavelet. The first axis of ``coefs`` corresponds to the scales. - The remaining axes match the shape of ``data``. - frequencies : array_like - If the unit of sampling period are seconds and given, then frequencies - are in hertz. Otherwise, a sampling period of 1 is assumed. - - ..todo:: This should use pynapple convolve but currently that cannot handle imaginary numbers as it uses scipy convolve - """ - int_psi = np.array(WAVELET_DICT[wavelet]['int_psi_real'])*1j + np.array(WAVELET_DICT[wavelet]['int_psi_imag']) - x = np.array(WAVELET_DICT[wavelet]["x"]) - central_freq = WAVELET_DICT[wavelet]["central_freq"] - scales = central_freq/(freqs*sampling_period) - out = np.empty((np.size(scales),) + data.shape, dtype=np.complex128) - - if data.ndim > 1: - # move axis to be transformed last (so it is contiguous) - data = data.swapaxes(-1, axis) - # reshape to (n_batch, data.shape[-1]) - data_shape_pre = data.shape - data = data.reshape((-1, data.shape[-1])) - - for i, scale in enumerate(scales): - step = x[1] - x[0] - j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] - - if data.ndim == 1: - conv = np.convolve(data, int_psi_scale) - else: - # batch convolution via loop - conv_shape = list(data.shape) - conv_shape[-1] += int_psi_scale.size - 1 - conv_shape = tuple(conv_shape) - conv = np.empty(conv_shape, dtype=np.complex128) - for n in range(data.shape[0]): - conv[n, :] = np.convolve(data[n], int_psi_scale) - - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) - if out.dtype.kind != 'c': - coef = coef.real - # transform axis is always -1 due to the data reshape above - d = (coef.shape[-1] - data.shape[-1]) / 2. - if d > 0: - coef = coef[..., floor(d):-ceil(d)] - elif d < 0: - raise ValueError( - f"Selected scale of {scale} too small.") - if data.ndim > 1: - # restore original data shape and axis position - coef = coef.reshape(data_shape_pre) - coef = coef.swapaxes(axis, -1) - out[i, ...] = coef - - frequencies = central_freq/scales - if np.isscalar(frequencies): - frequencies = np.array([frequencies]) - frequencies /= sampling_period - return out, frequencies - - - - - - -# ------------------------------------------------------------------------------- - def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ Defines the complex Morelet wavelet kernel @@ -222,13 +77,6 @@ def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): x = np.linspace(-precision, precision, M) return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) -""" -The following code has been adapted from functions in the neurodsp package: -https://github.com/neurodsp-tools/neurodsp - -..todo: reference licence in LICENCE directory -""" - def _check_n_cycles(n_cycles, len_cycles=None): """ Check an input as a number of cycles, and make it iterable. @@ -285,12 +133,10 @@ def _create_freqs(freq_start, freq_stop, freq_step=1): return np.arange(freq_start, freq_stop + freq_step, freq_step) -def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm="amp"): +def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="amp"): """ Compute the time-frequency representation of a signal using morlet wavelets. - ..todo:: better normalization between channels - Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -325,6 +171,8 @@ def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm=" raise TypeError("`sig` must be instance of Tsd or TsdFrame") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) + if fs is None: + fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): mwt = np.zeros([len(freqs), len(sig)], dtype=complex) @@ -349,7 +197,7 @@ def compute_wavelet_transform_og(sig, fs, freqs, n_cycles=7, scaling=0.5, norm=" def _convolve_wavelet( - sig, fs, freq, n_cycles=7, scaling=0.5, wavelet_len=None, norm="sss" + sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm="sss" ): """ Convolve a signal with a complex wavelet. @@ -366,8 +214,6 @@ def _convolve_wavelet( Length of the filter, as the number of cycles of the oscillation with specified frequency. scaling : float, optional, default: 0.5 Scaling factor for the morlet wavelet. - wavelet_len : int, optional - Length of the wavelet. If defined, this overrides the freq and n_cycles inputs. norm : {'sss', 'amp'}, optional Normalization method: @@ -384,21 +230,31 @@ def _convolve_wavelet( * The real part of the returned array is the filtered signal. * Taking np.abs() of output gives the analytic amplitude. - * Taking np.angle() of output gives the analytic phase. ..todo: this this still true? + * Taking np.angle() of output gives the analytic phase. """ if norm not in ["sss", "amp"]: raise ValueError("Given `norm` must be `sss` or `amp`") - if wavelet_len is None: - wavelet_len = int(n_cycles * fs / freq) - if wavelet_len > sig.shape[-1]: + morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) + x = np.linspace(-8, 8, int(2**precision)) + int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + scale = scaling / (freq/fs) + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + conv = np.convolve(sig, int_psi_scale) + coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + # transform axis is always -1 due to the data reshape above + d = (coef.shape[-1] - sig.shape[-1]) / 2. + if d > 0: + coef = coef[..., floor(d):-ceil(d)] + elif d < 0: raise ValueError( - "The length of the wavelet is greater than the signal. Can not proceed." - ) - morlet_f = morlet(wavelet_len, ncycles=n_cycles, scaling=scaling) - if norm == "sss": - morlet_f = morlet_f / np.sqrt(np.sum(np.abs(morlet_f) ** 2)) - elif norm == "amp": - morlet_f = morlet_f / np.sum(np.abs(morlet_f)) - mwt_real = sig.convolve(np.real(morlet_f)) - mwt_imag = sig.convolve(np.imag(morlet_f)) - return mwt_real.values + 1j * mwt_imag.values \ No newline at end of file + f"Selected scale of {scale} too small.") + return coef + +def _integrate(arr, step): + integral = np.cumsum(arr) + integral *= step + return integral From 01c5435d0fd5be04f5c83ee812e242a0193a9d09 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:36:51 +0100 Subject: [PATCH 005/195] tutorial cleaning --- docs/examples/tutorial_signal_processing.py | 69 ++++++++++----------- 1 file changed, 32 insertions(+), 37 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index c0d585ba..6955df5c 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -22,9 +22,8 @@ import numpy as np import pynapple as nap -import pandas as pd import os -#import requests +import requests from zipfile import ZipFile import matplotlib.pyplot as plt import matplotlib @@ -35,20 +34,20 @@ # Downloading the data # ------------------ # First things first: Let's download and extract the data -# path = "data/A2929-200711" -# extract_to = "data" -# if extract_to not in os.listdir("."): -# os.mkdir(extract_to) -# if path not in os.listdir("."): -# # Download the file -# response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") -# zip_path = os.path.join(extract_to, '/downloaded_file.zip') -# # Write the zip file to disk -# with open(zip_path, 'wb') as f: -# f.write(response.content) -# # Unzip the file -# with ZipFile(zip_path, 'r') as zip_ref: -# zip_ref.extractall(extract_to) +path = "data/A2929-200711" +extract_to = "data" +if extract_to not in os.listdir("."): + os.mkdir(extract_to) +if path not in os.listdir("."): +# Download the file + response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") + zip_path = os.path.join(extract_to, '/downloaded_file.zip') + # Write the zip file to disk + with open(zip_path, 'wb') as f: + f.write(response.content) + # Unzip the file + with ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to) # %% @@ -56,22 +55,22 @@ # Parsing the data # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -# eeg_path = "data/A2929-200711/A2929-200711.dat" -# frequency = 20000 # Hz -# n_channels = 16 -# f = open(eeg_path, 'rb') -# startoffile = f.seek(0, 0) -# endoffile = f.seek(0, 2) -# f.close() -# bytes_size = 2 -# n_samples = int((endoffile-startoffile)/n_channels/bytes_size) -# duration = n_samples/frequency -# interval = 1/frequency -# fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) -# timestep = np.arange(0, n_samples)/frequency -# eeg = nap.TsdFrame(t=timestep, d=fp) -# nap.append_NWB_LFP("data/A2929-200711/", -# eeg) +eeg_path = "data/A2929-200711/A2929-200711.dat" +frequency = 20000 # Hz +n_channels = 16 +f = open(eeg_path, 'rb') +startoffile = f.seek(0, 0) +endoffile = f.seek(0, 2) +f.close() +bytes_size = 2 +n_samples = int((endoffile-startoffile)/n_channels/bytes_size) +duration = n_samples/frequency +interval = 1/frequency +fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) +timestep = np.arange(0, n_samples)/frequency +eeg = nap.TsdFrame(t=timestep, d=fp) +nap.append_NWB_LFP("data/A2929-200711/", + eeg) # %% @@ -80,17 +79,13 @@ # data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") data = nap.load_file("data/stable.nwb") print(data["ElectricalSeries"]) -# normed_electrical_series = data["ElectricalSeries"].values -# normed_electrical_series = normed_electrical_series[:, :] -# normed_electrical_series[:, :10] = normed_electrical_series[:, :10] - np.expand_dims(np.mean(normed_electrical_series[:, :10], axis=1), axis=1) -# normed_electrical_series[:, 10:] = normed_electrical_series[:, 10:] - np.expand_dims(np.mean(normed_electrical_series[:, 10:], axis=1), axis=1) -NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) # %% # *** # Selecting slices # ----------------------------------- # Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake +NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) From 027878e71e6c6bb290c614a7e26b4e0f370451c9 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:41:44 +0100 Subject: [PATCH 006/195] linting --- docs/examples/tutorial_signal_processing.py | 168 +++++++++++++------- pynapple/process/signal_processing.py | 54 ++++--- 2 files changed, 141 insertions(+), 81 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 6955df5c..b3e2f9af 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,14 +20,14 @@ # # Now, import the necessary libraries: -import numpy as np -import pynapple as nap import os -import requests from zipfile import ZipFile + import matplotlib.pyplot as plt -import matplotlib -matplotlib.use("TkAgg") +import numpy as np +import requests + +import pynapple as nap # %% # *** @@ -39,14 +39,16 @@ if extract_to not in os.listdir("."): os.mkdir(extract_to) if path not in os.listdir("."): -# Download the file - response = requests.get("https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1") - zip_path = os.path.join(extract_to, '/downloaded_file.zip') + # Download the file + response = requests.get( + "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" + ) + zip_path = os.path.join(extract_to, "/downloaded_file.zip") # Write the zip file to disk - with open(zip_path, 'wb') as f: + with open(zip_path, "wb") as f: f.write(response.content) # Unzip the file - with ZipFile(zip_path, 'r') as zip_ref: + with ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_to) @@ -56,21 +58,20 @@ # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file eeg_path = "data/A2929-200711/A2929-200711.dat" -frequency = 20000 # Hz +frequency = 20000 # Hz n_channels = 16 -f = open(eeg_path, 'rb') +f = open(eeg_path, "rb") startoffile = f.seek(0, 0) endoffile = f.seek(0, 2) f.close() bytes_size = 2 -n_samples = int((endoffile-startoffile)/n_channels/bytes_size) -duration = n_samples/frequency -interval = 1/frequency -fp = np.memmap(eeg_path, np.int16, 'r', shape = (n_samples, n_channels)) -timestep = np.arange(0, n_samples)/frequency +n_samples = int((endoffile - startoffile) / n_channels / bytes_size) +duration = n_samples / frequency +interval = 1 / frequency +fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) +timestep = np.arange(0, n_samples) / frequency eeg = nap.TsdFrame(t=timestep, d=fp) -nap.append_NWB_LFP("data/A2929-200711/", - eeg) +nap.append_NWB_LFP("data/A2929-200711/", eeg) # %% @@ -85,9 +86,13 @@ # Selecting slices # ----------------------------------- # Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake -NES = nap.TsdFrame(t=data["ElectricalSeries"].index.values, d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support) -wake_minute = NES.value_from(NES, nap.IntervalSet(900,960)) -sleep_minute = NES.value_from(NES, nap.IntervalSet(0,60)) +NES = nap.TsdFrame( + t=data["ElectricalSeries"].index.values, + d=data["ElectricalSeries"].values, + time_support=data["ElectricalSeries"].time_support, +) +wake_minute = NES.value_from(NES, nap.IntervalSet(900, 960)) +sleep_minute = NES.value_from(NES, nap.IntervalSet(0, 60)) # %% # *** @@ -96,10 +101,17 @@ # Let's plot fig, ax = plt.subplots(2) for channel in range(sleep_minute.shape[1]): - ax[0].plot(sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data") + ax[0].plot( + sleep_minute.index.values, + sleep_minute[:, channel], + alpha=0.5, + label="Sleep Data", + ) ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): - ax[1].plot(wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data") + ax[1].plot( + wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data" + ) ax[1].set_title("Wake ephys") plt.show() @@ -109,12 +121,10 @@ # Let's take the Fourier transforms of one channel for both and see if differences are present channel = 1 fig, ax = plt.subplots(1) -fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], - fs=int(FS)) +fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") -ax.set_xlim((0, FS/2 )) -fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], - fs=int(FS)) +ax.set_xlim((0, FS / 2)) +fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") ax.set_title(f"Fourier Decomposition for channel {channel}") ax.legend() @@ -125,13 +135,14 @@ # There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? # Let's explore further with a wavelet decomposition + def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) - ax.imshow(powers, aspect='auto', **kwargs) + ax.imshow(powers, aspect="auto", **kwargs) ax.invert_yaxis() - ax.set_xlabel('Time (s)') - ax.set_ylabel('Frequency (Hz)') + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") if isinstance(x_ticks, int): x_tick_pos = np.linspace(0, times.size, x_ticks) x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) @@ -145,22 +156,43 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + fig, ax = plt.subplots(2) -freqs = np.array([2.59, 3.66, 5.18, 8.0, 10.36, 14.65, 20.72, 29.3, 41.44, 58.59, 82.88, 117.19, - 165.75, 234.38, 331.5, 468.75, 624., ]) +freqs = np.array( + [ + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 14.65, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 165.75, + 234.38, + 331.5, + 468.75, + 624.0, + ] +) mwt_sleep = nap.compute_wavelet_transform( - sleep_minute[:, channel], - fs=None, - freqs=freqs - ) -plot_timefrequency(sleep_minute.index.values[:], freqs[:], np.transpose(mwt_sleep[:,:].values), ax=ax[0]) + sleep_minute[:, channel], fs=None, freqs=freqs +) +plot_timefrequency( + sleep_minute.index.values[:], + freqs[:], + np.transpose(mwt_sleep[:, :].values), + ax=ax[0], +) ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") -mwt_wake = nap.compute_wavelet_transform( - wake_minute[:, channel], - fs=None, - freqs=freqs - ) -plot_timefrequency(wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:,:].values), ax=ax[1]) +mwt_wake = nap.compute_wavelet_transform(wake_minute[:, channel], fs=None, freqs=freqs) +plot_timefrequency( + wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:, :].values), ax=ax[1] +) ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") plt.margins(0) plt.show() @@ -169,11 +201,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 interval = (937, 939) -wake_second = wake_minute.value_from(wake_minute, nap.IntervalSet(interval[0],interval[1])) +wake_second = wake_minute.value_from( + wake_minute, nap.IntervalSet(interval[0], interval[1]) +) fig, ax = plt.subplots(1) ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot(wake_second.index.values, - mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Theta oscillations") +ax.plot( + wake_second.index.values, + mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0], interval[1]))[ + :, freq + ].values.real, + label="Theta oscillations", +) ax.set_title(f"{freqs[freq]}Hz oscillation power.") plt.show() @@ -183,11 +222,20 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw freq = 0 # interval = (10, 15) interval = (20, 25) -sleep_second = sleep_minute.value_from(sleep_minute, nap.IntervalSet(interval[0],interval[1])) +sleep_second = sleep_minute.value_from( + sleep_minute, nap.IntervalSet(interval[0], interval[1]) +) _, ax = plt.subplots(1) -ax.plot(sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot(sleep_second.index.values, - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))[:, freq].values.real, label="Slow Wave Oscillations") +ax.plot( + sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data" +) +ax.plot( + sleep_second.index.values, + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1]))[ + :, freq + ].values.real, + label="Slow Wave Oscillations", +) ax.set_title(f"{freqs[freq]}Hz oscillation power") plt.show() @@ -195,7 +243,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep _, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose(mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0],interval[1]))) +mwt_sleep = np.transpose( + mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1])) +) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) @@ -210,13 +260,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw spikes = {} for i in data["units"].index: spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) & (data["units"][i].times() < interval[1])] + (data["units"][i].times() > interval[0]) + & (data["units"][i].times() < interval[1]) + ] phase = {} for i in spikes.keys(): phase_i = [] for spike in spikes[i]: - phase_i.append(np.angle(mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))])) + phase_i.append( + np.angle( + mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))] + ) + ) phase[i] = np.array(phase_i) for i in range(15): @@ -227,4 +283,4 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ax[5 + i].set_ylabel("phase") plt.tight_layout() -plt.show() \ No newline at end of file +plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index cc9ab1bc..21ab311c 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,21 +4,23 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -import numpy as np -import pynapple as nap -from math import ceil, floor import json -from scipy.signal import welch from itertools import repeat +from math import ceil, floor + +import numpy as np +from scipy.signal import welch -with open('wavelets.json') as f: +import pynapple as nap + +with open("wavelets.json") as f: WAVELET_DICT = json.load(f) def compute_spectrum(sig, fs=None): - """ - Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor + """ + Performs numpy fft on sig, returns output + ..todo: Make sig handle TsdFrame, TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -29,7 +31,7 @@ def compute_spectrum(sig, fs=None): if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) return fft_result, fft_freq @@ -49,7 +51,7 @@ def compute_welch_spectrum(sig, fs=None): if not isinstance(sig, nap.Tsd): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) freqs, spectogram = welch(sig.values, fs=fs) return spectogram, freqs @@ -75,7 +77,12 @@ def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): Morelet wavelet kernel """ x = np.linspace(-precision, precision, M) - return ((np.pi*ncycles) ** (-0.25)) * np.exp(-x**2 / ncycles) * np.exp(1j * 2*np.pi * scaling * x) + return ( + ((np.pi * ncycles) ** (-0.25)) + * np.exp(-(x**2) / ncycles) + * np.exp(1j * 2 * np.pi * scaling * x) + ) + def _check_n_cycles(n_cycles, len_cycles=None): """ @@ -172,18 +179,15 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: - fs = sig.index.shape[0]/(sig.index.max() - sig.index.min()) + fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): mwt = np.zeros([len(freqs), len(sig)], dtype=complex) for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[ind, :] = _convolve_wavelet(sig, - fs, - freq, - n_cycle, - scaling, - norm=norm) - return nap.TsdFrame(t=sig.index, d=np.transpose(mwt), time_support=sig.time_support) + mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + return nap.TsdFrame( + t=sig.index, d=np.transpose(mwt), time_support=sig.time_support + ) else: mwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex @@ -237,23 +241,23 @@ def _convolve_wavelet( morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) - scale = scaling / (freq/fs) + scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] conv = np.convolve(sig, int_psi_scale) - coef = - np.sqrt(scale) * np.diff(conv, axis=-1) + coef = -np.sqrt(scale) * np.diff(conv, axis=-1) # transform axis is always -1 due to the data reshape above - d = (coef.shape[-1] - sig.shape[-1]) / 2. + d = (coef.shape[-1] - sig.shape[-1]) / 2.0 if d > 0: - coef = coef[..., floor(d):-ceil(d)] + coef = coef[..., floor(d) : -ceil(d)] elif d < 0: - raise ValueError( - f"Selected scale of {scale} too small.") + raise ValueError(f"Selected scale of {scale} too small.") return coef + def _integrate(arr, step): integral = np.cumsum(arr) integral *= step From 18b3bc1da0561c7e43ba3ddd5a1b5fa3031b2a0a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:44:59 +0100 Subject: [PATCH 007/195] more linting --- pynapple/process/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 08d58648..1cb2f735 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -15,6 +15,11 @@ shift_timestamps, shuffle_ts_intervals, ) +from .signal_processing import ( + compute_spectrum, + compute_wavelet_transform, + compute_welch_spectrum, +) from .tuning_curves import ( compute_1d_mutual_info, compute_1d_tuning_curves, @@ -24,8 +29,3 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) -from .signal_processing import ( - compute_wavelet_transform, - compute_spectrum, - compute_welch_spectrum -) From ebdbe67320d23b5b23f6b52590c59033926e7463 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 16:50:05 +0100 Subject: [PATCH 008/195] json removal --- pynapple/process/signal_processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 21ab311c..5f4e969e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -13,9 +13,6 @@ import pynapple as nap -with open("wavelets.json") as f: - WAVELET_DICT = json.load(f) - def compute_spectrum(sig, fs=None): """ From 75d3a460525683896de25e3b67e41a2d65b12377 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 28 Jun 2024 17:52:16 +0100 Subject: [PATCH 009/195] basic tests added --- pynapple/process/signal_processing.py | 4 +-- tests/test_signal_processing.py | 37 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) create mode 100644 tests/test_signal_processing.py diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 5f4e969e..245825f7 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -53,7 +53,7 @@ def compute_welch_spectrum(sig, fs=None): return spectogram, freqs -def morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): +def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ Defines the complex Morelet wavelet kernel @@ -235,7 +235,7 @@ def _convolve_wavelet( """ if norm not in ["sss", "amp"]: raise ValueError("Given `norm` must be `sss` or `amp`") - morlet_f = morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) + morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) scale = scaling / (freq / fs) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py new file mode 100644 index 00000000..8da68921 --- /dev/null +++ b/tests/test_signal_processing.py @@ -0,0 +1,37 @@ +"""Tests of `signal_processing` for pynapple""" + +import numpy as np +import pytest + +import pynapple as nap + + +def test_compute_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_spectrum(sig) + assert len(r[1]) == 1024 + assert len(r[0]) == 1024 + assert r[0].dtype == np.complex128 + assert r[1].dtype == np.float64 + + +def test_compute_welch_spectrum(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + r = nap.compute_welch_spectrum(sig) + assert r[0].dtype == np.float64 + assert r[1].dtype == np.float64 + + +def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10) + + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4) From 334e785fceaf4ba022718dd7899a4dea64da953d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 1 Jul 2024 15:08:10 +0100 Subject: [PATCH 010/195] remove unused import --- pynapple/process/signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 245825f7..15da4906 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,7 +4,6 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -import json from itertools import repeat from math import ceil, floor From a3ab81cffe06cd3d7e3ae086f155e7f6a35925d3 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 2 Jul 2024 19:43:30 +0100 Subject: [PATCH 011/195] minor notebook changes --- docs/examples/tutorial_signal_processing.py | 59 ++++++++++----------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index b3e2f9af..c7da10e9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -57,21 +57,21 @@ # Parsing the data # ------------------ # Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -eeg_path = "data/A2929-200711/A2929-200711.dat" -frequency = 20000 # Hz -n_channels = 16 -f = open(eeg_path, "rb") -startoffile = f.seek(0, 0) -endoffile = f.seek(0, 2) -f.close() -bytes_size = 2 -n_samples = int((endoffile - startoffile) / n_channels / bytes_size) -duration = n_samples / frequency -interval = 1 / frequency -fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) -timestep = np.arange(0, n_samples) / frequency -eeg = nap.TsdFrame(t=timestep, d=fp) -nap.append_NWB_LFP("data/A2929-200711/", eeg) +# eeg_path = "data/A2929-200711/A2929-200711.dat" +# frequency = 20000 # Hz +# n_channels = 16 +# f = open(eeg_path, "rb") +# startoffile = f.seek(0, 0) +# endoffile = f.seek(0, 2) +# f.close() +# bytes_size = 2 +# n_samples = int((endoffile - startoffile) / n_channels / bytes_size) +# duration = n_samples / frequency +# interval = 1 / frequency +# fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) +# timestep = np.arange(0, n_samples) / frequency +# eeg = nap.TsdFrame(t=timestep, d=fp) +# nap.append_NWB_LFP("data/A2929-200711/", eeg) # %% @@ -91,8 +91,9 @@ d=data["ElectricalSeries"].values, time_support=data["ElectricalSeries"].time_support, ) -wake_minute = NES.value_from(NES, nap.IntervalSet(900, 960)) -sleep_minute = NES.value_from(NES, nap.IntervalSet(0, 60)) +wake_minute = NES.restrict(nap.IntervalSet(900, 960)) +sleep_minute = NES.restrict(nap.IntervalSet(0, 60)) +channel = 1 # %% # *** @@ -102,7 +103,6 @@ fig, ax = plt.subplots(2) for channel in range(sleep_minute.shape[1]): ax[0].plot( - sleep_minute.index.values, sleep_minute[:, channel], alpha=0.5, label="Sleep Data", @@ -110,7 +110,9 @@ ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): ax[1].plot( - wake_minute.index.values, wake_minute[:, channel], alpha=0.5, label="Wake Data" + wake_minute[:, channel], + alpha=0.5, + label="Wake Data" ) ax[1].set_title("Wake ephys") plt.show() @@ -119,7 +121,6 @@ # %% # There is much shared information between channels, and wake and sleep don't seem visibly different. # Let's take the Fourier transforms of one channel for both and see if differences are present -channel = 1 fig, ax = plt.subplots(1) fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") @@ -201,14 +202,13 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 interval = (937, 939) -wake_second = wake_minute.value_from( - wake_minute, nap.IntervalSet(interval[0], interval[1]) -) +wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( wake_second.index.values, - mwt_wake.value_from(mwt_wake, nap.IntervalSet(interval[0], interval[1]))[ + mwt_wake_second[ :, freq ].values.real, label="Theta oscillations", @@ -222,16 +222,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw freq = 0 # interval = (10, 15) interval = (20, 25) -sleep_second = sleep_minute.value_from( - sleep_minute, nap.IntervalSet(interval[0], interval[1]) -) +sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) ax.plot( - sleep_second.index.values, sleep_second[:, channel], alpha=0.5, label="Wake Data" + sleep_second[:, channel], alpha=0.5, label="Wake Data" ) ax.plot( sleep_second.index.values, - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1]))[ + mwt_sleep_second[ :, freq ].values.real, label="Slow Wave Oscillations", @@ -244,7 +243,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw _, ax = plt.subplots(20, figsize=(10, 50)) mwt_sleep = np.transpose( - mwt_sleep.value_from(mwt_sleep, nap.IntervalSet(interval[0], interval[1])) + mwt_sleep_second ) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) From 80c12d76a6c0ae3a7168eda74e14fe9dd7c7d558 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 2 Jul 2024 21:04:15 +0100 Subject: [PATCH 012/195] spectogram now takes tdsframe --- pynapple/process/signal_processing.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 15da4906..ad1a9b8e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -8,6 +8,7 @@ from math import ceil, floor import numpy as np +import pandas as pd from scipy.signal import welch import pynapple as nap @@ -16,7 +17,7 @@ def compute_spectrum(sig, fs=None): """ Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor + ..todo: Make sig handle TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -24,13 +25,13 @@ def compute_spectrum(sig, fs=None): fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal """ - if not isinstance(sig, nap.Tsd): - raise TypeError("Currently compute_fft is only implemented for Tsd") + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("Currently compute_fft is only implemented for Tsd or TsdFrame") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - fft_result = np.fft.fft(sig.values) + fft_result = np.fft.fft(sig.values, axis=0) fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) - return fft_result, fft_freq + return pd.DataFrame(fft_result, fft_freq) def compute_welch_spectrum(sig, fs=None): From c1a5a26eb31e058f277375fc81e3cebe4abe4084 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 16:49:26 +0100 Subject: [PATCH 013/195] review changes --- docs/examples/tutorial_signal_processing.py | 144 ++++++++++---------- pynapple/process/__init__.py | 4 +- pynapple/process/signal_processing.py | 12 +- 3 files changed, 81 insertions(+), 79 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index c7da10e9..a4d58520 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,12 +20,8 @@ # # Now, import the necessary libraries: -import os -from zipfile import ZipFile - import matplotlib.pyplot as plt import numpy as np -import requests import pynapple as nap @@ -33,51 +29,28 @@ # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data -path = "data/A2929-200711" -extract_to = "data" -if extract_to not in os.listdir("."): - os.mkdir(extract_to) -if path not in os.listdir("."): - # Download the file - response = requests.get( - "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" - ) - zip_path = os.path.join(extract_to, "/downloaded_file.zip") - # Write the zip file to disk - with open(zip_path, "wb") as f: - f.write(response.content) - # Unzip the file - with ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(extract_to) - - -# %% -# *** -# Parsing the data -# ------------------ -# Now that we have the data, we must append the 2kHz LFP recording to the .nwb file -# eeg_path = "data/A2929-200711/A2929-200711.dat" -# frequency = 20000 # Hz -# n_channels = 16 -# f = open(eeg_path, "rb") -# startoffile = f.seek(0, 0) -# endoffile = f.seek(0, 2) -# f.close() -# bytes_size = 2 -# n_samples = int((endoffile - startoffile) / n_channels / bytes_size) -# duration = n_samples / frequency -# interval = 1 / frequency -# fp = np.memmap(eeg_path, np.int16, "r", shape=(n_samples, n_channels)) -# timestep = np.arange(0, n_samples) / frequency -# eeg = nap.TsdFrame(t=timestep, d=fp) -# nap.append_NWB_LFP("data/A2929-200711/", eeg) +# First things first: Let's download and extract the data - currently commented as correct NWB is not online +# path = "data/A2929-200711" +# extract_to = "data" +# if extract_to not in os.listdir("."): +# os.mkdir(extract_to) +# if path not in os.listdir("."): +# # Download the file +# response = requests.get( +# "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" +# ) +# zip_path = os.path.join(extract_to, "/downloaded_file.zip") +# # Write the zip file to disk +# with open(zip_path, "wb") as f: +# f.write(response.content) +# # Unzip the file +# with ZipFile(zip_path, "r") as zip_ref: +# zip_ref.extractall(extract_to) # %% # Let's save the RoiResponseSeries as a variable called 'transients' and print it FS = 1250 -# data = nap.load_file("data/A2929-200711/pynapplenwb/A2929-200711.nwb") data = nap.load_file("data/stable.nwb") print(data["ElectricalSeries"]) @@ -109,29 +82,66 @@ ) ax[0].set_title("Sleep ephys") for channel in range(wake_minute.shape[1]): - ax[1].plot( - wake_minute[:, channel], - alpha=0.5, - label="Wake Data" - ) + ax[1].plot(wake_minute[:, channel], alpha=0.5, label="Wake Data") ax[1].set_title("Wake ephys") plt.show() # %% -# There is much shared information between channels, and wake and sleep don't seem visibly different. -# Let's take the Fourier transforms of one channel for both and see if differences are present -fig, ax = plt.subplots(1) -fft_sig, fft_freqs = nap.compute_spectrum(sleep_minute[:, channel], fs=int(FS)) -ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Sleep Data") -ax.set_xlim((0, FS / 2)) -fft_sig, fft_freqs = nap.compute_spectrum(wake_minute[:, channel], fs=int(FS)) -ax.plot(fft_freqs, np.abs(fft_sig), alpha=0.5, label="Wake Data") -ax.set_title(f"Fourier Decomposition for channel {channel}") -ax.legend() +# Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present +channel = 1 +fig, ax = plt.subplots(2) +fft = nap.compute_spectogram(sleep_minute, fs=int(FS)) +ax[0].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" +) +ax[0].set_xlim((0, FS / 2)) +ax[0].set_xlabel("Freq (Hz)") +ax[0].set_ylabel("Frequency Power") + +ax[0].set_title("Sleep LFP Decomposition") +fft = nap.compute_spectogram(wake_minute, fs=int(FS)) +ax[1].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" +) +ax[1].set_xlim((0, FS / 2)) +fig.suptitle(f"Fourier Decomposition for channel {channel}") +ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_xlabel("Freq (Hz)") +ax[1].set_ylabel("Frequency Power") + +# ax.legend() plt.show() +# %% +# Let's now consider the Welch spectograms of waking and sleeping data... + +fig, ax = plt.subplots(2) +fft = nap.compute_welch_spectogram(sleep_minute, fs=int(FS)) +ax[0].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", color="blue" +) +ax[0].set_xlim((0, FS / 2)) +ax[0].set_title("Sleep LFP Decomposition") +ax[0].set_xlabel("Freq (Hz)") +ax[0].set_ylabel("Frequency Power") +welch = nap.compute_welch_spectogram(wake_minute, fs=int(FS)) +ax[1].plot( + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="Wake Data", + color="orange", +) +ax[1].set_xlim((0, FS / 2)) +fig.suptitle(f"Welch Decomposition for channel {channel}") +ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_xlabel("Freq (Hz)") +ax[1].set_ylabel("Frequency Power") +# ax.legend() +plt.show() + # %% # There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? # Let's explore further with a wavelet decomposition @@ -208,9 +218,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( wake_second.index.values, - mwt_wake_second[ - :, freq - ].values.real, + mwt_wake_second[:, freq].values.real, label="Theta oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power.") @@ -225,14 +233,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) -ax.plot( - sleep_second[:, channel], alpha=0.5, label="Wake Data" -) +ax.plot(sleep_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( sleep_second.index.values, - mwt_sleep_second[ - :, freq - ].values.real, + mwt_sleep_second[:, freq].values.real, label="Slow Wave Oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power") @@ -242,9 +246,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep _, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose( - mwt_sleep_second -) +mwt_sleep = np.transpose(mwt_sleep_second) ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 1cb2f735..fb7e22b9 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,9 +16,9 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_spectrum, + compute_spectogram, compute_wavelet_transform, - compute_welch_spectrum, + compute_welch_spectogram, ) from .tuning_curves import ( compute_1d_mutual_info, diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index ad1a9b8e..3bb4be8e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,7 +14,7 @@ import pynapple as nap -def compute_spectrum(sig, fs=None): +def compute_spectogram(sig, fs=None): """ Performs numpy fft on sig, returns output ..todo: Make sig handle TsdTensor @@ -34,7 +34,7 @@ def compute_spectrum(sig, fs=None): return pd.DataFrame(fft_result, fft_freq) -def compute_welch_spectrum(sig, fs=None): +def compute_welch_spectogram(sig, fs=None): """ Performs scipy Welch's decomposition on sig, returns output ..todo: Make sig handle TsdFrame, TsdTensor @@ -45,12 +45,12 @@ def compute_welch_spectrum(sig, fs=None): fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal """ - if not isinstance(sig, nap.Tsd): + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError("Currently compute_fft is only implemented for Tsd") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - freqs, spectogram = welch(sig.values, fs=fs) - return spectogram, freqs + freqs, spectogram = welch(sig.values, fs=fs, axis=0) + return pd.DataFrame(spectogram, freqs) def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): @@ -145,7 +145,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float + fs : float or None Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. From dfe6e411e8f2fed704a770f7da339d85ac754e38 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 16:55:25 +0100 Subject: [PATCH 014/195] updated function names in test --- tests/test_signal_processing.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8da68921..b2d3c8b9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -6,20 +6,20 @@ import pynapple as nap -def test_compute_spectrum(): +def test_compute_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_spectrum(sig) + r = nap.compute_spectogram(sig) assert len(r[1]) == 1024 assert len(r[0]) == 1024 assert r[0].dtype == np.complex128 assert r[1].dtype == np.float64 -def test_compute_welch_spectrum(): +def test_ccompute_welch_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_welch_spectrum(sig) + r = nap.compute_welch_spectogram(sig) assert r[0].dtype == np.float64 assert r[1].dtype == np.float64 From 3a9173a9149ac7f760c3fb05642221ae4c942acc Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 17:59:09 +0100 Subject: [PATCH 015/195] updated tests --- tests/test_signal_processing.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b2d3c8b9..85ed3b8e 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,6 +1,7 @@ """Tests of `signal_processing` for pynapple""" import numpy as np +import pandas as pd import pytest import pynapple as nap @@ -10,18 +11,15 @@ def test_compute_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_spectogram(sig) - assert len(r[1]) == 1024 - assert len(r[0]) == 1024 - assert r[0].dtype == np.complex128 - assert r[1].dtype == np.float64 + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 1024 -def test_ccompute_welch_spectogram(): +def test_compute_welch_spectogram(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_welch_spectogram(sig) - assert r[0].dtype == np.float64 - assert r[1].dtype == np.float64 + assert isinstance(r, pd.DataFrame) def test_compute_wavelet_transform(): From cfb606630b37cb97d828a338d83d2032280987ae Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 3 Jul 2024 19:01:32 +0100 Subject: [PATCH 016/195] expanded test coverage --- pynapple/process/signal_processing.py | 8 +++- tests/test_signal_processing.py | 56 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 3bb4be8e..a38b56ab 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -26,7 +26,9 @@ def compute_spectogram(sig, fs=None): Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("Currently compute_fft is only implemented for Tsd or TsdFrame") + raise TypeError( + "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) fft_result = np.fft.fft(sig.values, axis=0) @@ -46,7 +48,9 @@ def compute_welch_spectogram(sig, fs=None): Sampling rate, in Hz. If None, will be calculated from the given signal """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("Currently compute_fft is only implemented for Tsd") + raise TypeError( + "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" + ) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) freqs, spectogram = welch(sig.values, fs=fs, axis=0) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 85ed3b8e..fd728efe 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -14,6 +14,18 @@ def test_compute_spectogram(): assert isinstance(r, pd.DataFrame) assert r.shape[0] == 1024 + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + r = nap.compute_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1024, 4) + + with pytest.raises(TypeError) as e_info: + nap.compute_spectogram("a_string") + assert ( + str(e_info.value) + == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) + def test_compute_welch_spectogram(): t = np.linspace(0, 1, 1024) @@ -21,6 +33,18 @@ def test_compute_welch_spectogram(): r = nap.compute_welch_spectogram(sig) assert isinstance(r, pd.DataFrame) + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + r = nap.compute_welch_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[1] == 4 + + with pytest.raises(TypeError) as e_info: + nap.compute_welch_spectogram("a_string") + assert ( + str(e_info.value) + == "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" + ) + def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1024) @@ -29,7 +53,39 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10) + t = np.linspace(0, 1, 1024) + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = (1, 51, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 6) + + sig = nap.Tsd(d=np.random.random(1024), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 5)) + ) + assert mwt.shape == (1024, 10) + sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) + assert str(e_info.value) == "Number of cycles must be a positive number." + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, -2.5), 5)) + ) + assert str(e_info.value) == "Each number of cycles must be a positive number." + + with pytest.raises(ValueError) as e_info: + nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 2)) + ) + assert ( + str(e_info.value) + == "The length of number of cycles does not match other inputs." + ) From a3d30eaa77093567dbeb2276dc3f637d5b2fc744 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 5 Jul 2024 12:30:14 +0200 Subject: [PATCH 017/195] Fixing kwargs --- pynapple/core/interval_set.py | 2 +- pynapple/core/time_series.py | 6 +++++- pynapple/core/ts_group.py | 13 ++++++++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 3f65f802..9f9798d9 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -80,7 +80,7 @@ class IntervalSet(NDArrayOperatorsMixin): A class representing a (irregular) set of time intervals in elapsed time, with relative operations """ - def __init__(self, start, end=None, time_units="s", **kwargs): + def __init__(self, start, end=None, time_units="s"): """ IntervalSet initializer diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 95ae2c37..92c2241c 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -514,7 +514,11 @@ def convolve(self, array, ep=None, trim="both"): new_data_array = _convolve(time_array, data_array, starts, ends, array, trim) - return self.__class__(t=time_array, d=new_data_array, time_support=ep) + kwargs_dict = dict(time_support=ep) + if hasattr(self, "columns"): + kwargs_dict["columns"] = self.columns + + return self.__class__(t=time_array, d=new_data_array, **kwargs_dict) def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=True): """Smooth a time series with a gaussian kernel. diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 6c052733..82eeab23 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -101,6 +101,17 @@ def __init__( - If the converted keys are not unique, i.e. {1: ts_2, "2": ts_2} is valid, {1: ts_2, "1": ts_2} is invalid. """ + # Check input type + if time_units not in ["s", "ms", "us"]: + raise ValueError("Argument time_units should be 's', 'ms' or 'us'") + if not isinstance(bypass_check, bool): + raise TypeError("Argument bypass_check should be of type bool") + passed_time_support = False + if time_support is not None and not isinstance(time_support, IntervalSet): + raise TypeError("Argument time_support should be of type IntervalSet") + else: + passed_time_support = True + self._initialized = False # convert all keys to integer @@ -141,7 +152,7 @@ def __init__( ) # If time_support is passed, all elements of data are restricted prior to init - if isinstance(time_support, IntervalSet): + if passed_time_support: self.time_support = time_support if not bypass_check: data = {k: data[k].restrict(self.time_support) for k in self.index} From cc2cdfa8591bc4084f4e56b890f02d773f0954eb Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Sun, 7 Jul 2024 09:36:59 +0200 Subject: [PATCH 018/195] fixed tsgroup --- pynapple/core/base_class.py | 1 + pynapple/core/time_series.py | 6 +++--- pynapple/core/ts_group.py | 10 +++++++--- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index da8c91ce..cf053574 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -335,6 +335,7 @@ def restrict(self, iset): 0 0.0 500.0 """ + assert isinstance(iset, IntervalSet), "Argument should be IntervalSet" time_array = self.index.values diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 92c2241c..ee62f433 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1002,9 +1002,9 @@ def __getitem__(self, key, *args, **kwargs): if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: - if isinstance(columns, pd.Index): - if not pd.api.types.is_integer_dtype(columns): - kwargs["columns"] = columns + # if isinstance(columns, pd.Index): + # if not pd.api.types.is_integer_dtype(columns): + kwargs["columns"] = columns return _get_class(output)( t=index, d=output, time_support=self.time_support, **kwargs diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 82eeab23..2b72f69c 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -107,10 +107,14 @@ def __init__( if not isinstance(bypass_check, bool): raise TypeError("Argument bypass_check should be of type bool") passed_time_support = False - if time_support is not None and not isinstance(time_support, IntervalSet): - raise TypeError("Argument time_support should be of type IntervalSet") - else: + + if isinstance(time_support, IntervalSet): passed_time_support = True + else: + if time_support is not None: + raise TypeError("Argument time_support should be of type IntervalSet") + else: + passed_time_support = False self._initialized = False From 4148d6d83ded7a3a382e742bdf67dd139302c441 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 8 Jul 2024 17:55:58 +0100 Subject: [PATCH 019/195] notebook various changes --- docs/examples/tutorial_signal_processing.py | 76 ++++++++++++--------- pynapple/process/signal_processing.py | 2 - 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index a4d58520..8684ed02 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -19,53 +19,54 @@ # mkdocs_gallery_thumbnail_number = 1 # # Now, import the necessary libraries: - +import matplotlib +matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np import pynapple as nap +from examples_utils import data, plotting # %% # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data - currently commented as correct NWB is not online -# path = "data/A2929-200711" -# extract_to = "data" -# if extract_to not in os.listdir("."): -# os.mkdir(extract_to) -# if path not in os.listdir("."): -# # Download the file -# response = requests.get( -# "https://www.dropbox.com/s/su4oaje57g3kit9/A2929-200711.zip?dl=1" -# ) -# zip_path = os.path.join(extract_to, "/downloaded_file.zip") -# # Write the zip file to disk -# with open(zip_path, "wb") as f: -# f.write(response.content) -# # Unzip the file -# with ZipFile(zip_path, "r") as zip_ref: -# zip_ref.extractall(extract_to) +# First things first: Let's download and extract the data - download section currently commented as correct NWB +# is not online +# path = data.download_data( +# "Achilles_10252013.nwb", "https://osf.io/hu5ma/download", "../data" +# ) +# data = nap.load_file(path) -# %% -# Let's save the RoiResponseSeries as a variable called 'transients' and print it -FS = 1250 -data = nap.load_file("data/stable.nwb") -print(data["ElectricalSeries"]) +data = nap.load_file("../data/Achillies_ephys.nwb") +FS = len(data["LFP"].index[:]) / (data["LFP"].index[-1] - data["LFP"].index[0]) +print(data) # %% # *** # Selecting slices # ----------------------------------- -# Let's consider a two 1-second slices of data, one from the sleep epoch and one from wake -NES = nap.TsdFrame( - t=data["ElectricalSeries"].index.values, - d=data["ElectricalSeries"].values, - time_support=data["ElectricalSeries"].time_support, +# Let's consider two 60-second slices of data, one from the sleep epoch and one from wake + +wake_minute_interval = nap.IntervalSet( + data["epochs"]["MazeEpoch"]["start"] + 60., + data["epochs"]["MazeEpoch"]["start"] + 120., +) +sleep_minute_interval = nap.IntervalSet( + data["epochs"]["POSTEpoch"]["start"] + 60., + data["epochs"]["POSTEpoch"]["start"] + 120., +) +wake_minute = nap.TsdFrame( + t=data["LFP"].restrict(wake_minute_interval).index.values, + d=data["LFP"].restrict(wake_minute_interval).values, + time_support=data["LFP"].restrict(wake_minute_interval).time_support, +) +sleep_minute = nap.TsdFrame( + t=data["LFP"].restrict(sleep_minute_interval).index.values, + d=data["LFP"].restrict(sleep_minute_interval).values, + time_support=data["LFP"].restrict(sleep_minute_interval).time_support, ) -wake_minute = NES.restrict(nap.IntervalSet(900, 960)) -sleep_minute = NES.restrict(nap.IntervalSet(0, 60)) channel = 1 # %% @@ -211,7 +212,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # %% # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 -interval = (937, 939) +interval = ( + wake_minute_interval["start"], + wake_minute_interval["start"]+2 +) wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) @@ -229,7 +233,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = (20, 25) +interval = ( + sleep_minute_interval["start"]+30, + sleep_minute_interval["start"]+35 +) sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) @@ -276,8 +283,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ) phase[i] = np.array(phase_i) +spikes = {k: v for k,v in spikes.items() if len(v) > 0} +phase = {k: v for k,v in phase.items() if len(v) > 0} + for i in range(15): - ax[5 + i].scatter(spikes[i], phase[i]) + ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) ax[5 + i].set_xlim(interval[0], interval[1]) ax[5 + i].set_ylim(-np.pi, np.pi) ax[5 + i].set_xlabel("time (s)") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index a38b56ab..72fc650f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -17,7 +17,6 @@ def compute_spectogram(sig, fs=None): """ Performs numpy fft on sig, returns output - ..todo: Make sig handle TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -39,7 +38,6 @@ def compute_spectogram(sig, fs=None): def compute_welch_spectogram(sig, fs=None): """ Performs scipy Welch's decomposition on sig, returns output - ..todo: Make sig handle TsdFrame, TsdTensor ---------- sig : pynapple.Tsd or pynapple.TsdFrame From 19c91835cd29171f7e67a66421b1e63b022dcc13 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 8 Jul 2024 21:00:01 +0100 Subject: [PATCH 020/195] compute_wavelet_transform can now handle TsdTensor --- docs/examples/tutorial_signal_processing.py | 23 ++++++-------- pynapple/process/signal_processing.py | 34 ++++++++++++--------- tests/test_signal_processing.py | 14 +++++++++ 3 files changed, 42 insertions(+), 29 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 8684ed02..0bb17fcb 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -20,6 +20,7 @@ # # Now, import the necessary libraries: import matplotlib + matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np @@ -50,12 +51,12 @@ # Let's consider two 60-second slices of data, one from the sleep epoch and one from wake wake_minute_interval = nap.IntervalSet( - data["epochs"]["MazeEpoch"]["start"] + 60., - data["epochs"]["MazeEpoch"]["start"] + 120., + data["epochs"]["MazeEpoch"]["start"] + 60.0, + data["epochs"]["MazeEpoch"]["start"] + 120.0, ) sleep_minute_interval = nap.IntervalSet( - data["epochs"]["POSTEpoch"]["start"] + 60., - data["epochs"]["POSTEpoch"]["start"] + 120., + data["epochs"]["POSTEpoch"]["start"] + 60.0, + data["epochs"]["POSTEpoch"]["start"] + 120.0, ) wake_minute = nap.TsdFrame( t=data["LFP"].restrict(wake_minute_interval).index.values, @@ -212,10 +213,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # %% # Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data freq = 3 -interval = ( - wake_minute_interval["start"], - wake_minute_interval["start"]+2 -) +interval = (wake_minute_interval["start"], wake_minute_interval["start"] + 2) wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) @@ -233,10 +231,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = ( - sleep_minute_interval["start"]+30, - sleep_minute_interval["start"]+35 -) +interval = (sleep_minute_interval["start"] + 30, sleep_minute_interval["start"] + 35) sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) @@ -283,8 +278,8 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw ) phase[i] = np.array(phase_i) -spikes = {k: v for k,v in spikes.items() if len(v) > 0} -phase = {k: v for k,v in phase.items() if len(v) > 0} +spikes = {k: v for k, v in spikes.items() if len(v) > 0} +phase = {k: v for k, v in phase.items() if len(v) > 0} for i in range(15): ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 72fc650f..ca61f2fe 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -173,30 +173,34 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ----- This computes the continuous wavelet transform at specified frequencies across time. """ - if not isinstance(sig, nap.Tsd) and not isinstance(sig, nap.TsdFrame): - raise TypeError("`sig` must be instance of Tsd or TsdFrame") + if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): + raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): - mwt = np.zeros([len(freqs), len(sig)], dtype=complex) + sig = sig.reshape((sig.shape[0], 1)) + output_shape = (sig.shape[0], len(freqs)) + else: + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + mwt = np.zeros( + [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex + ) + for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[ind, :] = _convolve_wavelet(sig, fs, freq, n_cycle, scaling, norm=norm) + mwt[:, ind, channel_i] = _convolve_wavelet( + sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + ) + if len(output_shape) == 2: return nap.TsdFrame( - t=sig.index, d=np.transpose(mwt), time_support=sig.time_support + t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support ) - else: - mwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - for channel_i in range(sig.values.shape[1]): - for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): - mwt[:, ind, channel_i] = _convolve_wavelet( - sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm - ) - return nap.TsdTensor(t=sig.index, d=mwt, time_support=sig.time_support) + return nap.TsdTensor( + t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + ) def _convolve_wavelet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index fd728efe..c50860ba 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -47,6 +47,14 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): + + ##..todo put there + t = np.linspace(0, 1, 1024) # can remove this when we move it + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4, 2) + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -71,6 +79,8 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) + #..todo: here + with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." @@ -89,3 +99,7 @@ def test_compute_wavelet_transform(): str(e_info.value) == "The length of number of cycles does not match other inputs." ) + + +if __name__ == "__main__": + test_compute_wavelet_transform() \ No newline at end of file From c59e01024e78de3ea416df362bb865f63bd2bc5d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 10 Jul 2024 11:23:02 -0400 Subject: [PATCH 021/195] update convolve --- pynapple/core/time_series.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index ee62f433..ed34e9d1 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -487,9 +487,9 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - assert is_array_like( - array - ), "Input should be a numpy array (or jax array if pynajax is installed)." + if not is_array_like(array): + raise RuntimeError("Input should be a numpy array (or jax array if pynajax is installed).") + assert array.ndim == 1, "Input should be a one dimensional array." assert trim in [ "both", From 2d6069804e54f5be3b4697e87f50c451e9da17c6 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 10 Jul 2024 11:30:15 -0400 Subject: [PATCH 022/195] Update --- pynapple/core/time_series.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index ed34e9d1..05282444 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -475,8 +475,8 @@ def convolve(self, array, ep=None, trim="both"): Parameters ---------- array : array-like - One dimensional input array-like. - + 1-D or 2-D array with kernel(s) to be used for convolution. + First dimension is assumed to be time. ep : None, optional The epochs to apply the convolution trim : str, optional @@ -488,14 +488,13 @@ def convolve(self, array, ep=None, trim="both"): The convolved time series """ if not is_array_like(array): - raise RuntimeError("Input should be a numpy array (or jax array if pynajax is installed).") - - assert array.ndim == 1, "Input should be a one dimensional array." - assert trim in [ - "both", - "left", - "right", - ], "Unknow argument. trim should be 'both', 'left' or 'right'." + raise IOError("Input should be a numpy array (or jax array if pynajax is installed).") + + if array.ndim == 0: + raise IOError("Provide a kernel with at least 1 dimension, current kernel has 0 dimensions") + + if trim not in ["both","left","right"]: + raise IOError("Unknow argument. trim should be 'both', 'left' or 'right'.") time_array = self.index.values data_array = self.values @@ -505,7 +504,8 @@ def convolve(self, array, ep=None, trim="both"): starts = ep.start ends = ep.end else: - assert isinstance(ep, IntervalSet) + if not isinstance(ep, IntervalSet): + raise IOError("ep should be an object of type IntervalSet") starts = ep.start ends = ep.end idx = _restrict(time_array, starts, ends) From 9b0bc02de96f27912bcfa1789e4b361d8386010b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 12 Jul 2024 12:44:00 -0400 Subject: [PATCH 023/195] New convolution --- pynapple/core/_core_functions.py | 64 ++++++++--------- pynapple/core/_jitted_functions.py | 40 +++++------ pynapple/core/time_series.py | 60 +++++++++++----- pynapple/core/utils.py | 9 +++ tests/test_time_series.py | 108 +++++++++++++++++++++++------ 5 files changed, 191 insertions(+), 90 deletions(-) diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index a7f67d4d..0ba1e915 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -11,7 +11,7 @@ import numpy as np from scipy import signal -from ._jitted_functions import ( +from ._jitted_functions import ( # pjitconvolve, jitbin_array, jitcount, jitremove_nan, @@ -19,7 +19,6 @@ jitrestrict_with_count, jitthreshold, jitvaluefrom, - pjitconvolve, ) from .utils import get_backend @@ -99,36 +98,37 @@ def _convolve(time_array, data_array, starts, ends, array, trim="both"): return convolve(time_array, data_array, starts, ends, array, trim) else: - if data_array.ndim == 1: - new_data_array = np.zeros(data_array.shape) - k = array.shape[0] - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - - t = idx_e - idx_s - if trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) - else: - cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) - # scipy is actually faster for Tsd - new_data_array[idx_s:idx_e] = signal.convolve( - data_array[idx_s:idx_e], array - )[cut[0] : cut[1]] - - return new_data_array - else: - new_data_array = np.zeros(data_array.shape) - for s, e in zip(starts, ends): - idx_s = np.searchsorted(time_array, s) - idx_e = np.searchsorted(time_array, e, side="right") - new_data_array[idx_s:idx_e] = pjitconvolve( - data_array[idx_s:idx_e], array, trim=trim - ) - - return new_data_array + # reshape to 2d + shape = data_array.shape + data_array = np.reshape(data_array, (shape[0], -1)) + + kshape = array.shape + k = kshape[0] + array = array.reshape(k, -1) + + new_data_array = np.zeros((shape[0], int(np.prod(shape[1:])), *array.shape[1:])) + + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + + t = idx_e - idx_s + if trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + else: + cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) + + for i in range(data_array.shape[1]): + for j in range(array.shape[1]): + new_data_array[idx_s:idx_e, i, j] = signal.convolve( + data_array[idx_s:idx_e, i], array[:, j] + )[cut[0] : cut[1]] + + new_data_array = new_data_array.reshape((*shape, *kshape[1:])) + + return new_data_array def _bin_average(time_array, data_array, starts, ends, bin_size): diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 669343c6..f7751910 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -375,33 +375,33 @@ def _jitbin_array(countin, time_array, data_array, starts, ends, bin_size): return (new_time_array, new_data_array) -@jit(nopython=True) -def jitconvolve(d, a): - return np.convolve(d, a) +# @jit(nopython=True) +# def jitconvolve(d, a): +# return np.convolve(d, a) -@njit(parallel=True) -def pjitconvolve(data_array, array, trim="both"): - shape = data_array.shape - t = shape[0] - k = array.shape[0] +# @njit(parallel=True) +# def pjitconvolve(data_array, array, trim="both"): +# shape = data_array.shape +# t = shape[0] +# k = array.shape[0] - data_array = data_array.reshape(t, -1) - new_data_array = np.zeros(data_array.shape) +# data_array = data_array.reshape(t, -1) +# new_data_array = np.zeros(data_array.shape) - if trim == "both": - cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) - elif trim == "left": - cut = (k - 1, t + k - 1) - elif trim == "right": - cut = (0, t) +# if trim == "both": +# cut = ((k - 1) // 2, t + k - 1 - ((k - 1) // 2) - (1 - k % 2)) +# elif trim == "left": +# cut = (k - 1, t + k - 1) +# elif trim == "right": +# cut = (0, t) - for i in prange(data_array.shape[1]): - new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] +# for i in prange(data_array.shape[1]): +# new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] - new_data_array = new_data_array.reshape(shape) +# new_data_array = new_data_array.reshape(shape) - return new_data_array +# return new_data_array ################################ diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 05282444..b8470eba 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -488,12 +488,19 @@ def convolve(self, array, ep=None, trim="both"): The convolved time series """ if not is_array_like(array): - raise IOError("Input should be a numpy array (or jax array if pynajax is installed).") + raise IOError( + "Input should be a numpy array (or jax array if pynajax is installed)." + ) + + if len(array) == 0: + raise IOError( + "Input array is length 0" + ) - if array.ndim == 0: - raise IOError("Provide a kernel with at least 1 dimension, current kernel has 0 dimensions") + if array.ndim > 2: + raise IOError("Array should be 1 or 2 dimension.") - if trim not in ["both","left","right"]: + if trim not in ["both", "left", "right"]: raise IOError("Unknow argument. trim should be 'both', 'left' or 'right'.") time_array = self.index.values @@ -505,7 +512,7 @@ def convolve(self, array, ep=None, trim="both"): ends = ep.end else: if not isinstance(ep, IntervalSet): - raise IOError("ep should be an object of type IntervalSet") + raise IOError("ep should be an object of type IntervalSet") starts = ep.start ends = ep.end idx = _restrict(time_array, starts, ends) @@ -515,10 +522,14 @@ def convolve(self, array, ep=None, trim="both"): new_data_array = _convolve(time_array, data_array, starts, ends, array, trim) kwargs_dict = dict(time_support=ep) - if hasattr(self, "columns"): + + nap_class = _get_class(new_data_array) + + if isinstance(self, TsdFrame) and array.ndim==1: # keep columns kwargs_dict["columns"] = self.columns - return self.__class__(t=time_array, d=new_data_array, **kwargs_dict) + return nap_class(t=time_array, d=new_data_array, **kwargs_dict) + def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=True): """Smooth a time series with a gaussian kernel. @@ -573,18 +584,21 @@ def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=Tru Time series convolved with a gaussian kernel """ - assert isinstance(std, (int, float)), "std should be type int or float" - assert isinstance(size_factor, int), "size_factor should be of type int" - assert isinstance(norm, bool), "norm should be of type boolean" - assert isinstance(time_units, str), "time_units should be of type str" + if not isinstance(std, (int, float)): + raise IOError("std should be type int or float") + if not isinstance(size_factor, int): + raise IOError("size_factor should be of type int") + if not isinstance(norm, bool): + raise IOError("norm should be of type boolean") + if not isinstance(time_units, str): + raise IOError("time_units should be of type str") std = TsIndex.format_timestamps(np.array([std]), time_units)[0] std_size = int(self.rate * std) if windowsize is not None: - assert isinstance( - windowsize, (int, float) - ), "windowsize should be type int or float" + if not isinstance(windowsize, Number): + raise IOError("windowsize should be type int or float") windowsize = TsIndex.format_timestamps(np.array([windowsize]), time_units)[ 0 ] @@ -615,12 +629,22 @@ def interpolate(self, ts, ep=None, left=None, right=None): right : None, optional Value to return for ts > tsd[-1], default is tsd[-1]. """ - assert isinstance( - ts, Base - ), "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + if not isinstance(ts, Base): + raise IOError( + "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + ) - if not isinstance(ep, IntervalSet): + if left is not None and not isinstance(left, Number): + raise IOError("Argument left should be of type float or int") + + if right is not None and not isinstance(right, Number): + raise IOError("Argument right should be of type float or int") + + if ep is None: ep = self.time_support + else: + if not isinstance(ep, IntervalSet): + raise IOError("ep should be an object of type IntervalSet") new_t = ts.restrict(ep).index diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 17d90546..8d978b43 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -122,12 +122,21 @@ def is_array_like(obj): has_ndim = hasattr(obj, "ndim") # Check for indexability (try to access the first element) + try: obj[0] is_indexable = True except Exception: is_indexable = False + if not is_indexable: + if hasattr(obj, "__len__"): + try: + if len(obj) == 0: + is_indexable = True # Could be an empty array + except: + is_indexable = False + # Check for iterable property try: iter(obj) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 0d48c727..cb3bfbb4 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -408,7 +408,31 @@ def test_dropna(self, tsd): assert len(new_tsd) == 0 assert len(new_tsd.time_support) == 0 - def test_convolve(self, tsd): + def test_convolve_raise_errors(self, tsd): + if not isinstance(tsd, nap.Ts): + + with pytest.raises(IOError) as e_info: + tsd.convolve([1,2,3]) + assert str(e_info.value) == "Input should be a numpy array (or jax array if pynajax is installed)." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.array([])) + assert str(e_info.value) == "Input array is length 0" + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones(3), trim='a') + assert str(e_info.value) == "Unknow argument. trim should be 'both', 'left' or 'right'." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones((2,3,4))) + assert str(e_info.value) == "Array should be 1 or 2 dimension." + + with pytest.raises(IOError) as e_info: + tsd.convolve(np.ones(3), ep=[1,2,3,4]) + assert str(e_info.value) == "ep should be an object of type IntervalSet" + + + def test_convolve_1d_kernel(self, tsd): array = np.random.randn(10) if not isinstance(tsd, nap.Ts): tsd2 = tsd.convolve(array) @@ -421,14 +445,6 @@ def test_convolve(self, tsd): tsd2.values.reshape(tsd2.shape[0], -1) ) - with pytest.raises(AssertionError) as e_info: - tsd.convolve([1,2,3]) - assert str(e_info.value) == "Input should be a numpy array (or jax array if pynajax is installed)." - - with pytest.raises(AssertionError) as e_info: - tsd.convolve(np.random.rand(2,3)) - assert str(e_info.value) == "Input should be a one dimensional array." - ep = nap.IntervalSet(start=[0, 60], end=[40,100]) tsd3 = tsd.convolve(array, ep) @@ -456,9 +472,41 @@ def test_convolve(self, tsd): tsd2.values.reshape(tsd2.shape[0], -1) ) - with pytest.raises(AssertionError) as e_info: - tsd.convolve(array, trim='a') - assert str(e_info.value) == "Unknow argument. trim should be 'both', 'left' or 'right'." + def test_convolve_2d_kernel(self, tsd): + array = np.random.randn(10, 3) + if not isinstance(tsd, nap.Ts): + # no epochs + tsd2 = tsd.convolve(array) + tmp = tsd.values.reshape(tsd.shape[0], -1) + + output = [] + + for i in range(tmp.shape[1]): + for j in range(array.shape[1]): + output.append( + np.convolve(tmp[:,i], array[:,j], mode='full')[4:-5] + ) + + output = np.array(output).T + np.testing.assert_array_almost_equal(output,tsd2.values.reshape(tsd.shape[0], -1)) + + # epochs + ep = nap.IntervalSet(start=[0, 60], end=[40,100]) + tsd2 = tsd.convolve(array, ep) + + for k in range(len(ep)): + tmp = tsd.restrict(ep[k]) + tmp2 = tmp.values.reshape(tmp.shape[0], -1) + output = [] + for i in range(tmp2.shape[1]): + for j in range(array.shape[1]): + output.append( + np.convolve(tmp2[:,i], array[:,j], mode='full')[4:-5] + ) + output = np.array(output).T + np.testing.assert_array_almost_equal( + output,tsd2.restrict(ep[k]).values.reshape(tmp.shape[0], -1) + ) def test_smooth(self, tsd): if not isinstance(tsd, nap.Ts): @@ -514,23 +562,23 @@ def test_smooth(self, tsd): def test_smooth_raise_error(self, tsd): if not isinstance(tsd, nap.Ts): - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth('a') assert str(e_info.value) == "std should be type int or float" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, size_factor='b') assert str(e_info.value) == "size_factor should be of type int" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, norm=1) assert str(e_info.value) == "norm should be of type boolean" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, time_units = 0) assert str(e_info.value) == "time_units should be of type str" - with pytest.raises(AssertionError) as e_info: + with pytest.raises(IOError) as e_info: tsd.smooth(1, windowsize='a') assert str(e_info.value) == "windowsize should be type int or float" @@ -757,10 +805,22 @@ def test_interpolate(self, tsd): tsd2 = tsd.interpolate(ts) np.testing.assert_array_almost_equal(tsd2.values, y) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsd.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" + with pytest.raises(IOError) as e: + tsd.interpolate(ts, left='a') + assert str(e.value) == "Argument left should be of type float or int" + + with pytest.raises(IOError) as e: + tsd.interpolate(ts, right='a') + assert str(e.value) == "Argument right should be of type float or int" + + with pytest.raises(IOError) as e: + tsd.interpolate(ts, ep=[1,2,3,4]) + assert str(e.value) == "ep should be an object of type IntervalSet" + # Right left ep = nap.IntervalSet(start=0, end=5) tsd = nap.Tsd(t=np.arange(1,4), d=np.arange(3), time_support=ep) @@ -989,7 +1049,7 @@ def test_interpolate(self, tsdframe): tsdframe2 = tsdframe.interpolate(ts) np.testing.assert_array_almost_equal(tsdframe2.values, data_stack) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsdframe.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" @@ -1021,6 +1081,14 @@ def test_interpolate_with_ep(self, tsdframe): tsdframe2 = tsdframe.interpolate(ts, ep) assert len(tsdframe2) == 0 + def test_convolve_keep_columns(self, tsdframe): + array = np.random.randn(10) + tsdframe = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), time_units="s", columns=['a', 'b', 'c']) + tsd2 = tsdframe.convolve(array) + + assert isinstance(tsd2, nap.TsdFrame) + np.testing.assert_array_equal(tsd2.columns, tsdframe.columns) + #################################################### # Test for ts #################################################### @@ -1343,7 +1411,7 @@ def test_interpolate(self, tsdtensor): tsdtensor2 = tsdtensor.interpolate(ts) np.testing.assert_array_almost_equal(tsdtensor2.values, data_stack) - with pytest.raises(AssertionError) as e: + with pytest.raises(IOError) as e: tsdtensor.interpolate([0, 1, 2]) assert str(e.value) == "First argument should be an instance of Ts, Tsd, TsdFrame or TsdTensor" From 2f3b457151d8024c4d6a5c450a9ef013889897b2 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 12 Jul 2024 12:50:23 -0400 Subject: [PATCH 024/195] linting --- pynapple/core/_jitted_functions.py | 2 +- pynapple/core/time_series.py | 7 ++----- pynapple/core/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index f7751910..dae40a8a 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -1,5 +1,5 @@ import numpy as np -from numba import jit, njit, prange +from numba import jit # , njit, prange ################################ diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index b8470eba..b7c8c5e2 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -493,9 +493,7 @@ def convolve(self, array, ep=None, trim="both"): ) if len(array) == 0: - raise IOError( - "Input array is length 0" - ) + raise IOError("Input array is length 0") if array.ndim > 2: raise IOError("Array should be 1 or 2 dimension.") @@ -525,12 +523,11 @@ def convolve(self, array, ep=None, trim="both"): nap_class = _get_class(new_data_array) - if isinstance(self, TsdFrame) and array.ndim==1: # keep columns + if isinstance(self, TsdFrame) and array.ndim == 1: # keep columns kwargs_dict["columns"] = self.columns return nap_class(t=time_array, d=new_data_array, **kwargs_dict) - def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=True): """Smooth a time series with a gaussian kernel. diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index 8d978b43..aa27eb7c 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -133,8 +133,8 @@ def is_array_like(obj): if hasattr(obj, "__len__"): try: if len(obj) == 0: - is_indexable = True # Could be an empty array - except: + is_indexable = True # Could be an empty array + except Exception: is_indexable = False # Check for iterable property From 3ad7b5a61795a9db092a6fdc83659188cf11a4bc Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 12 Jul 2024 15:15:20 -0400 Subject: [PATCH 025/195] adding tests for utils --- tests/test_utils.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..fdafbb3d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,20 @@ +"""Tests of utils for `pynapple` package.""" + +import pynapple as nap +import numpy as np +import pandas as pd +import pytest + +def test_get_backend(): + assert nap.core.utils.get_backend() == "numba" + +def test_is_array_like(): + assert nap.core.utils.is_array_like(np.ones(3)) + assert nap.core.utils.is_array_like(np.array([])) + assert not nap.core.utils.is_array_like([1,2,3]) + assert not nap.core.utils.is_array_like(1) + assert not nap.core.utils.is_array_like('a') + assert not nap.core.utils.is_array_like(True) + assert not nap.core.utils.is_array_like((1,2,3)) + assert not nap.core.utils.is_array_like({0:1}) + assert not nap.core.utils.is_array_like(np.array(0)) \ No newline at end of file From 63f52b2e32f540525d73fa50455f5c0687bfa893 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 12 Jul 2024 20:36:07 +0100 Subject: [PATCH 026/195] PR comment changes --- docs/examples/tutorial_signal_processing.py | 471 ++++++++++++++++---- pynapple/process/signal_processing.py | 91 +++- tests/test_signal_processing.py | 40 +- 3 files changed, 472 insertions(+), 130 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 0bb17fcb..dece1036 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -14,86 +14,138 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests` +# You can install all with `pip install matplotlib requests tqdm` # # mkdocs_gallery_thumbnail_number = 1 # # Now, import the necessary libraries: import matplotlib - matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np - +import os +import requests +import tqdm +import math import pynapple as nap -from examples_utils import data, plotting +import scipy # %% # *** # Downloading the data # ------------------ -# First things first: Let's download and extract the data - download section currently commented as correct NWB -# is not online +# First things first: Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024*1024 + with open(path, 'wb') as f: + for data in tqdm.tqdm(r.iter_content(block_size), unit='MB', unit_scale=True, + total=math.ceil(int(r.headers.get('content-length', 0))//block_size)): + f.write(data) -# path = data.download_data( -# "Achilles_10252013.nwb", "https://osf.io/hu5ma/download", "../data" -# ) -# data = nap.load_file(path) -data = nap.load_file("../data/Achillies_ephys.nwb") -FS = len(data["LFP"].index[:]) / (data["LFP"].index[-1] - data["LFP"].index[0]) +# %% +# *** +# Loading the data +# ------------------ +# Loading the data, calculating the sampling frequency + +data = nap.load_file(path) +FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) print(data) + # %% # *** # Selecting slices # ----------------------------------- # Let's consider two 60-second slices of data, one from the sleep epoch and one from wake -wake_minute_interval = nap.IntervalSet( - data["epochs"]["MazeEpoch"]["start"] + 60.0, - data["epochs"]["MazeEpoch"]["start"] + 120.0, +REM_minute_interval = nap.IntervalSet( + data["rem"]["start"][0] + 60.0, + data["rem"]["start"][0] + 120.0, ) -sleep_minute_interval = nap.IntervalSet( - data["epochs"]["POSTEpoch"]["start"] + 60.0, - data["epochs"]["POSTEpoch"]["start"] + 120.0, + +SWS_minute_interval = nap.IntervalSet( + data["nrem"]["start"][0] + 10.0, + data["nrem"]["start"][0] + 70.0, ) -wake_minute = nap.TsdFrame( - t=data["LFP"].restrict(wake_minute_interval).index.values, - d=data["LFP"].restrict(wake_minute_interval).values, - time_support=data["LFP"].restrict(wake_minute_interval).time_support, + +RUN_minute_interval = nap.IntervalSet( + data["forward_ep"]["start"][-18] + 0., + data["forward_ep"]["start"][-18] + 60., ) -sleep_minute = nap.TsdFrame( - t=data["LFP"].restrict(sleep_minute_interval).index.values, - d=data["LFP"].restrict(sleep_minute_interval).values, - time_support=data["LFP"].restrict(sleep_minute_interval).time_support, + +REM_minute = nap.TsdFrame( + t=data["eeg"].restrict(REM_minute_interval).index.values, + d=data["eeg"].restrict(REM_minute_interval).values, + time_support=data["eeg"].restrict(REM_minute_interval).time_support, ) -channel = 1 + +SWS_minute = nap.TsdFrame( + t=data["eeg"].restrict(SWS_minute_interval).index.values, + d=data["eeg"].restrict(SWS_minute_interval).values, + time_support=data["eeg"].restrict(SWS_minute_interval).time_support, +) + +RUN_minute = nap.TsdFrame( + t=data["eeg"].restrict(RUN_minute_interval).index.values, + d=data["eeg"].restrict(RUN_minute_interval).values, + time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +) +# RUN_position = nap.TsdFrame( +# t=data["position"].restrict(RUN_minute_interval).index.values[1:], +# d=np.diff(data['position'].restrict(RUN_minute_interval)), +# time_support=data["position"].restrict(RUN_minute_interval).time_support, +# ) +RUN_position = nap.TsdFrame( + t=data["position"].restrict(RUN_minute_interval).index.values[:], + d=data['position'].restrict(RUN_minute_interval), + time_support=data["position"].restrict(RUN_minute_interval).time_support, +) + +channel = 0 # %% # *** # Plotting the LFP activity of one slices # ----------------------------------- # Let's plot -fig, ax = plt.subplots(2) -for channel in range(sleep_minute.shape[1]): + +fig, ax = plt.subplots(3) + +for channel in range(SWS_minute.shape[1]): ax[0].plot( - sleep_minute[:, channel], + SWS_minute[:, channel], alpha=0.5, label="Sleep Data", ) -ax[0].set_title("Sleep ephys") -for channel in range(wake_minute.shape[1]): - ax[1].plot(wake_minute[:, channel], alpha=0.5, label="Wake Data") -ax[1].set_title("Wake ephys") +ax[0].set_title("non-REM ephys") +ax[0].set_ylabel("LFP (v)") +ax[0].set_xlabel("time (s)") +ax[0].margins(0) +for channel in range(REM_minute.shape[1]): + ax[1].plot(REM_minute[:, channel], alpha=0.5, label="Wake Data", color="orange") +ax[1].set_ylabel("LFP (v)") +ax[1].set_xlabel("time (s)") +ax[1].set_title("REM ephys") +ax[1].margins(0) +for channel in range(RUN_minute.shape[1]): + ax[2].plot(RUN_minute[:, channel], alpha=0.5, label="Wake Data", color="green") +ax[2].set_ylabel("LFP (v)") +ax[2].set_xlabel("time (s)") +ax[2].set_title("Running ephys") +ax[2].margins(0) plt.show() # %% # Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present -channel = 1 -fig, ax = plt.subplots(2) -fft = nap.compute_spectogram(sleep_minute, fs=int(FS)) +channel = 0 +fig, ax = plt.subplots(3) +fft = nap.compute_spectogram(SWS_minute, fs=int(FS)) ax[0].plot( fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" ) @@ -101,17 +153,27 @@ ax[0].set_xlabel("Freq (Hz)") ax[0].set_ylabel("Frequency Power") -ax[0].set_title("Sleep LFP Decomposition") -fft = nap.compute_spectogram(wake_minute, fs=int(FS)) +ax[0].set_title("non-REM LFP Decomposition") +fft = nap.compute_spectogram(REM_minute, fs=int(FS)) ax[1].plot( fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" ) ax[1].set_xlim((0, FS / 2)) fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_title("REM LFP Decomposition") ax[1].set_xlabel("Freq (Hz)") ax[1].set_ylabel("Frequency Power") +fft = nap.compute_spectogram(RUN_minute, fs=int(FS)) +ax[2].plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Running Data", c="green" +) +ax[2].set_xlim((0, FS / 2)) +fig.suptitle(f"Fourier Decomposition for channel {channel}") +ax[2].set_title("Running LFP Decomposition") +ax[2].set_xlabel("Freq (Hz)") +ax[2].set_ylabel("Frequency Power") + # ax.legend() plt.show() @@ -119,28 +181,46 @@ # %% # Let's now consider the Welch spectograms of waking and sleeping data... -fig, ax = plt.subplots(2) -fft = nap.compute_welch_spectogram(sleep_minute, fs=int(FS)) +fig, ax = plt.subplots(3) +welch = nap.compute_welch_spectogram(SWS_minute, fs=int(FS)) ax[0].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", color="blue" + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="non-REM Data", + color="blue" ) ax[0].set_xlim((0, FS / 2)) -ax[0].set_title("Sleep LFP Decomposition") +ax[0].set_title("non-REM LFP Decomposition") ax[0].set_xlabel("Freq (Hz)") ax[0].set_ylabel("Frequency Power") -welch = nap.compute_welch_spectogram(wake_minute, fs=int(FS)) +welch = nap.compute_welch_spectogram(REM_minute, fs=int(FS)) ax[1].plot( welch.index, np.abs(welch.iloc[:, channel]), alpha=0.5, - label="Wake Data", + label="REM Data", color="orange", ) ax[1].set_xlim((0, FS / 2)) fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[1].set_title("Sleep LFP Decomposition") +ax[1].set_title("REM LFP Decomposition") ax[1].set_xlabel("Freq (Hz)") ax[1].set_ylabel("Frequency Power") + +welch = nap.compute_welch_spectogram(RUN_minute, fs=int(FS)) +ax[2].plot( + welch.index, + np.abs(welch.iloc[:, channel]), + alpha=0.5, + label="Running Data", + color="green", +) +ax[2].set_xlim((0, FS / 2)) +fig.suptitle(f"Welch Decomposition for channel {channel}") +ax[2].set_title("Running LFP Decomposition") +ax[2].set_xlabel("Freq (Hz)") +ax[2].set_ylabel("Frequency Power") # ax.legend() plt.show() @@ -149,7 +229,7 @@ # Let's explore further with a wavelet decomposition -def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kwargs): +def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=None, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect="auto", **kwargs) @@ -166,11 +246,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw y_ticks_pos = np.linspace(0, freqs.size, y_ticks) y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) else: + y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - -fig, ax = plt.subplots(2) +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 10 +axd = fig.subplot_mosaic( + [ + ["wd_sws"], + ["lfp_sws"], + ["wd_rem"], + ["lfp_rem"], + ["wd_run"], + ["lfp_run"], + ["pos_run"] + ], + height_ratios=[1, .2, 1, .2, 1, .2, .2] +) freqs = np.array( [ 2.59, @@ -185,80 +278,192 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw 58.59, 82.88, 117.19, - 165.75, + 150.00, + 190.00, 234.38, + 270.00, 331.5, - 468.75, - 624.0, + 390.00, + # 468.75, + # 520.00, + # 570.00, + # 624.0, ] ) -mwt_sleep = nap.compute_wavelet_transform( - sleep_minute[:, channel], fs=None, freqs=freqs +mwt_SWS = nap.compute_wavelet_transform( + SWS_minute[:, channel], fs=None, freqs=freqs ) plot_timefrequency( - sleep_minute.index.values[:], + SWS_minute.index.values[:], freqs[:], - np.transpose(mwt_sleep[:, :].values), - ax=ax[0], + np.transpose(mwt_SWS[:, :].values), + ax=axd["wd_sws"], ) -ax[0].set_title(f"Sleep Data Wavelet Decomposition: Channel {channel}") -mwt_wake = nap.compute_wavelet_transform(wake_minute[:, channel], fs=None, freqs=freqs) +axd["wd_sws"].set_title(f"non-REM Data Wavelet Decomposition: Channel {channel}") + +mwt_REM = nap.compute_wavelet_transform(REM_minute[:, channel], fs=None, freqs=freqs) plot_timefrequency( - wake_minute.index.values[:], freqs[:], np.transpose(mwt_wake[:, :].values), ax=ax[1] + REM_minute.index.values[:], freqs[:], np.transpose(mwt_REM[:, :].values), ax=axd["wd_rem"] ) -ax[1].set_title(f"Wake Data Wavelet Decomposition: Channel {channel}") -plt.margins(0) +axd["wd_rem"].set_title(f"REM Data Wavelet Decomposition: Channel {channel}") + +mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], fs=None, freqs=freqs) +plot_timefrequency( + RUN_minute.index.values[:], freqs[:], np.transpose(mwt_RUN[:, :].values), ax=axd["wd_run"] +) +axd["wd_run"].set_title(f"Running Data Wavelet Decomposition: Channel {channel}") + +axd["lfp_sws"].plot(SWS_minute) +axd["lfp_rem"].plot(REM_minute) +axd["lfp_run"].plot(RUN_minute) +axd["pos_run"].plot(RUN_position) +axd["pos_run"].margins(0) +for k in ["lfp_sws", "lfp_rem", "lfp_run"]: + axd[k].margins(0) + axd[k].set_ylabel("LFP (v)") + axd[k].get_xaxis().set_visible(False) + axd[k].spines['top'].set_visible(False) + axd[k].spines['right'].set_visible(False) + axd[k].spines['bottom'].set_visible(False) + axd[k].spines['left'].set_visible(False) plt.show() -# %% -# Let's focus on the waking data. Let's see if we can isolate the theta oscillations from the data +# %%g freq = 3 -interval = (wake_minute_interval["start"], wake_minute_interval["start"] + 2) -wake_second = wake_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_wake_second = mwt_wake.restrict(nap.IntervalSet(interval[0], interval[1])) +interval = (REM_minute_interval["start"] + 0, REM_minute_interval["start"] + 5) +REM_second = REM_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_REM_second = mwt_REM.restrict(nap.IntervalSet(interval[0], interval[1])) fig, ax = plt.subplots(1) -ax.plot(wake_second.index.values, wake_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(REM_second.index.values, REM_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( - wake_second.index.values, - mwt_wake_second[:, freq].values.real, + REM_second.index.values, + mwt_REM_second[:, freq].values.real, label="Theta oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power.") plt.show() +# %% +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation + +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 10 +axd = fig.subplot_mosaic( + [ + ["raw_lfp"]*2, + ["wavelet"]*2, + ["fit_wavelet"]*2, + ["wavelet_power"]*2, + ["wavelet_phase"]*2 + ] + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], +) + + +# _, ax = plt.subplots(25, figsize=(10, 50)) +mwt_REM = np.transpose(mwt_REM_second) +axd["raw_lfp"].plot(REM_second.index, REM_second.values[:, 0]) +axd["raw_lfp"].margins(0) +plot_timefrequency(REM_second.index, freqs, np.abs(mwt_REM[:, :]), ax=axd["wavelet"]) + +axd["fit_wavelet"].plot(REM_second.index, REM_second.values[:, 0]) +axd["fit_wavelet"].plot(REM_second.index, mwt_REM[freq, :].real) +axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") +axd["fit_wavelet"].margins(0) + +axd["wavelet_power"].plot(REM_second.index, np.abs(mwt_REM[freq, :])) +axd["wavelet_power"].margins(0) +# ax[3].plot(lfp.index, lfp.values[:,0]) +axd["wavelet_phase"].plot(REM_second.index, np.angle(mwt_REM[freq, :])) +axd["wavelet_phase"].margins(0) + +spikes = {} +for i in data["units"].index: + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > interval[0]) + & (data["units"][i].times() < interval[1]) + ] + +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append( + np.angle( + mwt_REM[freq, np.argmin(np.abs(REM_second.index.values - spike))] + ) + ) + phase[i] = np.array(phase_i) + +spikes = {k: v for k, v in spikes.items() if len(v) > 20} +phase = {k: v for k, v in phase.items() if len(v) > 20} + +variances = {key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) for key, value in phase.items()} +spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) +phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) + +for i in range(num_cells): + axd[f"spikes_phasetime_{i}"].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) + axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) + axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) + axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") + axd[f"spikes_phasetime_{i}"].set_ylabel("phase") + + axd[f"spikephase_hist_{i}"].hist(phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10)) + axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) + +plt.tight_layout() +plt.show() + # %% # Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data freq = 0 # interval = (10, 15) -interval = (sleep_minute_interval["start"] + 30, sleep_minute_interval["start"] + 35) -sleep_second = sleep_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_sleep_second = mwt_sleep.restrict(nap.IntervalSet(interval[0], interval[1])) +interval = (SWS_minute_interval["start"] + 30, SWS_minute_interval["start"] + 50) +SWS_second = SWS_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +mwt_SWS_second = mwt_SWS.restrict(nap.IntervalSet(interval[0], interval[1])) _, ax = plt.subplots(1) -ax.plot(sleep_second[:, channel], alpha=0.5, label="Wake Data") +ax.plot(SWS_second[:, channel], alpha=0.5, label="Wake Data") ax.plot( - sleep_second.index.values, - mwt_sleep_second[:, freq].values.real, + SWS_second.index.values, + mwt_SWS_second[:, freq].values.real, label="Slow Wave Oscillations", ) ax.set_title(f"{freqs[freq]}Hz oscillation power") plt.show() # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during slow wave sleep +# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation -_, ax = plt.subplots(20, figsize=(10, 50)) -mwt_sleep = np.transpose(mwt_sleep_second) -ax[0].plot(sleep_second.index, sleep_second.values[:, 0]) -plot_timefrequency(sleep_second.index, freqs, np.abs(mwt_sleep[:, :]), ax=ax[1]) +fig = plt.figure(constrained_layout=True, figsize=(10, 50)) +num_cells = 5 +axd = fig.subplot_mosaic( + [ + ["raw_lfp"]*2, + ["wavelet"]*2, + ["fit_wavelet"]*2, + ["wavelet_power"]*2, + ["wavelet_phase"]*2 + ] + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], +) -ax[2].plot(sleep_second.index, sleep_second.values[:, 0]) -ax[2].plot(sleep_second.index, mwt_sleep[freq, :].real) -ax[2].set_title(f"{freqs[freq]}Hz") -ax[3].plot(sleep_second.index, np.abs(mwt_sleep[freq, :])) -# ax[3].plot(lfp.index, lfp.values[:,0]) -ax[4].plot(sleep_second.index, np.angle(mwt_sleep[freq, :])) +# _, ax = plt.subplots(25, figsize=(10, 50)) +mwt_SWS = np.transpose(mwt_SWS_second) +axd["raw_lfp"].plot(SWS_second.index, SWS_second.values[:, 0]) +axd["raw_lfp"].margins(0) + +plot_timefrequency(SWS_second.index, freqs, np.abs(mwt_SWS[:, :]), ax=axd["wavelet"]) + +axd["fit_wavelet"].plot(SWS_second.index, SWS_second.values[:, 0]) +axd["fit_wavelet"].plot(SWS_second.index, mwt_SWS[freq, :].real) +axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") +axd["fit_wavelet"].margins(0) + +axd["wavelet_power"].plot(SWS_second.index, np.abs(mwt_SWS[freq, :])) +axd["wavelet_power"].margins(0) +axd["wavelet_phase"].plot(SWS_second.index, np.angle(mwt_SWS[freq, :])) +axd["wavelet_phase"].margins(0) spikes = {} for i in data["units"].index: @@ -273,7 +478,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw for spike in spikes[i]: phase_i.append( np.angle( - mwt_sleep[freq, np.argmin(np.abs(sleep_second.index.values - spike))] + mwt_SWS[freq, np.argmin(np.abs(SWS_second.index.values - spike))] ) ) phase[i] = np.array(phase_i) @@ -281,12 +486,88 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, y_ticks=5, ax=None, **kw spikes = {k: v for k, v in spikes.items() if len(v) > 0} phase = {k: v for k, v in phase.items() if len(v) > 0} -for i in range(15): - ax[5 + i].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) - ax[5 + i].set_xlim(interval[0], interval[1]) - ax[5 + i].set_ylim(-np.pi, np.pi) - ax[5 + i].set_xlabel("time (s)") - ax[5 + i].set_ylabel("phase") +for i in range(num_cells): + axd[f"spikes_phasetime_{i}"].scatter(spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]]) + axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) + axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) + axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") + axd[f"spikes_phasetime_{i}"].set_ylabel("phase") + + axd[f"spikephase_hist_{i}"].hist(phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10)) + axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) plt.tight_layout() plt.show() + +# %% +# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data +# interval = (10, 15) + +# for run in [-16, -15, -13, -20]: +# interval = ( +# data["forward_ep"]["start"][run], +# data["forward_ep"]["end"][run]+3., +# ) +# print(interval) +# RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) +# RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) +# mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) +# _, ax = plt.subplots(3) +# plot_timefrequency( +# RUN_second_r.index.values[:], freqs[:], np.transpose(mwt_RUN_second_r[:, :].values), ax=ax[0] +# ) +# ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") +# ax[1].margins(0) +# +# ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") +# ax[2].set_xlim(RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max()) +# ax[2].margins(0) +# plt.show() + + +RUN_minute_interval = nap.IntervalSet( + data["forward_ep"]["start"][0], + data["forward_ep"]["end"][-1] +) + +RUN_minute = nap.TsdFrame( + t=data["eeg"].restrict(RUN_minute_interval).index.values, + d=data["eeg"].restrict(RUN_minute_interval).values, + time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +) + +RUN_position = nap.TsdFrame( + t=data["position"].restrict(RUN_minute_interval).index.values[:], + d=data['position'].restrict(RUN_minute_interval), + time_support=data["position"].restrict(RUN_minute_interval).time_support, +) + +mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], + freqs=freqs, + fs=None, + norm=None, + n_cycles=3.5, + scaling=1) + +for run in range(len(data["forward_ep"]["start"])): + interval = ( + data["forward_ep"]["start"][run], + data["forward_ep"]["end"][run]+5., + ) + if interval[1] - interval[0] < 6: + continue + print(interval) + RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) + RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) + mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) + _, ax = plt.subplots(3) + plot_timefrequency( + RUN_second_r.index.values[:], freqs[:], np.transpose(mwt_RUN_second_r[:, :].values), ax=ax[0] + ) + ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") + ax[1].margins(0) + + ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") + ax[2].set_xlim(RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max()) + ax[2].margins(0) + plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index ca61f2fe..f7b1d51e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,7 +14,7 @@ import pynapple as nap -def compute_spectogram(sig, fs=None): +def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output @@ -23,22 +23,41 @@ def compute_spectogram(sig, fs=None): Time series. fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal + ep : pynapple.IntervalSet or None, optional + The epoch to calculate the fft on. Must be length 1. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError( "Currently compute_spectogram is only implemented for Tsd or TsdFrame" ) + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + if len(ep) != 1: + raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - fft_result = np.fft.fft(sig.values, axis=0) - fft_freq = np.fft.fftfreq(len(sig.values), 1 / fs) - return pd.DataFrame(fft_result, fft_freq) - + fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) + fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret def compute_welch_spectogram(sig, fs=None): """ - Performs scipy Welch's decomposition on sig, returns output + Performs scipy Welch's decomposition on sig, returns output. + Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a + window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. + + ..todo: remove this or add binsize parameter + ..todo: be careful of border artifacts + Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. @@ -57,16 +76,16 @@ def compute_welch_spectogram(sig, fs=None): def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ - Defines the complex Morelet wavelet kernel + Defines the complex Morlet wavelet kernel Parameters ---------- M : int Length of the wavelet ncycles : float - number of wavelet cycles to use. Default is 5 + number of wavelet cycles to use. Default is 1.5 scaling: float - Scaling factor. Default is 1.5 + Scaling factor. Default is 1.0 precision: int Precision of wavelet to use @@ -139,7 +158,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1): return np.arange(freq_start, freq_stop + freq_step, freq_step) -def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="amp"): +def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None): """ Compute the time-frequency representation of a signal using morlet wavelets. @@ -147,20 +166,21 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. - fs : float or None - Sampling rate, in Hz. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + fs : float or None + Sampling rate, in Hz. Defaults to sig.rate if None is given n_cycles : float or 1d array Length of the filter, as the number of cycles for each frequency. If 1d array, this defines n_cycles for each frequency. scaling : float Scaling factor. - norm : {'sss', 'amp'}, optional + norm : {None, 'sss', 'amp'}, optional Normalization method: + * None - no normalization * 'sss' - divide by the square root of the sum of squares * 'amp' - divide by the sum of amplitudes @@ -178,7 +198,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) + fs = sig.rate n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) @@ -192,7 +212,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): mwt[:, ind, channel_i] = _convolve_wavelet( - sig[:, channel_i], fs, freq, n_cycle, scaling, norm=norm + sig[:, channel_i], fs, freq, n_cycle, scaling, precision=precision, norm=norm ) if len(output_shape) == 2: return nap.TsdFrame( @@ -204,7 +224,7 @@ def compute_wavelet_transform(sig, fs, freqs, n_cycles=1.5, scaling=1.0, norm="a def _convolve_wavelet( - sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm="sss" + sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm=None ): """ Convolve a signal with a complex wavelet. @@ -221,17 +241,22 @@ def _convolve_wavelet( Length of the filter, as the number of cycles of the oscillation with specified frequency. scaling : float, optional, default: 0.5 Scaling factor for the morlet wavelet. - norm : {'sss', 'amp'}, optional + precision: int, optional, defaul: 10 + Precision of wavelet - higher number will lead to higher resolution wavelet (i.e. a longer filter bank + to be convolved with the signal) + norm : {'sss', 'amp', None}, optional Normalization method: * 'sss' - divide by the square root of the sum of squares * 'amp' - divide by the sum of amplitudes + * None - no normalization Returns ------- array - Complex- valued time series. + Complex-valued time series. + ..todo: fix scaling Notes ----- @@ -239,19 +264,27 @@ def _convolve_wavelet( * Taking np.abs() of output gives the analytic amplitude. * Taking np.angle() of output gives the analytic phase. """ - if norm not in ["sss", "amp"]: - raise ValueError("Given `norm` must be `sss` or `amp`") + if norm not in ["sss", "amp", None]: + raise ValueError("Given `norm` must be None, `sss` or `amp`") morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) x = np.linspace(-8, 8, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] + conv = np.convolve(sig, int_psi_scale) - coef = -np.sqrt(scale) * np.diff(conv, axis=-1) + if norm == "sss": + coef = -np.sqrt(scale) * np.diff(conv, axis=-1) + elif norm == "amp": + coef = -scale * np.diff(conv, axis=-1) + else: + coef = np.diff(conv, axis=-1) #No normalization seems to be most effective... take others out? Why scale? ..todo + # transform axis is always -1 due to the data reshape above d = (coef.shape[-1] - sig.shape[-1]) / 2.0 if d > 0: @@ -262,6 +295,22 @@ def _convolve_wavelet( def _integrate(arr, step): + """ + Integrates an array with respect to some step param. Used for integrating complex wavelets. + + Parameters + ---------- + arr : np.ndarray + wave function to be integrated + step : float + Step size of vgiven wave function array + + Returns + ------- + array + Complex-valued integrated wavelet + + """ integral = np.cumsum(arr) integral *= step return integral diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index c50860ba..e302c446 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -8,16 +8,31 @@ def test_compute_spectogram(): - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) + with pytest.raises(ValueError) as e_info: + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81])) + r = nap.compute_spectogram(sig) + assert ( + str(e_info.value) + == "Given epoch (or signal time_support) must have length 1" + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) r = nap.compute_spectogram(sig) assert isinstance(r, pd.DataFrame) - assert r.shape[0] == 1024 + assert r.shape[0] == 500 - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_spectogram(sig) assert isinstance(r, pd.DataFrame) - assert r.shape == (1024, 4) + assert r.shape == (500, 4) + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_spectogram(sig, full_range=True) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1000, 4) with pytest.raises(TypeError) as e_info: nap.compute_spectogram("a_string") @@ -48,13 +63,6 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): - ##..todo put there - t = np.linspace(0, 1, 1024) # can remove this when we move it - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4, 2) - t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -79,7 +87,11 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) - #..todo: here + t = np.linspace(0, 1, 1024) # can remove this when we move it + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert mwt.shape == (1024, 10, 4, 2) with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) @@ -102,4 +114,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_wavelet_transform() \ No newline at end of file + test_compute_spectogram() \ No newline at end of file From dda072f2630f52f805563e0bff54e6b20c1d60ad Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 12 Jul 2024 21:20:25 +0100 Subject: [PATCH 027/195] filterbank changes --- pynapple/process/signal_processing.py | 40 +++++++++++++++++++++++++-- tests/test_signal_processing.py | 2 +- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index f7b1d51e..629f0bd1 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -199,7 +199,7 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr freqs = _create_freqs(*freqs) if fs is None: fs = sig.rate - n_cycles = _check_n_cycles(n_cycles, len(freqs)) + # n_cycles = _check_n_cycles(n_cycles, len(freqs)) if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) @@ -209,6 +209,18 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr mwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex ) + + filter_bank = _generate_morelet_filterbank(freqs, fs, n_cycles, scaling, precision) + # + import matplotlib + matplotlib.use("TkAgg") + import matplotlib.pyplot as plt + plt.clf() + for f in filter_bank: + plt.plot(f) + plt.show() + conv = np.convolve(sig, filter_bank) + for channel_i in range(sig.values.shape[1]): for ind, (freq, n_cycle) in enumerate(zip(freqs, n_cycles)): mwt[:, ind, channel_i] = _convolve_wavelet( @@ -222,6 +234,30 @@ def compute_wavelet_transform(sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, pr t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support ) +def _generate_morelet_filterbank(freqs, fs, n_cycles, scaling, precision): + """ + Make docsting #..todo: + + :param freqs: + :param n_cycles: + :param scaling: + :param precision: + :return: + """ + filter_bank = [] + morlet_f = _morlet(int(2 ** precision), ncycles=n_cycles, scaling=scaling) + x = np.linspace(-8, 8, int(2 ** precision)) + int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + for freq in freqs: + scale = scaling / (freq / fs) + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) + j = j.astype(int) # floor + if j[-1] >= int_psi.size: + j = np.extract(j < int_psi.size, j) + int_psi_scale = int_psi[j][::-1] + filter_bank.append(int_psi_scale) + return filter_bank + def _convolve_wavelet( sig, fs, freq, n_cycles=1.5, scaling=1.0, precision=10, norm=None @@ -256,7 +292,6 @@ def _convolve_wavelet( array Complex-valued time series. - ..todo: fix scaling Notes ----- @@ -276,6 +311,7 @@ def _convolve_wavelet( if j[-1] >= int_psi.size: j = np.extract(j < int_psi.size, j) int_psi_scale = int_psi[j][::-1] + print(len(int_psi_scale)) conv = np.convolve(sig, int_psi_scale) if norm == "sss": diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index e302c446..7be3eef9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -114,4 +114,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_spectogram() \ No newline at end of file + test_compute_wavelet_transform() \ No newline at end of file From d0c0ddd349810d111dfff9847699e16649ce90e5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Sat, 13 Jul 2024 00:37:07 +0100 Subject: [PATCH 028/195] fixed test --- pynapple/process/signal_processing.py | 42 ++++----------------------- tests/test_signal_processing.py | 32 +++++--------------- 2 files changed, 13 insertions(+), 61 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index acc711ce..0ec96d10 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -102,39 +102,6 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _check_n_cycles(n_cycles, len_cycles=None): - """ - Check an input as a number of cycles, and make it iterable. - - Parameters - ---------- - n_cycles : float or list - Definition of number of cycles to check. If a single value, the same number of cycles is used for each - frequency value. If a list or list_like, then should be a n_cycles corresponding to each frequency. - len_cycles: int, optional - What the length of `n_cycles` should be, if it's a list. - - Returns - ------- - iter - An iterable version of the number of cycles. - """ - if isinstance(n_cycles, (int, float, np.number)): - if n_cycles <= 0: - raise ValueError("Number of cycles must be a positive number.") - n_cycles = repeat(n_cycles) - elif isinstance(n_cycles, (tuple, list, np.ndarray)): - for cycle in n_cycles: - if cycle <= 0: - raise ValueError("Each number of cycles must be a positive number.") - if len_cycles and len(n_cycles) != len_cycles: - raise ValueError( - "The length of number of cycles does not match other inputs." - ) - n_cycles = iter(n_cycles) - return n_cycles - - def _create_freqs(freq_start, freq_stop, freq_step=1): """ Creates an array of frequencies. @@ -199,6 +166,9 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") + if isinstance(n_cycles, (int, float, np.number)): + if n_cycles <= 0: + raise ValueError("Number of cycles must be a positive number.") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) @@ -206,8 +176,6 @@ def compute_wavelet_transform( if fs is None: fs = sig.rate - # n_cycles = _check_n_cycles(n_cycles, len(freqs)) - if isinstance(sig, nap.Tsd): sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) @@ -229,9 +197,9 @@ def compute_wavelet_transform( elif norm == "amp": coef *= -scaling / (freqs[f_i] / fs) coef = np.insert( - coef, 1, coef[0] + coef, 1, coef[0], axis=0 ) # slightly hacky line, necessary to make output correct shape - mwt[:, f_i, :] = np.expand_dims(coef, axis=1) + mwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) if len(output_shape) == 2: return nap.TsdFrame( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index d9159416..68347c3a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -10,12 +10,14 @@ def test_compute_spectogram(): with pytest.raises(ValueError) as e_info: t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.random.random(1000), t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81])) + sig = nap.Tsd( + d=np.random.random(1000), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ) r = nap.compute_spectogram(sig) assert ( - str(e_info.value) - == "Given epoch (or signal time_support) must have length 1" + str(e_info.value) == "Given epoch (or signal time_support) must have length 1" ) t = np.linspace(0, 1, 1000) @@ -75,13 +77,6 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 5)) - ) - assert mwt.shape == (1024, 10) - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) @@ -97,17 +92,6 @@ def test_compute_wavelet_transform(): nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, -2.5), 5)) - ) - assert str(e_info.value) == "Each number of cycles must be a positive number." - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=tuple(np.repeat((1.5, 2.5), 2)) - ) - assert ( - str(e_info.value) - == "The length of number of cycles does not match other inputs." - ) +if __name__ == "__main__": + test_compute_wavelet_transform() From 9c53af536b791b52b94044234b6f0e78688aef49 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Sat, 13 Jul 2024 00:49:17 +0100 Subject: [PATCH 029/195] unused import removed --- pynapple/process/signal_processing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 0ec96d10..e82c1901 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,8 +4,6 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ -from itertools import repeat - import numpy as np import pandas as pd from scipy.signal import welch From f6b11d3977906a293f2cc9ea7e2f6480d69b7cc9 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Mon, 15 Jul 2024 08:38:44 +0200 Subject: [PATCH 030/195] added lazy loading option for nwb loading function --- pynapple/io/misc.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index a3bab79e..cd08228d 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -22,7 +22,7 @@ from .suite2p import Suite2P -def load_file(path): +def load_file(path, lazy_loading=True): """Load file. Current format supported is (npz,nwb,) .npz -> If the file is compatible with a pynapple format, the function will return a pynapple object. @@ -34,6 +34,8 @@ def load_file(path): ---------- path : str Path to the file + lazy_loading : bool + Lazy loading of the data, used only for NWB files Returns ------- @@ -49,7 +51,7 @@ def load_file(path): if path.endswith(".npz"): return NPZFile(path).load() elif path.endswith(".nwb"): - return NWBFile(path) + return NWBFile(path, lazy_loading=lazy_loading) else: raise RuntimeError("File format not supported") else: From 2b4bc86ab1466299acc6e610193b17071073341e Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 15 Jul 2024 17:38:18 +0100 Subject: [PATCH 031/195] logspacing --- pynapple/process/signal_processing.py | 15 ++++++++++----- tests/test_signal_processing.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e82c1901..8c0a0fc7 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -48,7 +48,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): def compute_welch_spectogram(sig, fs=None): """ - Performs scipy Welch's decomposition on sig, returns output. + Performs Welch's decomposition on sig, returns output. Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. @@ -100,27 +100,32 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, freq_step=1): +def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_base=np.e): """ Creates an array of frequencies. - ..todo:: Implement log scaling - Parameters ---------- freq_start : float Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. + log_scaling: Bool + If True, will use log spacing with base log_base for frequency spacing. Default False. freq_step: float, optional Step value, for linearly spaced values between start and stop. + log_base: float + If log_scaling==True, this defines the base of the log to use. Returns ------- freqs: 1d array Frequency indices. """ - return np.arange(freq_start, freq_stop + freq_step, freq_step) + if not log_scaling: + return np.arange(freq_start, freq_stop + freq_step, freq_step) + else: + return np.logspace(freq_start, freq_stop, base=log_base) def compute_wavelet_transform( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 68347c3a..5055cdbf 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -45,6 +45,16 @@ def test_compute_spectogram(): def test_compute_welch_spectogram(): + t = np.linspace(0, 1, 10000) + sig = nap.TsdFrame( + d=np.random.random((10000, 4)), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.4], end=[0.2, 0.525]), + ) + r = nap.compute_welch_spectogram(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[1] == 4 + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) r = nap.compute_welch_spectogram(sig) @@ -94,4 +104,4 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_wavelet_transform() + test_compute_welch_spectogram() From 8660aff3db2148d71a09f07f28efb9fc359bdfea Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 15 Jul 2024 14:33:29 -0400 Subject: [PATCH 032/195] added pickling support for tsgroup --- pynapple/core/ts_group.py | 4 ++++ tests/test_time_series.py | 26 +++++++++++++++++++++++++- tests/test_ts_group.py | 39 ++++++++++++++++++++++++++++++++++----- 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 2b72f69c..e5460b1b 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -202,6 +202,10 @@ def __getattr__(self, name): AttributeError If the requested attribute is not a metadata column. """ + # avoid infinite recursion when pickling due to + # self._metadata.column having attributes '__reduce__', '__reduce_ex__' + if name in ('__getstate__', '__setstate__', '__reduce__', '__reduce_ex__'): + raise AttributeError(name) # Check if the requested attribute is part of the metadata if name in self._metadata.columns: return self._metadata[name] diff --git a/tests/test_time_series.py b/tests/test_time_series.py index cb3bfbb4..906486c6 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,10 +1,12 @@ """Tests of time series for `pynapple` package.""" -import pynapple as nap +import pickle + import numpy as np import pandas as pd import pytest +import pynapple as nap # tsd1 = nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s") # tsd2 = nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 3), columns = ['a', 'b', 'c']) @@ -1444,3 +1446,25 @@ def test_interpolate_with_ep(self, tsdtensor): tsdframe2 = tsdtensor.interpolate(ts, ep) assert len(tsdframe2) == 0 +@pytest.mark.parametrize("obj", + [ + nap.Tsd(t=np.arange(10), d=np.random.rand(10), time_units="s"), + nap.TsdFrame( + t=np.arange(10), d=np.random.rand(10, 3), time_units="s", columns=["a","b","c"] + ), + nap.TsdTensor(t=np.arange(10), d=np.random.rand(10, 3, 2), time_units="s"), + ]) +def test_pickling(obj): + """Test that pikling works as expected.""" + # pickle and unpickle ts_group + pickled_obj = pickle.dumps(obj) + unpickled_obj = pickle.loads(pickled_obj) + + # Ensure time is the same + assert np.all(obj.t == unpickled_obj.t) + + # Ensure data is the same + assert np.all(obj.d == unpickled_obj.d) + + # Ensure time support is the same + assert np.all(obj.time_support == unpickled_obj.time_support) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index fc221f60..3ade3c94 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -6,13 +6,17 @@ """Tests of ts group for `pynapple` package.""" -import pynapple as nap +import pickle +import warnings +from collections import UserDict +from contextlib import nullcontext as does_not_raise + import numpy as np import pandas as pd import pytest -from collections import UserDict -import warnings -from contextlib import nullcontext as does_not_raise + +import pynapple as nap + @pytest.fixture def group(): @@ -854,4 +858,29 @@ def test_merge_time_support(self, ts_group, time_support, reset_time_support, ex np.testing.assert_array_almost_equal( ts_group.time_support.as_units("s").to_numpy(), merged.time_support.as_units("s").to_numpy() - ) \ No newline at end of file + ) + + +def test_pickling(ts_group): + """Test that pikling works as expected.""" + # pickle and unpickle ts_group + pickled_obj = pickle.dumps(ts_group) + unpickled_obj = pickle.loads(pickled_obj) + + # Ensure the type is the same + assert type(ts_group) is type(unpickled_obj), "Types are different" + + # Ensure that TsGroup have same len + assert len(ts_group) == len(unpickled_obj) + + # Ensure that metadata content is the same + assert np.all(unpickled_obj._metadata == ts_group._metadata) + + # Ensure that metadata columns are the same + assert np.all(unpickled_obj._metadata.columns == ts_group._metadata.columns) + + # Ensure that the Ts are the same + assert all([np.all(ts_group[key].t == unpickled_obj[key].t) for key in unpickled_obj.keys()]) + + # Ensure time support is the same + assert np.all(ts_group.time_support == unpickled_obj.time_support) From 0c59fba27a841d33f5056e5a2b1bff434099aae6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 15 Jul 2024 14:44:27 -0400 Subject: [PATCH 033/195] linters --- pynapple/core/ts_group.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index e5460b1b..3c96230f 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -204,7 +204,7 @@ def __getattr__(self, name): """ # avoid infinite recursion when pickling due to # self._metadata.column having attributes '__reduce__', '__reduce_ex__' - if name in ('__getstate__', '__setstate__', '__reduce__', '__reduce_ex__'): + if name in ("__getstate__", "__setstate__", "__reduce__", "__reduce_ex__"): raise AttributeError(name) # Check if the requested attribute is part of the metadata if name in self._metadata.columns: From c5f8cfb58c7bc24fa4bddc954f583b7240672f5d Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 00:26:50 +0200 Subject: [PATCH 034/195] npz refactoring and lazy loading --- pynapple/core/interval_set.py | 19 ++++ pynapple/core/time_series.py | 68 +++++++++++++ pynapple/core/ts_group.py | 63 ++++++++++++ pynapple/io/interface_npz.py | 182 ++++++++++++++-------------------- pynapple/io/misc.py | 12 ++- tests/npzfilestest/tsd2.json | 4 + 6 files changed, 238 insertions(+), 110 deletions(-) create mode 100644 tests/npzfilestest/tsd2.json diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 9f9798d9..1362fce2 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -669,3 +669,22 @@ def save(self, filename): ) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load an IntervalSet object from a npz file. + + The file should contain the keys 'start', 'end' and 'type'. The 'type' key should be 'IntervalSet'. + + Returns + ------- + IntervalSet + The IntervalSet object + + Raises + ------ + RuntimeError + If the file does not contain the 'start', 'end' and 'type' keys. + """ + return cls(start=file["start"], end=file["end"]) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index b7c8c5e2..065aa473 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -858,6 +858,23 @@ def save(self, filename): ) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a TsdTensor object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Tsd object + """ + return cls(t=file["t"], d=file["d"], time_support=IntervalSet(start=file["start"], end=file["end"])) class TsdFrame(BaseTsd): @@ -1144,6 +1161,23 @@ def save(self, filename): ) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a Tsd object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Tsd object + """ + return cls(t=file["t"], d=file["d"], columns=file["columns"], time_support=IntervalSet(start=file["start"], end=file["end"])) class Tsd(BaseTsd): @@ -1466,6 +1500,23 @@ def save(self, filename): ) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a Tsd object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Tsd object + """ + return cls(t=file["t"], d=file["d"], time_support=IntervalSet(start=file["start"], end=file["end"])) class Ts(Base): @@ -1758,3 +1809,20 @@ def save(self, filename): ) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a Ts object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Ts object + """ + return cls(t=file["t"], time_support=IntervalSet(start=file["start"], end=file["end"])) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 2b72f69c..d2e044d6 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1379,3 +1379,66 @@ def save(self, filename): np.savez(filename, **dicttosave) return + + @classmethod + def _from_npz_reader(cls, file): + """ + Load a Tsd object from a npz file. + + Parameters + ---------- + file : str + The opened npz file + + Returns + ------- + Tsd + The Tsd object + """ + + times = file["t"] + index = file["index"] + has_data = "d" in file.keys() + time_support = IntervalSet(file["start"], file["end"]) + + + if has_data: + data = file["data"] + + if "keys" in file.keys(): + keys = file["keys"] + else: + keys = np.unique(index) + + group = {} + for key in keys: + filtering_index = index == key + t = times[filtering_index] + + if has_data: + group[key] = nap.Tsd( + t=t, + d=data[filtering_index], + time_support=time_support, + ) + else: + group[key] = Ts( + t=t, time_support=time_support + ) + + tsgroup = cls( + group, time_support=time_support, bypass_check=True + ) + + metainfo = {} + not_info_keys = {"start", "end", "t", "index", "d", "rate", "keys"} + + for k in set(file.keys()) - not_info_keys: + tmp = file[k] + if len(tmp) == len(tsgroup): + metainfo[k] = tmp + + tsgroup.set_info(**metainfo) + return tsgroup + + diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 4795e0b2..6a232f43 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -6,11 +6,6 @@ # @Last Modified by: Guillaume Viejo # @Last Modified time: 2024-04-02 14:32:25 -""" -File classes help to validate and load pynapple objects or NWB files. -Data are always lazy-loaded. -Both classes behaves like dictionary. -""" import os @@ -18,6 +13,59 @@ from .. import core as nap +# +EXPECTED_ENTRIES = {"TsGroup": {"t", "start", "end", "index"}, + "TsdFrame": {"t", "d", "start", "end", "columns"}, + "TsdTensor": {"t", "d", "start", "end"}, + "Tsd": {"t", "d", "start", "end"}, + "Ts": {"t", "start", "end"}, + "IntervalSet": {"start", "end"}} + + +def _find_class_from_variables(file_variables, data_ndims=None): + + if data_ndims is not None: + # either TsdTensor or Tsd: + assert EXPECTED_ENTRIES["Tsd"].issubset(file_variables) + + return "Tsd" if data_ndims == 1 else "TsdTensor" + + for possible_type, espected_variables in EXPECTED_ENTRIES.items(): + if espected_variables.issubset(file_variables): + return possible_type + + return "npz" + + +class LazyNPZLoader: + """Class that lazily loads an NPZ file. + """ + def __init__(self, file_path, lazy_loading=False): + self.lazy_loading = lazy_loading + self.file_path = file_path + self.npz_file = np.load(file_path, allow_pickle=True, mmap_mode='r' if lazy_loading else None) + self.data = {key: None for key in self.npz_file.keys()} + + def __getitem__(self, key): + if key not in self.data: + raise KeyError(f"{key} not found in the NPZ file") + + if self.data[key] is None: + self.data[key] = self._load_array(key) + + return self.data[key] + + def _load_array(self, key): + if self.lazy_loading: + array_info = self.npz_file.zip.read(self.npz_file.zip.NameToInfo[key].filename) + np_array = np.frombuffer(array_info, dtype=self.npz_file[key].dtype).reshape(self.npz_file[key].shape) + return np.memmap(self.npz_file.filename, dtype=np_array.dtype, mode='r', shape=np_array.shape) + else: + return self.npz_file[key] + + def keys(self): + return self.npz_file.keys() + class NPZFile(object): """Class that points to a NPZ file that can be loaded as a pynapple object. @@ -35,8 +83,10 @@ class NPZFile(object): dtype: int64 """ + # valid_types = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] + - def __init__(self, path): + def __init__(self, path, lazy_loading=False): """Initialization of the NPZ file Parameters @@ -46,35 +96,22 @@ def __init__(self, path): """ self.path = path self.name = os.path.basename(path) - self.file = np.load(self.path, allow_pickle=True) - self.type = "" - - # First check if type is explicitely defined - possible = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] - if "type" in self.file.keys(): - if len(self.file["type"]) == 1: - if isinstance(self.file["type"][0], np.str_): - if self.file["type"] in possible: - self.type = self.file["type"][0] - - # Second check manually - if self.type == "": - k = set(self.file.keys()) - if {"t", "start", "end", "index"}.issubset(k): - self.type = "TsGroup" - elif {"t", "d", "start", "end", "columns"}.issubset(k): - self.type = "TsdFrame" - elif {"t", "d", "start", "end"}.issubset(k): - if self.file["d"].ndim == 1: - self.type = "Tsd" - else: - self.type = "TsdTensor" - elif {"t", "start", "end"}.issubset(k): - self.type = "Ts" - elif {"start", "end"}.issubset(k): - self.type = "IntervalSet" - else: - self.type = "npz" + self.file = LazyNPZLoader(path, lazy_loading=lazy_loading) # np.load(self.path, allow_pickle=True) + type_ = "" + + # First check if type is explicitely defined in the file: + try: + type_ = self.file["type"][0] + assert type_ in EXPECTED_ENTRIES.keys() + + # if not, use heuristics: + except (KeyError, IndexError, AssertionError): + file_variables = set(self.file.keys()) + data_ndims = self.file["d"].ndim if "d" in file_variables else None + + type_ = _find_class_from_variables(file_variables, data_ndims) + + self.type = type_ def load(self): """Load the NPZ file @@ -85,74 +122,7 @@ def load(self): A pynapple object """ if self.type == "npz": - return self.file - else: - time_support = nap.IntervalSet(self.file["start"], self.file["end"]) - if self.type == "TsGroup": - - times = self.file["t"] - index = self.file["index"] - has_data = False - if "d" in self.file.keys(): - data = self.file["data"] - has_data = True - - if "keys" in self.file.keys(): - keys = self.file["keys"] - else: - keys = np.unique(index) - - group = {} - for k in keys: - if has_data: - group[k] = nap.Tsd( - t=times[index == k], - d=data[index == k], - time_support=time_support, - ) - else: - group[k] = nap.Ts( - t=times[index == k], time_support=time_support - ) - - tsgroup = nap.TsGroup( - group, time_support=time_support, bypass_check=True - ) - - metainfo = {} - for k in set(self.file.keys()) - { - "start", - "end", - "t", - "index", - "d", - "rate", - "keys", - }: - tmp = self.file[k] - if len(tmp) == len(tsgroup): - metainfo[k] = tmp - tsgroup.set_info(**metainfo) - return tsgroup - - elif self.type == "TsdFrame": - return nap.TsdFrame( - t=self.file["t"], - d=self.file["d"], - time_support=time_support, - columns=self.file["columns"], - ) - elif self.type == "TsdTensor": - return nap.TsdTensor( - t=self.file["t"], d=self.file["d"], time_support=time_support - ) - elif self.type == "Tsd": - return nap.Tsd( - t=self.file["t"], d=self.file["d"], time_support=time_support - ) - elif self.type == "Ts": - return nap.Ts(t=self.file["t"], time_support=time_support) - elif self.type == "IntervalSet": - return time_support - else: - return self.file + return self.file.npz_file + + return getattr(nap, self.type)._from_npz_reader(self.file) + \ No newline at end of file diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index cd08228d..75c102af 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -22,7 +22,7 @@ from .suite2p import Suite2P -def load_file(path, lazy_loading=True): +def load_file(path, lazy_loading=None): """Load file. Current format supported is (npz,nwb,) .npz -> If the file is compatible with a pynapple format, the function will return a pynapple object. @@ -34,8 +34,9 @@ def load_file(path, lazy_loading=True): ---------- path : str Path to the file - lazy_loading : bool - Lazy loading of the data, used only for NWB files + lazy_loading : bool, optional + Lazy loading of the data. If not specified, the function will use the defaults + True for nwb and False for npz. Returns ------- @@ -49,8 +50,11 @@ def load_file(path, lazy_loading=True): """ if os.path.isfile(path): if path.endswith(".npz"): - return NPZFile(path).load() + lazy_loading = False if lazy_loading is None else lazy_loading + return NPZFile(path, lazy_loading=lazy_loading).load() + elif path.endswith(".nwb"): + lazy_loading = True if lazy_loading is None else lazy_loading return NWBFile(path, lazy_loading=lazy_loading) else: raise RuntimeError("File format not supported") diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json new file mode 100644 index 00000000..41a9e443 --- /dev/null +++ b/tests/npzfilestest/tsd2.json @@ -0,0 +1,4 @@ +{ + "time": "2024-07-17 00:22:50.255786", + "info": "Test description" +} \ No newline at end of file From 358892e42307e3978ba77d8adc1ee68bd0856c17 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 00:35:07 +0200 Subject: [PATCH 035/195] small impros --- pynapple/core/ts_group.py | 2 +- tests/test_npz_file.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index d2e044d6..f25e6150 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1416,7 +1416,7 @@ def _from_npz_reader(cls, file): t = times[filtering_index] if has_data: - group[key] = nap.Tsd( + group[key] = Tsd( t=t, d=data[filtering_index], time_support=time_support, diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index 74eec96f..2a92522c 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -61,7 +61,7 @@ def test_load(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, data[k]) @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsgroup']) @@ -69,7 +69,7 @@ def test_load_tsgroup(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, data[k]) assert tmp.keys() == data[k].keys() assert np.all(tmp._metadata == data[k]._metadata) assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) @@ -82,7 +82,7 @@ def test_load_tsd(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, data[k]) assert np.all(tmp.d == data[k].d) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -94,7 +94,7 @@ def test_load_ts(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, data[k]) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -106,7 +106,7 @@ def test_load_tsdframe(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert type(tmp) == type(data[k]) + assert isinstance(tmp, data[k]) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) assert np.all(tmp.columns == data[k].columns) From c566989c59a10ff844e2b31f9699b4386bc34ce4 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 00:39:17 +0200 Subject: [PATCH 036/195] blaked --- pynapple/core/_jitted_functions.py | 1 - pynapple/core/base_class.py | 1 - pynapple/core/config.py | 1 - pynapple/core/interval_set.py | 2 +- pynapple/core/time_series.py | 32 +++++++++++---- pynapple/core/ts_group.py | 17 +++----- pynapple/io/interface_npz.py | 57 ++++++++++++++++---------- pynapple/io/misc.py | 4 +- pynapple/process/_process_functions.py | 1 - 9 files changed, 66 insertions(+), 50 deletions(-) diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index dae40a8a..4dae7a99 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -322,7 +322,6 @@ def jitbin_array(time_array, data_array, starts, ends, bin_size): @jit(nopython=True) def _jitbin_array(countin, time_array, data_array, starts, ends, bin_size): - m = starts.shape[0] f = data_array.shape[1:] diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index cf053574..d8f1021f 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -23,7 +23,6 @@ class Base(abc.ABC): _initialized = False def __init__(self, t, time_units="s", time_support=None): - if isinstance(t, TsIndex): self.index = t else: diff --git a/pynapple/core/config.py b/pynapple/core/config.py index 97eaafaa..fbc59ebb 100644 --- a/pynapple/core/config.py +++ b/pynapple/core/config.py @@ -98,7 +98,6 @@ def backend(self, backend): self.set_backend(backend) def set_backend(self, backend): - assert backend in ["numba", "jax"], "Options for backend are 'jax' or 'numba'" # Try to import pynajax diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 1362fce2..b3df9de3 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -669,7 +669,7 @@ def save(self, filename): ) return - + @classmethod def _from_npz_reader(cls, file): """ diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 065aa473..294cba6d 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -858,7 +858,7 @@ def save(self, filename): ) return - + @classmethod def _from_npz_reader(cls, file): """ @@ -874,7 +874,11 @@ def _from_npz_reader(cls, file): Tsd The Tsd object """ - return cls(t=file["t"], d=file["d"], time_support=IntervalSet(start=file["start"], end=file["end"])) + return cls( + t=file["t"], + d=file["d"], + time_support=IntervalSet(start=file["start"], end=file["end"]), + ) class TsdFrame(BaseTsd): @@ -1039,7 +1043,6 @@ def __getitem__(self, key, *args, **kwargs): if all(is_array_like(a) for a in [index, output]): if output.shape[0] == index.shape[0]: - # if isinstance(columns, pd.Index): # if not pd.api.types.is_integer_dtype(columns): kwargs["columns"] = columns @@ -1161,7 +1164,7 @@ def save(self, filename): ) return - + @classmethod def _from_npz_reader(cls, file): """ @@ -1177,7 +1180,12 @@ def _from_npz_reader(cls, file): Tsd The Tsd object """ - return cls(t=file["t"], d=file["d"], columns=file["columns"], time_support=IntervalSet(start=file["start"], end=file["end"])) + return cls( + t=file["t"], + d=file["d"], + columns=file["columns"], + time_support=IntervalSet(start=file["start"], end=file["end"]), + ) class Tsd(BaseTsd): @@ -1500,7 +1508,7 @@ def save(self, filename): ) return - + @classmethod def _from_npz_reader(cls, file): """ @@ -1516,7 +1524,11 @@ def _from_npz_reader(cls, file): Tsd The Tsd object """ - return cls(t=file["t"], d=file["d"], time_support=IntervalSet(start=file["start"], end=file["end"])) + return cls( + t=file["t"], + d=file["d"], + time_support=IntervalSet(start=file["start"], end=file["end"]), + ) class Ts(Base): @@ -1809,7 +1821,7 @@ def save(self, filename): ) return - + @classmethod def _from_npz_reader(cls, file): """ @@ -1825,4 +1837,6 @@ def _from_npz_reader(cls, file): Tsd The Ts object """ - return cls(t=file["t"], time_support=IntervalSet(start=file["start"], end=file["end"])) + return cls( + t=file["t"], time_support=IntervalSet(start=file["start"], end=file["end"]) + ) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 52b3068e..f96ed32e 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1383,7 +1383,7 @@ def save(self, filename): np.savez(filename, **dicttosave) return - + @classmethod def _from_npz_reader(cls, file): """ @@ -1405,7 +1405,6 @@ def _from_npz_reader(cls, file): has_data = "d" in file.keys() time_support = IntervalSet(file["start"], file["end"]) - if has_data: data = file["data"] @@ -1426,23 +1425,17 @@ def _from_npz_reader(cls, file): time_support=time_support, ) else: - group[key] = Ts( - t=t, time_support=time_support - ) + group[key] = Ts(t=t, time_support=time_support) - tsgroup = cls( - group, time_support=time_support, bypass_check=True - ) + tsgroup = cls(group, time_support=time_support, bypass_check=True) metainfo = {} not_info_keys = {"start", "end", "t", "index", "d", "rate", "keys"} - + for k in set(file.keys()) - not_info_keys: tmp = file[k] if len(tmp) == len(tsgroup): metainfo[k] = tmp - + tsgroup.set_info(**metainfo) return tsgroup - - diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 6a232f43..28f939b5 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -13,23 +13,24 @@ from .. import core as nap -# -EXPECTED_ENTRIES = {"TsGroup": {"t", "start", "end", "index"}, - "TsdFrame": {"t", "d", "start", "end", "columns"}, - "TsdTensor": {"t", "d", "start", "end"}, - "Tsd": {"t", "d", "start", "end"}, - "Ts": {"t", "start", "end"}, - "IntervalSet": {"start", "end"}} +# +EXPECTED_ENTRIES = { + "TsGroup": {"t", "start", "end", "index"}, + "TsdFrame": {"t", "d", "start", "end", "columns"}, + "TsdTensor": {"t", "d", "start", "end"}, + "Tsd": {"t", "d", "start", "end"}, + "Ts": {"t", "start", "end"}, + "IntervalSet": {"start", "end"}, +} def _find_class_from_variables(file_variables, data_ndims=None): - if data_ndims is not None: # either TsdTensor or Tsd: assert EXPECTED_ENTRIES["Tsd"].issubset(file_variables) return "Tsd" if data_ndims == 1 else "TsdTensor" - + for possible_type, espected_variables in EXPECTED_ENTRIES.items(): if espected_variables.issubset(file_variables): return possible_type @@ -38,28 +39,39 @@ def _find_class_from_variables(file_variables, data_ndims=None): class LazyNPZLoader: - """Class that lazily loads an NPZ file. - """ + """Class that lazily loads an NPZ file.""" + def __init__(self, file_path, lazy_loading=False): self.lazy_loading = lazy_loading self.file_path = file_path - self.npz_file = np.load(file_path, allow_pickle=True, mmap_mode='r' if lazy_loading else None) + self.npz_file = np.load( + file_path, allow_pickle=True, mmap_mode="r" if lazy_loading else None + ) self.data = {key: None for key in self.npz_file.keys()} def __getitem__(self, key): if key not in self.data: raise KeyError(f"{key} not found in the NPZ file") - + if self.data[key] is None: self.data[key] = self._load_array(key) - + return self.data[key] def _load_array(self, key): if self.lazy_loading: - array_info = self.npz_file.zip.read(self.npz_file.zip.NameToInfo[key].filename) - np_array = np.frombuffer(array_info, dtype=self.npz_file[key].dtype).reshape(self.npz_file[key].shape) - return np.memmap(self.npz_file.filename, dtype=np_array.dtype, mode='r', shape=np_array.shape) + array_info = self.npz_file.zip.read( + self.npz_file.zip.NameToInfo[key].filename + ) + np_array = np.frombuffer( + array_info, dtype=self.npz_file[key].dtype + ).reshape(self.npz_file[key].shape) + return np.memmap( + self.npz_file.filename, + dtype=np_array.dtype, + mode="r", + shape=np_array.shape, + ) else: return self.npz_file[key] @@ -83,8 +95,8 @@ class NPZFile(object): dtype: int64 """ + # valid_types = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] - def __init__(self, path, lazy_loading=False): """Initialization of the NPZ file @@ -96,7 +108,9 @@ def __init__(self, path, lazy_loading=False): """ self.path = path self.name = os.path.basename(path) - self.file = LazyNPZLoader(path, lazy_loading=lazy_loading) # np.load(self.path, allow_pickle=True) + self.file = LazyNPZLoader( + path, lazy_loading=lazy_loading + ) # np.load(self.path, allow_pickle=True) type_ = "" # First check if type is explicitely defined in the file: @@ -123,6 +137,5 @@ def load(self): """ if self.type == "npz": return self.file.npz_file - - return getattr(nap, self.type)._from_npz_reader(self.file) - \ No newline at end of file + + return getattr(nap, self.type)._from_npz_reader(self.file) diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index 75c102af..d7a67b69 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -36,7 +36,7 @@ def load_file(path, lazy_loading=None): Path to the file lazy_loading : bool, optional Lazy loading of the data. If not specified, the function will use the defaults - True for nwb and False for npz. + True for nwb and False for npz. Returns ------- @@ -52,7 +52,7 @@ def load_file(path, lazy_loading=None): if path.endswith(".npz"): lazy_loading = False if lazy_loading is None else lazy_loading return NPZFile(path, lazy_loading=lazy_loading).load() - + elif path.endswith(".nwb"): lazy_loading = True if lazy_loading is None else lazy_loading return NWBFile(path, lazy_loading=lazy_loading) diff --git a/pynapple/process/_process_functions.py b/pynapple/process/_process_functions.py index 9d9ea5ff..b528a52d 100644 --- a/pynapple/process/_process_functions.py +++ b/pynapple/process/_process_functions.py @@ -227,7 +227,6 @@ def _perievent_trigger_average( def _perievent_continuous( time_array, data_array, time_target_array, starts, ends, windowsize ): - idx, slice_idx, N_target, w_starts = _jitcontinuous_perievent( time_array, time_target_array, starts, ends, windowsize ) From 9079ddc5cc4730195f87bc9da40990bfa01261a6 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 00:47:35 +0200 Subject: [PATCH 037/195] fixed tests --- tests/npzfilestest/tsd2.json | 2 +- tests/test_npz_file.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json index 41a9e443..73959461 100644 --- a/tests/npzfilestest/tsd2.json +++ b/tests/npzfilestest/tsd2.json @@ -1,4 +1,4 @@ { - "time": "2024-07-17 00:22:50.255786", + "time": "2024-07-17 00:45:43.446025", "info": "Test description" } \ No newline at end of file diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index 2a92522c..a1978799 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -61,7 +61,7 @@ def test_load(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert isinstance(tmp, data[k]) + assert isinstance(tmp, type(data[k])) @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsgroup']) @@ -69,7 +69,7 @@ def test_load_tsgroup(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert isinstance(tmp, data[k]) + assert isinstance(tmp, type(data[k])) assert tmp.keys() == data[k].keys() assert np.all(tmp._metadata == data[k]._metadata) assert np.all(tmp[neu] == data[k][neu] for neu in tmp.keys()) @@ -82,7 +82,7 @@ def test_load_tsd(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert isinstance(tmp, data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.d == data[k].d) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -94,7 +94,7 @@ def test_load_ts(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert isinstance(tmp, data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) @@ -106,7 +106,7 @@ def test_load_tsdframe(path, k): file_path = os.path.join(path, k+".npz") file = nap.NPZFile(file_path) tmp = file.load() - assert isinstance(tmp, data[k]) + assert isinstance(tmp, type(data[k])) assert np.all(tmp.t == data[k].t) np.testing.assert_array_almost_equal(tmp.time_support.values, data[k].time_support.values) assert np.all(tmp.columns == data[k].columns) @@ -118,7 +118,7 @@ def test_load_tsdframe(path, k): def test_load_non_npz(path): file_path = os.path.join(path, "random.npz") tmp = np.random.rand(100) - np.savez(file_path, a = tmp) + np.savez(file_path, a=tmp) file = nap.NPZFile(file_path) assert file.type == "npz" From 295dfacd9ee2a56461a9d3bde3a2545b3f8ecc81 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 09:35:13 +0200 Subject: [PATCH 038/195] Moved class method to base class --- pynapple/core/base_class.py | 20 +++++++++ pynapple/core/time_series.py | 83 ------------------------------------ tests/npzfilestest/tsd2.json | 4 -- 3 files changed, 20 insertions(+), 87 deletions(-) delete mode 100644 tests/npzfilestest/tsd2.json diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index d8f1021f..79f13486 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -426,6 +426,26 @@ def get(self, start, end=None, time_units="s"): idx_end = np.searchsorted(time_array, end, side="right") return self[idx_start:idx_end] + @classmethod + def _from_npz_reader(cls, file): + """Load a time series object from a npz file interface. + + Parameters + ---------- + file : npz file interface + The file interface to read from + + Returns + ------- + out : Ts or Tsd or TsdFrame or TsdTensor + The time series object + """ + kwargs = { + key: file[key] for key in file.keys() if key not in ["start", "end", "type"] + } + iset = IntervalSet(start=file["start"], end=file["end"]) + return cls(time_support=iset, **kwargs) + # def find_gaps(self, min_gap, time_units='s'): # """ # finds gaps in a tsd larger than min_gap. Return an IntervalSet. diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 294cba6d..f3de4cf6 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -859,27 +859,6 @@ def save(self, filename): return - @classmethod - def _from_npz_reader(cls, file): - """ - Load a TsdTensor object from a npz file. - - Parameters - ---------- - file : str - The opened npz file - - Returns - ------- - Tsd - The Tsd object - """ - return cls( - t=file["t"], - d=file["d"], - time_support=IntervalSet(start=file["start"], end=file["end"]), - ) - class TsdFrame(BaseTsd): """ @@ -1165,28 +1144,6 @@ def save(self, filename): return - @classmethod - def _from_npz_reader(cls, file): - """ - Load a Tsd object from a npz file. - - Parameters - ---------- - file : str - The opened npz file - - Returns - ------- - Tsd - The Tsd object - """ - return cls( - t=file["t"], - d=file["d"], - columns=file["columns"], - time_support=IntervalSet(start=file["start"], end=file["end"]), - ) - class Tsd(BaseTsd): """ @@ -1509,27 +1466,6 @@ def save(self, filename): return - @classmethod - def _from_npz_reader(cls, file): - """ - Load a Tsd object from a npz file. - - Parameters - ---------- - file : str - The opened npz file - - Returns - ------- - Tsd - The Tsd object - """ - return cls( - t=file["t"], - d=file["d"], - time_support=IntervalSet(start=file["start"], end=file["end"]), - ) - class Ts(Base): """ @@ -1821,22 +1757,3 @@ def save(self, filename): ) return - - @classmethod - def _from_npz_reader(cls, file): - """ - Load a Ts object from a npz file. - - Parameters - ---------- - file : str - The opened npz file - - Returns - ------- - Tsd - The Ts object - """ - return cls( - t=file["t"], time_support=IntervalSet(start=file["start"], end=file["end"]) - ) diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json deleted file mode 100644 index 73959461..00000000 --- a/tests/npzfilestest/tsd2.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "time": "2024-07-17 00:45:43.446025", - "info": "Test description" -} \ No newline at end of file From ae2aa3e674ef8e98ffa9da5f6719d75f83d3be1b Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 10:39:16 +0200 Subject: [PATCH 039/195] Final cleanup and load_file test in lazy_load --- pynapple/io/interface_npz.py | 49 +++--------------------------------- pynapple/io/misc.py | 18 ++++++++----- tests/npzfilestest/tsd2.json | 4 +++ tests/test_lazy_loading.py | 23 ++++++++++++----- tests/test_npz_file.py | 5 ---- 5 files changed, 36 insertions(+), 63 deletions(-) create mode 100644 tests/npzfilestest/tsd2.json diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 28f939b5..0ed0dae1 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -38,47 +38,6 @@ def _find_class_from_variables(file_variables, data_ndims=None): return "npz" -class LazyNPZLoader: - """Class that lazily loads an NPZ file.""" - - def __init__(self, file_path, lazy_loading=False): - self.lazy_loading = lazy_loading - self.file_path = file_path - self.npz_file = np.load( - file_path, allow_pickle=True, mmap_mode="r" if lazy_loading else None - ) - self.data = {key: None for key in self.npz_file.keys()} - - def __getitem__(self, key): - if key not in self.data: - raise KeyError(f"{key} not found in the NPZ file") - - if self.data[key] is None: - self.data[key] = self._load_array(key) - - return self.data[key] - - def _load_array(self, key): - if self.lazy_loading: - array_info = self.npz_file.zip.read( - self.npz_file.zip.NameToInfo[key].filename - ) - np_array = np.frombuffer( - array_info, dtype=self.npz_file[key].dtype - ).reshape(self.npz_file[key].shape) - return np.memmap( - self.npz_file.filename, - dtype=np_array.dtype, - mode="r", - shape=np_array.shape, - ) - else: - return self.npz_file[key] - - def keys(self): - return self.npz_file.keys() - - class NPZFile(object): """Class that points to a NPZ file that can be loaded as a pynapple object. Objects have a save function in npz format as well as the Folder class. @@ -98,7 +57,7 @@ class NPZFile(object): # valid_types = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] - def __init__(self, path, lazy_loading=False): + def __init__(self, path): """Initialization of the NPZ file Parameters @@ -108,9 +67,7 @@ def __init__(self, path, lazy_loading=False): """ self.path = path self.name = os.path.basename(path) - self.file = LazyNPZLoader( - path, lazy_loading=lazy_loading - ) # np.load(self.path, allow_pickle=True) + self.file = np.load(self.path, allow_pickle=True) type_ = "" # First check if type is explicitely defined in the file: @@ -136,6 +93,6 @@ def load(self): A pynapple object """ if self.type == "npz": - return self.file.npz_file + return self.file return getattr(nap, self.type)._from_npz_reader(self.file) diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index d7a67b69..81cc4c95 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -20,6 +20,7 @@ from .neurosuite import NeuroSuite from .phy import Phy from .suite2p import Suite2P +import warnings def load_file(path, lazy_loading=None): @@ -34,9 +35,9 @@ def load_file(path, lazy_loading=None): ---------- path : str Path to the file - lazy_loading : bool, optional + lazy_loading : bool, optional default True Lazy loading of the data. If not specified, the function will use the defaults - True for nwb and False for npz. + True. Works only with NWB files. Returns ------- @@ -50,12 +51,17 @@ def load_file(path, lazy_loading=None): """ if os.path.isfile(path): if path.endswith(".npz"): - lazy_loading = False if lazy_loading is None else lazy_loading - return NPZFile(path, lazy_loading=lazy_loading).load() + if lazy_loading: + warnings.warn("Lazy loading is not supported for NPZ files") + return NPZFile(path).load() elif path.endswith(".nwb"): - lazy_loading = True if lazy_loading is None else lazy_loading - return NWBFile(path, lazy_loading=lazy_loading) + # preserves class init default: + kwargs_for_lazyloading = ( + {} if lazy_loading is None else {"lazy_loading": lazy_loading} + ) + + return NWBFile(path, **kwargs_for_lazyloading) else: raise RuntimeError("File format not supported") else: diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json new file mode 100644 index 00000000..b091ce6c --- /dev/null +++ b/tests/npzfilestest/tsd2.json @@ -0,0 +1,4 @@ +{ + "time": "2024-07-17 10:34:35.997644", + "info": "Test description" +} \ No newline at end of file diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 96c30bc9..a4da8074 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -206,12 +206,23 @@ def test_lazy_load_nwb(lazy): except: nwb = nap.NWBFile("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) - tsd = nwb["z"] - if lazy: - assert isinstance(tsd.d, h5py.Dataset) - else: - assert not isinstance(tsd.d, h5py.Dataset) - nwb.io.close() + assert isinstance(nwb["z"].d, h5py.Dataset) is lazy + + +@pytest.mark.parametrize( + "lazy", + [ + (True), + (False), + ] +) +def test_lazy_load_function(lazy): + try: + nwb = nap.load_file("tests/nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + except: + nwb = nap.load_file("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) + + assert isinstance(nwb["z"].d, h5py.Dataset) is lazy @pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index a1978799..c69805ba 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -125,8 +125,3 @@ def test_load_non_npz(path): a = file.load() assert isinstance(a, np.lib.npyio.NpzFile) np.testing.assert_array_equal(tmp, a['a']) - - - - - From d68f34823bba97ee6336f8f276b7590afce992cd Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 10:43:04 +0200 Subject: [PATCH 040/195] Comment removed --- pynapple/io/interface_npz.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 0ed0dae1..932813f1 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -55,8 +55,6 @@ class NPZFile(object): """ - # valid_types = ["Ts", "Tsd", "TsdFrame", "TsdTensor", "TsGroup", "IntervalSet"] - def __init__(self, path): """Initialization of the NPZ file From d469536be2de79488a11631d4c00c66b1e58627d Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 10:49:30 +0200 Subject: [PATCH 041/195] linting --- pynapple/core/time_series.py | 8 +++++++- pynapple/io/__init__.py | 8 +++++++- pynapple/io/misc.py | 2 +- pynapple/process/perievent.py | 5 ++++- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index f3de4cf6..5ad4b7c0 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -27,7 +27,13 @@ from scipy import signal from tabulate import tabulate -from ._core_functions import _bin_average, _convolve, _dropna, _restrict, _threshold +from ._core_functions import ( + _bin_average, + _convolve, + _dropna, + _restrict, + _threshold, +) from .base_class import Base from .interval_set import IntervalSet from .time_index import TsIndex diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index 20b194c7..f4eb2a70 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,4 +1,10 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session +from .misc import ( + append_NWB_LFP, + load_eeg, + load_file, + load_folder, + load_session, +) diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index 81cc4c95..2dff3fc1 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -5,6 +5,7 @@ """ import os +import warnings from xml.dom import minidom import numpy as np @@ -20,7 +21,6 @@ from .neurosuite import NeuroSuite from .phy import Phy from .suite2p import Suite2P -import warnings def load_file(path, lazy_loading=None): diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index a3dbb5d1..84aed7b1 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,7 +5,10 @@ import numpy as np from .. import core as nap -from ._process_functions import _perievent_continuous, _perievent_trigger_average +from ._process_functions import ( + _perievent_continuous, + _perievent_trigger_average, +) def _align_tsd(tsd, tref, window, time_support): From ed2726a51ad0568ace6b5beb2ba480b39a74c0f3 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 11:04:34 +0200 Subject: [PATCH 042/195] isorted black --- pynapple/core/time_series.py | 8 +------- pynapple/io/__init__.py | 8 +------- pynapple/process/perievent.py | 5 +---- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 5ad4b7c0..f3de4cf6 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -27,13 +27,7 @@ from scipy import signal from tabulate import tabulate -from ._core_functions import ( - _bin_average, - _convolve, - _dropna, - _restrict, - _threshold, -) +from ._core_functions import _bin_average, _convolve, _dropna, _restrict, _threshold from .base_class import Base from .interval_set import IntervalSet from .time_index import TsIndex diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index f4eb2a70..20b194c7 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,10 +1,4 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .misc import ( - append_NWB_LFP, - load_eeg, - load_file, - load_folder, - load_session, -) +from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 84aed7b1..a3dbb5d1 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,10 +5,7 @@ import numpy as np from .. import core as nap -from ._process_functions import ( - _perievent_continuous, - _perievent_trigger_average, -) +from ._process_functions import _perievent_continuous, _perievent_trigger_average def _align_tsd(tsd, tref, window, time_support): From 5ed5b769e90d7476d54cc630593e72b190c02436 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 12:28:13 +0200 Subject: [PATCH 043/195] Added option for closing an open file --- pynapple/io/interface_nwb.py | 4 ++++ tests/npzfilestest/tsd2.json | 2 +- tests/test_lazy_loading.py | 2 ++ tests/test_nwb.py | 1 + 4 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 70f33956..124b9787 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -480,3 +480,7 @@ def __getitem__(self, key): return self.data[key] else: raise KeyError("Can't find key {} in group index.".format(key)) + + def close(self): + """Close the NWB file""" + self.io.close() diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json index b091ce6c..b64c10cb 100644 --- a/tests/npzfilestest/tsd2.json +++ b/tests/npzfilestest/tsd2.json @@ -1,4 +1,4 @@ { - "time": "2024-07-17 10:34:35.997644", + "time": "2024-07-17 12:27:07.477318", "info": "Test description" } \ No newline at end of file diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index a4da8074..0e3aa929 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -207,6 +207,7 @@ def test_lazy_load_nwb(lazy): nwb = nap.NWBFile("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) assert isinstance(nwb["z"].d, h5py.Dataset) is lazy + nwb.close() @pytest.mark.parametrize( @@ -223,6 +224,7 @@ def test_lazy_load_function(lazy): nwb = nap.load_file("nwbfilestest/basic/pynapplenwb/A2929-200711.nwb", lazy_loading=lazy) assert isinstance(nwb["z"].d, h5py.Dataset) is lazy + nwb.close() @pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) diff --git a/tests/test_nwb.py b/tests/test_nwb.py index 943726e9..f27cbafa 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -86,6 +86,7 @@ def test_NWBFile(): assert nwb.name == "A2929-200711" assert isinstance(nwb.io, pynwb.NWBHDF5IO) + nwb.close() def test_NWBFile_missing_file(): From f110d1e0e3250b0e30e8ca6b856cb39fc3e796a5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 16:44:18 +0100 Subject: [PATCH 044/195] better tests, finished notebook 1 --- docs/examples/tutorial_signal_processing.py | 705 ++++++++------------ pynapple/process/signal_processing.py | 32 +- tests/test_signal_processing.py | 27 +- 3 files changed, 310 insertions(+), 454 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 80296e65..7f2383a9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,17 +1,16 @@ # -*- coding: utf-8 -*- """ -Signal Processing Local Field Potentials +Grosmark & Buzsáki (2016) Tutorial ============ +This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. +We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). -Working with Local Field Potential data. - -See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. +Specifically, we will examine Local Field Potential data from a period of active traversal of a linear track. This tutorial was made by Kipp Freud. """ -import math -import os + # %% # !!! warning @@ -19,13 +18,14 @@ # # You can install all with `pip install matplotlib requests tqdm` # -# mkdocs_gallery_thumbnail_number = 1 -# -# Now, import the necessary libraries: +# First, import the necessary libraries: + +import math +import os + import matplotlib.pyplot as plt import numpy as np import requests -import scipy import tqdm import pynapple as nap @@ -34,7 +34,7 @@ # *** # Downloading the data # ------------------ -# First things first: Let's download the data and save it locally +# Let's download the data and save it locally path = "Achilles_10252013_EEG.nwb" if path not in os.listdir("."): @@ -54,7 +54,7 @@ # *** # Loading the data # ------------------ -# Loading the data, calculating the sampling frequency +# Let's load and print the full dataset. data = nap.load_file(path) FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) @@ -65,172 +65,125 @@ # *** # Selecting slices # ----------------------------------- -# Let's consider two 60-second slices of data, one from the sleep epoch and one from wake - -REM_minute_interval = nap.IntervalSet( - data["rem"]["start"][0] + 60.0, - data["rem"]["start"][0] + 120.0, +# We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, +# followed by 4 seconds of post-traversal activity. + +# Define the run to use for this Analysis +run_index = 7 +# Define the IntervalSet for this run and instantiate both LFP and +# Position TsdFrame objects +RUN_interval = nap.IntervalSet( + data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] + 4.0, ) - -SWS_minute_interval = nap.IntervalSet( - data["nrem"]["start"][0] + 10.0, - data["nrem"]["start"][0] + 70.0, +RUN_Tsd = nap.TsdFrame( + t=data["eeg"].restrict(RUN_interval).index.values - + data["forward_ep"]["start"][run_index], + d=data["eeg"].restrict(RUN_interval).values, ) - -RUN_minute_interval = nap.IntervalSet( - data["forward_ep"]["start"][-18] + 0.0, - data["forward_ep"]["start"][-18] + 60.0, +RUN_pos = nap.TsdFrame( + t=data["position"].restrict(RUN_interval).index.values - + data["forward_ep"]["start"][run_index], + d=data["position"].restrict(RUN_interval).asarray(), ) +# The given dataset has only one channel, so we set channel = 0 here +channel = 0 -REM_minute = nap.TsdFrame( - t=data["eeg"].restrict(REM_minute_interval).index.values, - d=data["eeg"].restrict(REM_minute_interval).values, - time_support=data["eeg"].restrict(REM_minute_interval).time_support, -) +# %% +# *** +# Plotting the LFP and Behavioural Activity +# ----------------------------------- -SWS_minute = nap.TsdFrame( - t=data["eeg"].restrict(SWS_minute_interval).index.values, - d=data["eeg"].restrict(SWS_minute_interval).values, - time_support=data["eeg"].restrict(SWS_minute_interval).time_support, +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [ + ["ephys"], + ["pos"] + ], + height_ratios=[1, 0.2], ) -RUN_minute = nap.TsdFrame( - t=data["eeg"].restrict(RUN_minute_interval).index.values, - d=data["eeg"].restrict(RUN_minute_interval).values, - time_support=data["eeg"].restrict(RUN_minute_interval).time_support, +axd["ephys"].plot( + RUN_Tsd[:, channel].restrict( + nap.IntervalSet( + 0.0, + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index] + ) + ), + label="Traversal LFP Data", + color="green" ) - -RUN_position = nap.TsdFrame( - t=data["position"].restrict(RUN_minute_interval).index.values[:], - d=data["position"].restrict(RUN_minute_interval), - time_support=data["position"].restrict(RUN_minute_interval).time_support, +axd["ephys"].plot( + RUN_Tsd[:, channel].restrict( + nap.IntervalSet( + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] - + data["forward_ep"]["start"][run_index] + 5.0, + ) + ), + label="Post Traversal LFP Data", + color="blue" ) +axd["ephys"].set_title("Traversal & Post Traversal LFP") +axd["ephys"].set_ylabel("LFP (v)") +axd["ephys"].set_xlabel("time (s)") +axd["ephys"].margins(0) +axd["ephys"].legend() +axd["pos"].plot(RUN_pos, color="black") +axd["pos"].margins(0) +axd["pos"].set_xlabel("time (s)") +axd["pos"].set_ylabel("Linearized Position") +axd["pos"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -channel = 0 # %% # *** -# Plotting the LFP activity of one slices +# Getting the LFP Spectogram # ----------------------------------- -# Let's plot - -fig, ax = plt.subplots(3) - -for channel in range(SWS_minute.shape[1]): - ax[0].plot( - SWS_minute[:, channel], - alpha=0.5, - label="Sleep Data", - ) -ax[0].set_title("non-REM ephys") -ax[0].set_ylabel("LFP (v)") -ax[0].set_xlabel("time (s)") -ax[0].margins(0) -for channel in range(REM_minute.shape[1]): - ax[1].plot(REM_minute[:, channel], alpha=0.5, label="Wake Data", color="orange") -ax[1].set_ylabel("LFP (v)") -ax[1].set_xlabel("time (s)") -ax[1].set_title("REM ephys") -ax[1].margins(0) -for channel in range(RUN_minute.shape[1]): - ax[2].plot(RUN_minute[:, channel], alpha=0.5, label="Wake Data", color="green") -ax[2].set_ylabel("LFP (v)") -ax[2].set_xlabel("time (s)") -ax[2].set_title("Running ephys") -ax[2].margins(0) -plt.show() - - -# %% # Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present -channel = 0 -fig, ax = plt.subplots(3) -fft = nap.compute_spectogram(SWS_minute, fs=int(FS)) -ax[0].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Sleep Data", c="blue" -) -ax[0].set_xlim((0, FS / 2)) -ax[0].set_xlabel("Freq (Hz)") -ax[0].set_ylabel("Frequency Power") - -ax[0].set_title("non-REM LFP Decomposition") -fft = nap.compute_spectogram(REM_minute, fs=int(FS)) -ax[1].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Wake Data", c="orange" -) -ax[1].set_xlim((0, FS / 2)) -fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[1].set_title("REM LFP Decomposition") -ax[1].set_xlabel("Freq (Hz)") -ax[1].set_ylabel("Frequency Power") - -fft = nap.compute_spectogram(RUN_minute, fs=int(FS)) -ax[2].plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, label="Running Data", c="green" -) -ax[2].set_xlim((0, FS / 2)) -fig.suptitle(f"Fourier Decomposition for channel {channel}") -ax[2].set_title("Running LFP Decomposition") -ax[2].set_xlabel("Freq (Hz)") -ax[2].set_ylabel("Frequency Power") -# ax.legend() -plt.show() +fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) - -# %% -# Let's now consider the Welch spectograms of waking and sleeping data... - -fig, ax = plt.subplots(3) -welch = nap.compute_welch_spectogram(SWS_minute, fs=int(FS)) -ax[0].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="non-REM Data", - color="blue", -) -ax[0].set_xlim((0, FS / 2)) -ax[0].set_title("non-REM LFP Decomposition") -ax[0].set_xlabel("Freq (Hz)") -ax[0].set_ylabel("Frequency Power") -welch = nap.compute_welch_spectogram(REM_minute, fs=int(FS)) -ax[1].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="REM Data", - color="orange", -) -ax[1].set_xlim((0, FS / 2)) -fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[1].set_title("REM LFP Decomposition") -ax[1].set_xlabel("Freq (Hz)") -ax[1].set_ylabel("Frequency Power") - -welch = nap.compute_welch_spectogram(RUN_minute, fs=int(FS)) -ax[2].plot( - welch.index, - np.abs(welch.iloc[:, channel]), - alpha=0.5, - label="Running Data", - color="green", +# Now we will plot it +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, + label="LFP Frequency Power", c="blue", linewidth=2 ) -ax[2].set_xlim((0, FS / 2)) -fig.suptitle(f"Welch Decomposition for channel {channel}") -ax[2].set_title("Running LFP Decomposition") -ax[2].set_xlabel("Freq (Hz)") -ax[2].set_ylabel("Frequency Power") -# ax.legend() -plt.show() +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.set_xlim(1, 30) +ax.axvline(9.36, c="orange", label="9.36Hz", alpha=0.5) +ax.axvline(18.74, c="green", label="18.74Hz", alpha=0.5) +ax.legend() +# ax.set_yscale('log') +# ax.set_xscale('log') # %% -# There seems to be some differences presenting themselves - a bump in higher frequencies for waking data? -# Let's explore further with a wavelet decomposition +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# It looks like the prominent frequencies in the data may vary over time. For example, it looks like the +# LFP characteristics may be different while the animal is running along the track, and when it is finished. +# Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. +# We must define the frequency set that we'd like to use for our decomposition; these +# have been manually selected based on the frequencies used in Frey et. al (2021), but +# could also be defined as a linspace or logspace +freqs = np.array( + [ + 2.59, 3.66, 5.18, 8.0, 10.36, 20.72, 29.3, 41.44, 58.59, 82.88, + 117.19, 152.35, 192.19, 200., 234.38, 270.00, 331.5, 390.00, + ] +) +mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) +#Define wavelet decomposition plotting function def plot_timefrequency( - times, freqs, powers, x_ticks=5, y_ticks=None, ax=None, **kwargs + times, freqs, powers, x_ticks=5, ax=None, **kwargs ): if np.iscomplexobj(powers): powers = abs(powers) @@ -244,325 +197,197 @@ def plot_timefrequency( else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - if isinstance(y_ticks, int): - y_ticks_pos = np.linspace(0, freqs.size, y_ticks) - y_ticks = np.round(np.linspace(freqs[0], freqs[-1], y_ticks), 2) - else: - y_ticks = freqs - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 10 + +# And plot +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ - ["wd_sws"], - ["lfp_sws"], - ["wd_rem"], - ["lfp_rem"], ["wd_run"], ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.2, 1, 0.2, 1, 0.2, 0.2], -) -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 14.65, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 150.00, - 190.00, - 234.38, - 270.00, - 331.5, - 390.00, - # 468.75, - # 520.00, - # 570.00, - # 624.0, - ] -) -mwt_SWS = nap.compute_wavelet_transform(SWS_minute[:, channel], fs=None, freqs=freqs) -plot_timefrequency( - SWS_minute.index.values[:], - freqs[:], - np.transpose(mwt_SWS[:, :].values), - ax=axd["wd_sws"], -) -axd["wd_sws"].set_title(f"non-REM Data Wavelet Decomposition: Channel {channel}") - -mwt_REM = nap.compute_wavelet_transform(REM_minute[:, channel], fs=None, freqs=freqs) -plot_timefrequency( - REM_minute.index.values[:], - freqs[:], - np.transpose(mwt_REM[:, :].values), - ax=axd["wd_rem"], + height_ratios=[1, 0.2, 0.4], ) -axd["wd_rem"].set_title(f"REM Data Wavelet Decomposition: Channel {channel}") - -mwt_RUN = nap.compute_wavelet_transform(RUN_minute[:, channel], fs=None, freqs=freqs) plot_timefrequency( - RUN_minute.index.values[:], + RUN_Tsd.index.values[:], freqs[:], np.transpose(mwt_RUN[:, :].values), ax=axd["wd_run"], ) -axd["wd_run"].set_title(f"Running Data Wavelet Decomposition: Channel {channel}") - -axd["lfp_sws"].plot(SWS_minute) -axd["lfp_rem"].plot(REM_minute) -axd["lfp_run"].plot(RUN_minute) -axd["pos_run"].plot(RUN_position) -axd["pos_run"].margins(0) -for k in ["lfp_sws", "lfp_rem", "lfp_run"]: +axd["wd_run"].set_title(f"Wavelet Decomposition") +axd["lfp_run"].plot(RUN_Tsd) +axd["pos_run"].plot(RUN_pos) +axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_ylabel("Lin. Position (cm)") +for k in ["lfp_run", "pos_run"]: axd[k].margins(0) - axd[k].set_ylabel("LFP (v)") + if k != "pos_run": + axd[k].set_ylabel("LFP (v)") axd[k].get_xaxis().set_visible(False) - axd[k].spines["top"].set_visible(False) - axd[k].spines["right"].set_visible(False) - axd[k].spines["bottom"].set_visible(False) - axd[k].spines["left"].set_visible(False) -plt.show() - -# %%g -freq = 3 -interval = (REM_minute_interval["start"] + 0, REM_minute_interval["start"] + 5) -REM_second = REM_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_REM_second = mwt_REM.restrict(nap.IntervalSet(interval[0], interval[1])) -fig, ax = plt.subplots(1) -ax.plot(REM_second.index.values, REM_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot( - REM_second.index.values, - mwt_REM_second[:, freq].values.real, - label="Theta oscillations", -) -ax.set_title(f"{freqs[freq]}Hz oscillation power.") -plt.show() + for spine in ["top", "right", "bottom", "left"]: + axd[k].spines[spine].set_visible(False) # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 10 +# *** +# Visualizing Theta Band Power +# ----------------------------------- +# There seems to be a strong theta frequency present in the data during the maze traversal. +# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well +# they match up +theta_freq_index = 3 +theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real +theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ - ["raw_lfp"] * 2, - ["wavelet"] * 2, - ["fit_wavelet"] * 2, - ["wavelet_power"] * 2, - ["wavelet_phase"] * 2, - ] - + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], + ["lfp_run"], + ["pos_run"], + ], + height_ratios=[1, 0.3], ) +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], + alpha=0.5, + label="LFP Data") +axd["lfp_run"].plot( + RUN_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["lfp_run"].plot( + RUN_Tsd.index.values, + theta_band_power_envelope, + label=f"{freqs[theta_freq_index]}Hz power envelope", +) -# _, ax = plt.subplots(25, figsize=(10, 50)) -mwt_REM = np.transpose(mwt_REM_second) -axd["raw_lfp"].plot(REM_second.index, REM_second.values[:, 0]) -axd["raw_lfp"].margins(0) -plot_timefrequency(REM_second.index, freqs, np.abs(mwt_REM[:, :]), ax=axd["wavelet"]) - -axd["fit_wavelet"].plot(REM_second.index, REM_second.values[:, 0]) -axd["fit_wavelet"].plot(REM_second.index, mwt_REM[freq, :].real) -axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") -axd["fit_wavelet"].margins(0) - -axd["wavelet_power"].plot(REM_second.index, np.abs(mwt_REM[freq, :])) -axd["wavelet_power"].margins(0) -# ax[3].plot(lfp.index, lfp.values[:,0]) -axd["wavelet_phase"].plot(REM_second.index, np.angle(mwt_REM[freq, :])) -axd["wavelet_phase"].margins(0) - -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) - & (data["units"][i].times() < interval[1]) - ] - -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle(mwt_REM[freq, np.argmin(np.abs(REM_second.index.values - spike))]) - ) - phase[i] = np.array(phase_i) - -spikes = {k: v for k, v in spikes.items() if len(v) > 20} -phase = {k: v for k, v in phase.items() if len(v) > 20} - -variances = { - key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) - for key, value in phase.items() -} -spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) -phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) - -for i in range(num_cells): - axd[f"spikes_phasetime_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]] - ) - axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) - axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) - axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") - axd[f"spikes_phasetime_{i}"].set_ylabel("phase") - - axd[f"spikephase_hist_{i}"].hist( - phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10) - ) - axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["pos_run"].plot(RUN_pos) +[axd[k].margins(0) for k in ["lfp_run", "pos_run"]] +[axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"]] +axd["pos_run"].get_xaxis().set_visible(False) +axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_ylabel("Lin. Position (cm)") +axd["lfp_run"].legend() -plt.tight_layout() -plt.show() +# %% +# *** +# Visualizing Sharp Wave Ripple Power +# ----------------------------------- +# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. +# Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and +# see what's going on. +ripple_freq_idx = 13 +ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) -# %% -# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data -freq = 12 -# interval = (10, 15) -interval = (SWS_minute_interval["start"] + 30, SWS_minute_interval["start"] + 50) -SWS_second = SWS_minute.restrict(nap.IntervalSet(interval[0], interval[1])) -mwt_SWS_second = mwt_SWS.restrict(nap.IntervalSet(interval[0], interval[1])) -_, ax = plt.subplots(1) -ax.plot(SWS_second[:, channel], alpha=0.5, label="Wake Data") -ax.plot( - SWS_second.index.values, - mwt_SWS_second[:, freq].values.real, - label="Slow Wave Oscillations", +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +axd = fig.subplot_mosaic( + [ + ["lfp_run"], + ["rip_pow"], + ], + height_ratios=[1, 0.4], ) -ax.set_title(f"{freqs[freq]}Hz oscillation power") -plt.show() +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].margins(0) +axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") +axd["rip_pow"].plot(RUN_Tsd.index.values, + ripple_power + ) +axd["rip_pow"].margins(0) +axd["rip_pow"].get_xaxis().set_visible(False) +axd["rip_pow"].spines["top"].set_visible(False) +axd["rip_pow"].spines["right"].set_visible(False) +axd["rip_pow"].spines["bottom"].set_visible(False) +axd["rip_pow"].spines["left"].set_visible(False) +axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") # %% -# Let's plot spike phase, time scatter plots to see if spikes display phase characteristics during wakeful theta oscillation - -fig = plt.figure(constrained_layout=True, figsize=(10, 50)) -num_cells = 5 +# *** +# Isolating Ripple Times +# ----------------------------------- +# We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold +# to try to isolate this event. + +# define our threshold +threshold = 100 +# smooth our wavelet power +window_size = 51 +window = np.ones(window_size) / window_size +smoother_swr_power = np.convolve(np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode='same') +# isolate our ripple periods +is_ripple = smoother_swr_power > threshold +start_idx = None +ripple_periods = [] +for i in range(len(RUN_Tsd.index.values)): + if is_ripple[i] and start_idx is None: + start_idx = i + elif not is_ripple[i] and start_idx is not None: + axd["rip_pow"].plot(RUN_Tsd.index.values[start_idx:i], smoother_swr_power[start_idx:i], color='red', linewidth=2) + ripple_periods.append( (start_idx, i) ) + start_idx = None + +# plot of captured ripple periods +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ - ["raw_lfp"] * 2, - ["wavelet"] * 2, - ["fit_wavelet"] * 2, - ["wavelet_power"] * 2, - ["wavelet_phase"] * 2, - ] - + [[f"spikes_phasetime_{i}", f"spikephase_hist_{i}"] for i in range(num_cells)], + ["lfp_run"], + ["rip_pow"], + ], + height_ratios=[1, 0.4], ) - - -# _, ax = plt.subplots(25, figsize=(10, 50)) -mwt_SWS = np.transpose(mwt_SWS_second) -axd["raw_lfp"].plot(SWS_second.index, SWS_second.values[:, 0]) -axd["raw_lfp"].margins(0) - -plot_timefrequency(SWS_second.index, freqs, np.abs(mwt_SWS[:, :]), ax=axd["wavelet"]) - -axd["fit_wavelet"].plot(SWS_second.index, SWS_second.values[:, 0]) -axd["fit_wavelet"].plot(SWS_second.index, mwt_SWS[freq, :].real) -axd["fit_wavelet"].set_title(f"{freqs[freq]}Hz") -axd["fit_wavelet"].margins(0) - -axd["wavelet_power"].plot(SWS_second.index, np.abs(mwt_SWS[freq, :])) -axd["wavelet_power"].margins(0) -axd["wavelet_phase"].plot(SWS_second.index, np.angle(mwt_SWS[freq, :])) -axd["wavelet_phase"].margins(0) - -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > interval[0]) - & (data["units"][i].times() < interval[1]) - ] - -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle(mwt_SWS[freq, np.argmin(np.abs(SWS_second.index.values - spike))]) - ) - phase[i] = np.array(phase_i) - -spikes = {k: v for k, v in spikes.items() if len(v) > 0} -phase = {k: v for k, v in phase.items() if len(v) > 0} - -for i in range(num_cells): - axd[f"spikes_phasetime_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(phase.keys())[i]] - ) - axd[f"spikes_phasetime_{i}"].set_xlim(interval[0], interval[1]) - axd[f"spikes_phasetime_{i}"].set_ylim(-np.pi, np.pi) - axd[f"spikes_phasetime_{i}"].set_xlabel("time (s)") - axd[f"spikes_phasetime_{i}"].set_ylabel("phase") - - axd[f"spikephase_hist_{i}"].hist( - phase[list(phase.keys())[i]], bins=np.linspace(-np.pi, np.pi, 10) - ) - axd[f"spikephase_hist_{i}"].set_xlim(-np.pi, np.pi) - -plt.tight_layout() -plt.show() +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["rip_pow"].plot(RUN_Tsd.index.values, + smoother_swr_power + ) +for r in ripple_periods: + axd["rip_pow"].plot(RUN_Tsd.index.values[r[0]:r[1]], smoother_swr_power[r[0]:r[1]], color='red', linewidth=2) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") +axd["rip_pow"].axhline(threshold) +[axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] +[axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] +axd["rip_pow"].get_xaxis().set_visible(False) +axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") # %% -# Let's focus on the sleeping data. Let's see if we can isolate the slow wave oscillations from the data -# interval = (10, 15) - -RUN_minute_interval = nap.IntervalSet( - data["forward_ep"]["start"][0], data["forward_ep"]["end"][-1] -) +# *** +# Plotting a Sharp Wave Ripple +# ----------------------------------- +# Let's zoom in on out detected ripples and have a closer look! -RUN_minute = nap.TsdFrame( - t=data["eeg"].restrict(RUN_minute_interval).index.values, - d=data["eeg"].restrict(RUN_minute_interval).values, - time_support=data["eeg"].restrict(RUN_minute_interval).time_support, -) +# Filter out ripples which do not last long enough +ripple_periods = [r for r in ripple_periods if r[1]-r[0] > 20] -RUN_position = nap.TsdFrame( - t=data["position"].restrict(RUN_minute_interval).index.values[:], - d=data["position"].restrict(RUN_minute_interval), - time_support=data["position"].restrict(RUN_minute_interval).time_support, +# And plot! +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +buffer = 200 +ax.plot( + RUN_Tsd.index.values[r[0]-buffer:r[1]+buffer], + RUN_Tsd[r[0]-buffer:r[1]+buffer], + color="blue", + label="Non-SWR LFP" ) - -mwt_RUN = nap.compute_wavelet_transform( - RUN_minute[:, channel], freqs=freqs, fs=None, norm=None, n_cycles=3.5, scaling=1 +ax.plot( + RUN_Tsd.index.values[r[0]:r[1]], + RUN_Tsd[r[0]:r[1]], + color="red", + label="SWR", + linewidth=2 ) - -for run in range(len(data["forward_ep"]["start"])): - interval = ( - data["forward_ep"]["start"][run], - data["forward_ep"]["end"][run] + 5.0, - ) - if interval[1] - interval[0] < 6: - continue - print(interval) - RUN_second_r = RUN_minute.restrict(nap.IntervalSet(interval[0], interval[1])) - RUN_position_r = RUN_position.restrict(nap.IntervalSet(interval[0], interval[1])) - mwt_RUN_second_r = mwt_RUN.restrict(nap.IntervalSet(interval[0], interval[1])) - _, ax = plt.subplots(3) - plot_timefrequency( - RUN_second_r.index.values[:], - freqs[:], - np.transpose(mwt_RUN_second_r[:, :].values), - ax=ax[0], - ) - ax[1].plot(RUN_second_r[:, channel], alpha=0.5, label="Wake Data") - ax[1].margins(0) - - ax[2].plot(RUN_position, alpha=0.5, label="Wake Data") - ax[2].set_xlim( - RUN_second_r[:, channel].index.min(), RUN_second_r[:, channel].index.max() - ) - ax[2].margins(0) - plt.show() +ax.margins(0) +ax.set_xlabel("Time (s)") +ax.set_ylabel("LFP (v)") +ax.legend() +ax.set_title("Sharp Wave Ripple Visualization") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 8c0a0fc7..3dcd06db 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -13,7 +13,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ - Performs numpy fft on sig, returns output + Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. ---------- sig : pynapple.Tsd or pynapple.TsdFrame @@ -24,6 +24,11 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + + Notes + ----- + compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep + parameter otherwise will be sig.time_support, but it must only be a single epoch. """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): raise TypeError( @@ -100,7 +105,7 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_base=np.e): +def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_base=np.e): """ Creates an array of frequencies. @@ -110,10 +115,10 @@ def _create_freqs(freq_start, freq_stop, log_scaling=False, freq_step=1, log_bas Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. - log_scaling: Bool - If True, will use log spacing with base log_base for frequency spacing. Default False. freq_step: float, optional Step value, for linearly spaced values between start and stop. + log_scaling: Bool + If True, will use log spacing with base log_base for frequency spacing. Default False. log_base: float If log_scaling==True, this defines the base of the log to use. @@ -136,7 +141,7 @@ def compute_wavelet_transform( Parameters ---------- - sig : pynapple.Tsd or pynapple.TsdFrame + sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor Time series. freqs : 1d array or list of float If array, frequency values to estimate with morlet wavelets. @@ -159,7 +164,7 @@ def compute_wavelet_transform( Returns ------- - mwt : 2d array + pynapple.TsdFrame or pynapple.TsdTensor : 2d array Time frequency representation of the input signal. Notes @@ -186,11 +191,11 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - mwt = np.zeros( + cwt = np.zeros( [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex ) - filter_bank = _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) + filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) for f_i, filter in enumerate(filter_bank): convolved = sig.convolve(np.transpose(np.asarray([filter.real, filter.imag]))) convolved = convolved[:, :, 0].values + convolved[:, :, 1].values * 1j @@ -202,19 +207,19 @@ def compute_wavelet_transform( coef = np.insert( coef, 1, coef[0], axis=0 ) # slightly hacky line, necessary to make output correct shape - mwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) + cwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) if len(output_shape) == 2: return nap.TsdFrame( - t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support ) return nap.TsdTensor( - t=sig.index, d=mwt.reshape(output_shape), time_support=sig.time_support + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support ) -def _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision): +def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): """ Parameters @@ -253,7 +258,8 @@ def _generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision): if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) filter_bank.append(int_psi_scale) - return filter_bank + filter_bank = [np.pad(arr, ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), constant_values=0.0) for arr in filter_bank] + return np.array(filter_bank) def _integrate(arr, step): diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 5055cdbf..87686d18 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -75,6 +75,30 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t*50*np.pi*2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 50 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 70 + assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) freqs = np.linspace(1, 600, 10) @@ -104,4 +128,5 @@ def test_compute_wavelet_transform(): if __name__ == "__main__": - test_compute_welch_spectogram() + test_compute_wavelet_transform() + # test_compute_welch_spectogram() From b00bf23a22de6321e680a51a8ff0b838c6ac31da Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 16:45:43 +0100 Subject: [PATCH 045/195] linting --- docs/examples/tutorial_signal_processing.py | 121 ++++++++++++-------- pynapple/process/signal_processing.py | 9 +- tests/test_signal_processing.py | 2 +- 3 files changed, 84 insertions(+), 48 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 7f2383a9..6379c9be 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -77,13 +77,13 @@ data["forward_ep"]["end"][run_index] + 4.0, ) RUN_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], + t=data["eeg"].restrict(RUN_interval).index.values + - data["forward_ep"]["start"][run_index], d=data["eeg"].restrict(RUN_interval).values, ) RUN_pos = nap.TsdFrame( - t=data["position"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], + t=data["position"].restrict(RUN_interval).index.values + - data["forward_ep"]["start"][run_index], d=data["position"].restrict(RUN_interval).asarray(), ) # The given dataset has only one channel, so we set channel = 0 here @@ -96,10 +96,7 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( - [ - ["ephys"], - ["pos"] - ], + [["ephys"], ["pos"]], height_ratios=[1, 0.2], ) @@ -107,24 +104,25 @@ RUN_Tsd[:, channel].restrict( nap.IntervalSet( 0.0, - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index], ) ), label="Traversal LFP Data", - color="green" + color="green", ) axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] + 5.0, + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index], + data["forward_ep"]["end"][run_index] + - data["forward_ep"]["start"][run_index] + + 5.0, ) ), label="Post Traversal LFP Data", - color="blue" + color="blue", ) axd["ephys"].set_title("Traversal & Post Traversal LFP") axd["ephys"].set_ylabel("LFP (v)") @@ -149,8 +147,12 @@ # Now we will plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot( - fft.index, np.abs(fft.iloc[:, channel]), alpha=0.5, - label="LFP Frequency Power", c="blue", linewidth=2 + fft.index, + np.abs(fft.iloc[:, channel]), + alpha=0.5, + label="LFP Frequency Power", + c="blue", + linewidth=2, ) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") @@ -175,16 +177,31 @@ # could also be defined as a linspace or logspace freqs = np.array( [ - 2.59, 3.66, 5.18, 8.0, 10.36, 20.72, 29.3, 41.44, 58.59, 82.88, - 117.19, 152.35, 192.19, 200., 234.38, 270.00, 331.5, 390.00, + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 152.35, + 192.19, + 200.0, + 234.38, + 270.00, + 331.5, + 390.00, ] ) mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) -#Define wavelet decomposition plotting function -def plot_timefrequency( - times, freqs, powers, x_ticks=5, ax=None, **kwargs -): + +# Define wavelet decomposition plotting function +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) ax.imshow(powers, aspect="auto", **kwargs) @@ -200,7 +217,8 @@ def plot_timefrequency( y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - + + # And plot fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( @@ -250,9 +268,9 @@ def plot_timefrequency( height_ratios=[1, 0.3], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], - alpha=0.5, - label="LFP Data") +axd["lfp_run"].plot( + RUN_Tsd.index.values, RUN_Tsd[:, channel], alpha=0.5, label="LFP Data" +) axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_reconstruction, @@ -269,7 +287,10 @@ def plot_timefrequency( axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") axd["pos_run"].plot(RUN_pos) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] -[axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"]] +[ + axd["pos_run"].spines[sp].set_visible(False) + for sp in ["top", "right", "bottom", "left"] +] axd["pos_run"].get_xaxis().set_visible(False) axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) axd["pos_run"].set_ylabel("Lin. Position (cm)") @@ -299,9 +320,7 @@ def plot_timefrequency( axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(RUN_Tsd.index.values, - ripple_power - ) +axd["rip_pow"].plot(RUN_Tsd.index.values, ripple_power) axd["rip_pow"].margins(0) axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].spines["top"].set_visible(False) @@ -323,7 +342,9 @@ def plot_timefrequency( # smooth our wavelet power window_size = 51 window = np.ones(window_size) / window_size -smoother_swr_power = np.convolve(np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode='same') +smoother_swr_power = np.convolve( + np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode="same" +) # isolate our ripple periods is_ripple = smoother_swr_power > threshold start_idx = None @@ -332,8 +353,13 @@ def plot_timefrequency( if is_ripple[i] and start_idx is None: start_idx = i elif not is_ripple[i] and start_idx is not None: - axd["rip_pow"].plot(RUN_Tsd.index.values[start_idx:i], smoother_swr_power[start_idx:i], color='red', linewidth=2) - ripple_periods.append( (start_idx, i) ) + axd["rip_pow"].plot( + RUN_Tsd.index.values[start_idx:i], + smoother_swr_power[start_idx:i], + color="red", + linewidth=2, + ) + ripple_periods.append((start_idx, i)) start_idx = None # plot of captured ripple periods @@ -346,11 +372,14 @@ def plot_timefrequency( height_ratios=[1, 0.4], ) axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") -axd["rip_pow"].plot(RUN_Tsd.index.values, - smoother_swr_power - ) +axd["rip_pow"].plot(RUN_Tsd.index.values, smoother_swr_power) for r in ripple_periods: - axd["rip_pow"].plot(RUN_Tsd.index.values[r[0]:r[1]], smoother_swr_power[r[0]:r[1]], color='red', linewidth=2) + axd["rip_pow"].plot( + RUN_Tsd.index.values[r[0] : r[1]], + smoother_swr_power[r[0] : r[1]], + color="red", + linewidth=2, + ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") @@ -368,23 +397,23 @@ def plot_timefrequency( # Let's zoom in on out detected ripples and have a closer look! # Filter out ripples which do not last long enough -ripple_periods = [r for r in ripple_periods if r[1]-r[0] > 20] +ripple_periods = [r for r in ripple_periods if r[1] - r[0] > 20] # And plot! fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) buffer = 200 ax.plot( - RUN_Tsd.index.values[r[0]-buffer:r[1]+buffer], - RUN_Tsd[r[0]-buffer:r[1]+buffer], + RUN_Tsd.index.values[r[0] - buffer : r[1] + buffer], + RUN_Tsd[r[0] - buffer : r[1] + buffer], color="blue", - label="Non-SWR LFP" + label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.index.values[r[0]:r[1]], - RUN_Tsd[r[0]:r[1]], + RUN_Tsd.index.values[r[0] : r[1]], + RUN_Tsd[r[0] : r[1]], color="red", label="SWR", - linewidth=2 + linewidth=2, ) ax.margins(0) ax.set_xlabel("Time (s)") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 3dcd06db..fe000c24 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -258,7 +258,14 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) filter_bank.append(int_psi_scale) - filter_bank = [np.pad(arr, ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), constant_values=0.0) for arr in filter_bank] + filter_bank = [ + np.pad( + arr, + ((max_len - len(arr)) // 2, (max_len - len(arr) + 1) // 2), + constant_values=0.0, + ) + for arr in filter_bank + ] return np.array(filter_bank) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 87686d18..81c6b024 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -76,7 +76,7 @@ def test_compute_welch_spectogram(): def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t*50*np.pi*2), t=t) + sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] From 81a076e28c5606897f1c9add3191308a2c8ad668 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:23:38 +0100 Subject: [PATCH 046/195] phase preference notebook added --- docs/examples/tutorial_phase_preferences.py | 359 ++++++++++++++++++++ docs/examples/tutorial_signal_processing.py | 2 +- 2 files changed, 360 insertions(+), 1 deletion(-) create mode 100644 docs/examples/tutorial_phase_preferences.py diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py new file mode 100644 index 00000000..ce3f35ba --- /dev/null +++ b/docs/examples/tutorial_phase_preferences.py @@ -0,0 +1,359 @@ +# -*- coding: utf-8 -*- +""" +Grosmark & Buzsáki (2016) Tutorial 2 +============ + +In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, +we learned how to use Pynapple's signal processing tools with Local Field Potential data. Specifically, we +used wavelet decompositions to isolate Theta band activity during active traversal of a linear track, +as well as to find Sharp Wave Ripples which occurred after traversal. + +In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it +with spiking data, to find phase preferences of spiking units. + +Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. + +This tutorial was made by Kipp Freud. +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm` +# +# First, import the necessary libraries: + +import math +import os + +# ..todo: remove +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import requests +import scipy +import tqdm + +import pynapple as nap + +matplotlib.use("TkAgg") + +# %% +# *** +# Downloading the data +# ------------------ +# Let's download the data and save it locally + +path = "Achilles_10252013_EEG.nwb" +if path not in os.listdir("."): + r = requests.get(f"https://osf.io/2dfvp/download", stream=True) + block_size = 1024 * 1024 + with open(path, "wb") as f: + for data in tqdm.tqdm( + r.iter_content(block_size), + unit="MB", + unit_scale=True, + total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), + ): + f.write(data) + + +# %% +# *** +# Loading the data +# ------------------ +# Let's load and print the full dataset. + +data = nap.load_file(path) +FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +print(data) + + +# %% +# *** +# Selecting slices +# ----------------------------------- +# Let's consider a 10-second slice of data taken during REM sleep + +# Define the IntervalSet for this run and instantiate both LFP and +# Position TsdFrame objects +REM_minute_interval = nap.IntervalSet( + data["rem"]["start"][0] + 90.0, + data["rem"]["start"][0] + 100.0, +) +REM_Tsd = nap.TsdFrame( + t=data["eeg"].restrict(REM_minute_interval).index.values + - data["eeg"].restrict(REM_minute_interval).index.values.min(), + d=data["eeg"].restrict(REM_minute_interval).values, +) + +# We will also extract spike times from all units in our dataset +# which occur during our specified interval +spikes = {} +for i in data["units"].index: + spikes[i] = ( + data["units"][i].times()[ + (data["units"][i].times() > REM_minute_interval["start"][0]) + & (data["units"][i].times() < REM_minute_interval["end"][0]) + ] + - data["eeg"].restrict(REM_minute_interval).index.values.min() + ) + +# The given dataset has only one channel, so we set channel = 0 here +channel = 0 + +# %% +# *** +# Plotting the LFP Activity +# ----------------------------------- +# We should first plot our REM Local Field Potential data. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + +ax.plot( + REM_Tsd[:, channel], + label="REM LFP Data", + color="green", +) +ax.set_title("REM Local Field Potential") +ax.set_ylabel("LFP (v)") +ax.set_xlabel("time (s)") +ax.margins(0) +ax.legend() +plt.show() + +# %% +# *** +# Getting the Wavelet Decomposition +# ----------------------------------- +# As we would expect, it looks like we have a very strong theta oscillation within our data +# - this is a common feature of REM sleep. Let's perform a wavelet decomposition, +# as we did in the last tutorial, to see get a more informative breakdown of the +# frequencies present in the data. + +# We must define the frequency set that we'd like to use for our decomposition; +# these have been manually selected based on the frequencies used in +# Frey et. al (2021), but could also be defined as a linspace or logspace +freqs = np.array( + [ + 2.59, + 3.66, + 5.18, + 8.0, + 10.36, + 20.72, + 29.3, + 41.44, + 58.59, + 82.88, + 117.19, + 152.35, + 192.19, + 200.0, + 234.38, + 270.00, + 331.5, + 390.00, + ] +) +mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, channel], fs=None, freqs=freqs) + + +# Define wavelet decomposition plotting function +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect="auto", **kwargs) + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + + +# And plot it +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [ + ["wd_rem"], + ["lfp_rem"], + ], + height_ratios=[1, 0.2], +) +plot_timefrequency( + REM_Tsd.index.values[:], + freqs[:], + np.transpose(mwt_REM[:, :].values), + ax=axd["wd_rem"], +) +axd["wd_rem"].set_title(f"Wavelet Decomposition") +axd["lfp_rem"].plot(REM_Tsd) +axd["lfp_rem"].margins(0) +axd["lfp_rem"].set_ylabel("LFP (v)") +axd["lfp_rem"].get_xaxis().set_visible(False) +for spine in ["top", "right", "bottom", "left"]: + axd["lfp_rem"].spines[spine].set_visible(False) +plt.show() + +# %% +# *** +# Visualizing Theta Band Power and Phase +# ----------------------------------- +# There seems to be a strong theta frequency present in the data during the maze traversal. +# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well +# they match up. We will also extract and plot the phase of the 8Hz wavelet from the decomposition. +theta_freq_index = 3 +theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real +# calculating phase here +theta_band_phase = np.angle(mwt_REM[:, theta_freq_index].values) + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +axd = fig.subplot_mosaic( + [ + ["theta_pow"], + ["phase"], + ], + height_ratios=[0.4, 0.2], +) + +axd["theta_pow"].plot( + REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" +) +axd["theta_pow"].plot( + REM_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["theta_pow"].set_ylabel("LFP (v)") +axd["theta_pow"].set_xlabel("Time (s)") +axd["theta_pow"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") # +axd["theta_pow"].legend() +axd["phase"].plot(theta_band_phase) +[axd[k].margins(0) for k in ["theta_pow", "phase"]] +axd["phase"].set_ylabel("Phase") +plt.show() + + +# %% +# *** +# Finding Phase of Spikes +# ----------------------------------- +# Now that we have the phase of our theta wavelet, and our spike times, we can find the theta phase at which every +# spike occurs + +# We will start by throwing away cells which do not have enough +# spikes during our interval +spikes = {k: v for k, v in spikes.items() if len(v) > 20} +# Get phase of each spike +phase = {} +for i in spikes.keys(): + phase_i = [] + for spike in spikes[i]: + phase_i.append( + np.angle( + mwt_REM[ + np.argmin(np.abs(REM_Tsd.index.values - spike)), theta_freq_index + ] + ) + ) + phase[i] = np.array(phase_i) + +# Let's plot phase histograms for the first six units to see if there's +# any obvious preferences +fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) +for ri in range(2): + for ci in range(3): + ax[ri, ci].hist( + phase[list(phase.keys())[ri * 3 + ci]], + bins=np.linspace(-np.pi, np.pi, 10), + density=True, + ) + ax[ri, ci].set_xlabel("Phase (rad)") + ax[ri, ci].set_ylabel("Density") + ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +fig.suptitle("Phase Preference Histograms of First 6 Units") +plt.show() + +# %% +# *** +# Isolating Strong Phase Preferences +# ----------------------------------- +# It looks like there could be some phase preferences happening here, but there's a lot of cells to go through. +# Now that we have our phases of firing for each unit, we can sort the units by the circular variance of the phase +# of their spikes, to isolate the cells with the strongest phase preferences without manual inspection. + +variances = { + key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) + for key, value in phase.items() +} +spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) +phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) + +# Now let's plot phase histograms for the six units with the least +# varied phase of spikes. +fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) +for ri in range(2): + for ci in range(3): + ax[ri, ci].hist( + phase[list(phase.keys())[ri * 3 + ci]], + bins=np.linspace(-np.pi, np.pi, 10), + density=True, + ) + ax[ri, ci].set_xlabel("Phase (rad)") + ax[ri, ci].set_ylabel("Density") + ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +fig.suptitle( + "Phase Preference Histograms of 6 Units with " + "Highest Phase Preference" +) +plt.show() + +# %% +# *** +# Visualizing Phase Preferences +# ----------------------------------- +# There is definitely some strong phase preferences happening here. Let's visualize the firing preferences +# of the 6 cells we've isolated to get an impression of just how striking these preferences are. + +fig = plt.figure(constrained_layout=True, figsize=(10, 12)) +axd = fig.subplot_mosaic( + [ + ["lfp_run"], + ["phase_0"], + ["phase_1"], + ["phase_2"], + ["phase_3"], + ["phase_4"], + ["phase_5"], + ], + height_ratios=[0.4, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], +) +[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(6)]] +axd["lfp_run"].plot( + REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" +) +axd["lfp_run"].plot( + REM_Tsd.index.values, + theta_band_reconstruction, + label=f"{freqs[theta_freq_index]}Hz oscillations", +) +axd["lfp_run"].set_ylabel("LFP (v)") +axd["lfp_run"].set_xlabel("Time (s)") +axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].legend() +for i in range(6): + axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) + axd[f"phase_{i}"].scatter( + spikes[list(spikes.keys())[i]], phase[list(spikes.keys())[i]] + ) + axd[f"phase_{i}"].set_ylabel("Phase") + axd[f"phase_{i}"].set_title(f"Unit {list(spikes.keys())[i]}") +fig.suptitle("Phase Preference Visualizations") +plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 6379c9be..34941449 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial +Grosmark & Buzsáki (2016) Tutorial 1 ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). From ebdb64ca26f5f447a79923738156357a06b4d848 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:27:28 +0100 Subject: [PATCH 047/195] remove unused import --- docs/examples/tutorial_phase_preferences.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index ce3f35ba..59ba685b 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -27,8 +27,6 @@ import math import os -# ..todo: remove -import matplotlib import matplotlib.pyplot as plt import numpy as np import requests @@ -37,8 +35,6 @@ import pynapple as nap -matplotlib.use("TkAgg") - # %% # *** # Downloading the data From 0f17883b52325e716371363e17a885bf6dab676a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:50:25 +0100 Subject: [PATCH 048/195] simplified compute_wavelet_transform, added tests --- pynapple/process/signal_processing.py | 28 ++++++++++++--------------- tests/test_signal_processing.py | 21 +++++++++++++++----- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index fe000c24..936ae6a3 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -191,23 +191,19 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - cwt = np.zeros( - [sig.values.shape[0], len(freqs), sig.values.shape[1]], dtype=complex - ) - filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) - for f_i, filter in enumerate(filter_bank): - convolved = sig.convolve(np.transpose(np.asarray([filter.real, filter.imag]))) - convolved = convolved[:, :, 0].values + convolved[:, :, 1].values * 1j - coef = -np.diff(convolved, axis=0) - if norm == "sss": - coef *= -np.sqrt(scaling) / (freqs[f_i] / fs) - elif norm == "amp": - coef *= -scaling / (freqs[f_i] / fs) - coef = np.insert( - coef, 1, coef[0], axis=0 - ) # slightly hacky line, necessary to make output correct shape - cwt[:, f_i, :] = coef if len(coef.shape) == 2 else np.expand_dims(coef, axis=1) + convolved_real = sig.convolve(np.transpose(filter_bank.real)) + convolved_imag = sig.convolve(np.transpose(filter_bank.imag)) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = -np.diff(convolved, axis=0) + if norm == "sss": + coef *= coef * (-np.sqrt(scaling) / (freqs / fs)) + elif norm == "amp": + coef *= -scaling / (freqs / fs) + coef = np.insert( + coef, 1, coef[0, :], axis=0 + ) # slightly hacky line, necessary to make output correct shape + cwt = np.swapaxes(coef, 1, 2) if len(output_shape) == 2: return nap.TsdFrame( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 81c6b024..8edc4bd7 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -91,6 +91,22 @@ def test_compute_wavelet_transform(): assert mpf == 20 assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="sss") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="amp") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) @@ -125,8 +141,3 @@ def test_compute_wavelet_transform(): with pytest.raises(ValueError) as e_info: nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) assert str(e_info.value) == "Number of cycles must be a positive number." - - -if __name__ == "__main__": - test_compute_wavelet_transform() - # test_compute_welch_spectogram() From 5af389de2aadeb40f43dbdc2c64a14d50238c16b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 19:56:02 +0100 Subject: [PATCH 049/195] removed time zeroing for doc examples --- docs/examples/tutorial_phase_preferences.py | 17 +++++---------- docs/examples/tutorial_signal_processing.py | 24 ++++++--------------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 59ba685b..b4f1f6ec 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -78,23 +78,16 @@ data["rem"]["start"][0] + 90.0, data["rem"]["start"][0] + 100.0, ) -REM_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(REM_minute_interval).index.values - - data["eeg"].restrict(REM_minute_interval).index.values.min(), - d=data["eeg"].restrict(REM_minute_interval).values, -) +REM_Tsd = data["eeg"].restrict(REM_minute_interval) # We will also extract spike times from all units in our dataset # which occur during our specified interval spikes = {} for i in data["units"].index: - spikes[i] = ( - data["units"][i].times()[ - (data["units"][i].times() > REM_minute_interval["start"][0]) - & (data["units"][i].times() < REM_minute_interval["end"][0]) - ] - - data["eeg"].restrict(REM_minute_interval).index.values.min() - ) + spikes[i] = data["units"][i].times()[ + (data["units"][i].times() > REM_minute_interval["start"][0]) + & (data["units"][i].times() < REM_minute_interval["end"][0]) + ] # The given dataset has only one channel, so we set channel = 0 here channel = 0 diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 34941449..b5d786fe 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -76,16 +76,9 @@ data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] + 4.0, ) -RUN_Tsd = nap.TsdFrame( - t=data["eeg"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], - d=data["eeg"].restrict(RUN_interval).values, -) -RUN_pos = nap.TsdFrame( - t=data["position"].restrict(RUN_interval).index.values - - data["forward_ep"]["start"][run_index], - d=data["position"].restrict(RUN_interval).asarray(), -) +RUN_Tsd = data["eeg"].restrict(RUN_interval) +RUN_pos = data["position"].restrict(RUN_interval) + # The given dataset has only one channel, so we set channel = 0 here channel = 0 @@ -103,9 +96,7 @@ axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - 0.0, - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], + data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] ) ), label="Traversal LFP Data", @@ -114,11 +105,8 @@ axd["ephys"].plot( RUN_Tsd[:, channel].restrict( nap.IntervalSet( - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] - - data["forward_ep"]["start"][run_index] - + 5.0, + data["forward_ep"]["end"][run_index], + data["forward_ep"]["end"][run_index] + 5.0, ) ), label="Post Traversal LFP Data", From 072fbe396473ec5ca99a0a4405890b24247c52e7 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 20:57:17 +0100 Subject: [PATCH 050/195] minor changes, wavelet API v0 --- docs/examples/tutorial_wavelet_api.py | 178 ++++++++++++++++++++++++++ pynapple/process/__init__.py | 1 + pynapple/process/signal_processing.py | 6 +- 3 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 docs/examples/tutorial_wavelet_api.py diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py new file mode 100644 index 00000000..d5552249 --- /dev/null +++ b/docs/examples/tutorial_wavelet_api.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- +""" +Wavelet API tutorial +============ + +Working with Wavelets. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm` +# +# Now, import the necessary libraries: + +import matplotlib +matplotlib.use("TkAgg") +import matplotlib.pyplot as plt +import numpy as np + +import pynapple as nap + +# %% +# *** +# Generating a dummy signal +# ------------------ +# Let's generate a dummy signal to analyse with wavelets! + +# Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined +# with a weaker 25Hz sinusoid. +t = np.linspace(0,10, 10000) +sig = nap.Tsd(d=np.sin(t * (5+t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) +# Plot it +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) +ax.plot(sig) +ax.margins(0) +plt.show() + +# %% +# *** +# Getting our Morlet wavelet filter bank +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + +# %% +# *** +# Effect of n_cycles +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + +# %% +# *** +# Effect of scaling +# ------------------ + +freqs = np.linspace(1,25, num=25) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +offset = 0.2 +for f_i in range(filter_bank.shape[0]): + ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i,:]+offset*f_i) + ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') +ax.margins(0) +ax.yaxis.set_visible(False) +ax.spines['left'].set_visible(False) +ax.spines['right'].set_visible(False) +ax.spines['top'].set_visible(False) +ax.set_xlim(-2, 2) +ax.set_xlabel("Time (s)") +ax.set_title("Morlet Wavelet Filter Bank") +plt.show() + + +# %% +# *** +# Decomposing the dummy signal +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15) + +def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): + if np.iscomplexobj(powers): + powers = abs(powers) + ax.imshow(powers, aspect="auto", **kwargs) + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + if isinstance(x_ticks, int): + x_tick_pos = np.linspace(0, times.size, x_ticks) + x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) + else: + x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] + ax.set(xticks=x_tick_pos, xticklabels=x_ticks) + y_ticks = freqs + y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] + ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + +fig, ax = plt.subplots(1) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() + + + +# %% +# *** +# Increasing n_cycles increases resolution of decomposition +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() + +# %% +# *** +# Increasing n_cycles increases resolution of decomposition +# ------------------ + +mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +plot_timefrequency( + mwt.index.values[:], + freqs[:], + np.transpose(mwt[:, :].values), + ax=ax, +) +plt.show() diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index fb7e22b9..3db9771d 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -19,6 +19,7 @@ compute_spectogram, compute_wavelet_transform, compute_welch_spectogram, + generate_morlet_filterbank ) from .tuning_curves import ( compute_1d_mutual_info, diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 936ae6a3..28329092 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -224,7 +224,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 If array, frequency values to estimate with morlet wavelets. If list, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. - fs : float or None + fs : float Sampling rate, in Hz. n_cycles : float or 1d array Length of the filter, as the number of cycles for each frequency. @@ -236,8 +236,8 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 Returns ------- - filter_bank : list[np.ndarray] - list of morlet wavelet filters of the frequencies given + filter_bank : np.ndarray + list of Morlet wavelet filters of the frequencies given """ filter_bank = [] morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) From ffdc3d972e24850d42442ecfcc47fcc51cec5619 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Wed, 17 Jul 2024 20:58:59 +0100 Subject: [PATCH 051/195] linting --- docs/examples/tutorial_wavelet_api.py | 77 ++++++++++++++++----------- pynapple/process/__init__.py | 2 +- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index d5552249..55221265 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -20,6 +20,7 @@ # Now, import the necessary libraries: import matplotlib + matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np @@ -34,8 +35,8 @@ # Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined # with a weaker 25Hz sinusoid. -t = np.linspace(0,10, 10000) -sig = nap.Tsd(d=np.sin(t * (5+t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) +t = np.linspace(0, 10, 10000) +sig = nap.Tsd(d=np.sin(t * (5 + t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) # Plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) ax.plot(sig) @@ -47,19 +48,22 @@ # Getting our Morlet wavelet filter bank # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -70,19 +74,22 @@ # Effect of n_cycles # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -93,19 +100,22 @@ # Effect of scaling # ------------------ -freqs = np.linspace(1,25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10) +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 for f_i in range(filter_bank.shape[0]): - ax.plot(np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i,:]+offset*f_i) - ax.text(-2.2, offset*f_i, f"{np.round(freqs[f_i], 2)}Hz", va='center', ha='left') + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i + ) + ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) -ax.spines['left'].set_visible(False) -ax.spines['right'].set_visible(False) -ax.spines['top'].set_visible(False) +ax.spines["left"].set_visible(False) +ax.spines["right"].set_visible(False) +ax.spines["top"].set_visible(False) ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title("Morlet Wavelet Filter Bank") @@ -117,7 +127,10 @@ # Decomposing the dummy signal # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15 +) + def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): @@ -136,6 +149,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + fig, ax = plt.subplots(1) plot_timefrequency( mwt.index.values[:], @@ -146,13 +160,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): plt.show() - # %% # *** # Increasing n_cycles increases resolution of decomposition # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], @@ -167,7 +182,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Increasing n_cycles increases resolution of decomposition # ------------------ -mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10) +mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 3db9771d..0986cc6d 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -19,7 +19,7 @@ compute_spectogram, compute_wavelet_transform, compute_welch_spectogram, - generate_morlet_filterbank + generate_morlet_filterbank, ) from .tuning_curves import ( compute_1d_mutual_info, From 804565e4212027aa30e2547e21b13190afa4c7cf Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 22:57:15 +0200 Subject: [PATCH 052/195] removed file from test --- tests/npzfilestest/tsd2.json | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 tests/npzfilestest/tsd2.json diff --git a/tests/npzfilestest/tsd2.json b/tests/npzfilestest/tsd2.json deleted file mode 100644 index b64c10cb..00000000 --- a/tests/npzfilestest/tsd2.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "time": "2024-07-17 12:27:07.477318", - "info": "Test description" -} \ No newline at end of file From b302e7dfd8d042381f2e74e042ad84d7b5708368 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 23:02:41 +0200 Subject: [PATCH 053/195] ignore folder generated during tests --- .gitignore | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 04a63643..5e7adc67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,14 @@ *.nwb *.pickle *.py.md5 -*.npz +#*.npz /docs/generated/gallery/*.md /docs/generated/gallery/*.ipynb /docs/generated/gallery/*.py /docs/generated/gallery/*.zip +/tests/npzfilestest + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From 1be64e3c440fd0c465dd30511faa614e0060c5b0 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 23:08:48 +0200 Subject: [PATCH 054/195] Update pynapple/io/interface_npz.py Co-authored-by: Guillaume Viejo --- pynapple/io/interface_npz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 932813f1..d0ca4ab7 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -31,7 +31,7 @@ def _find_class_from_variables(file_variables, data_ndims=None): return "Tsd" if data_ndims == 1 else "TsdTensor" - for possible_type, espected_variables in EXPECTED_ENTRIES.items(): + for possible_type, expected_variables in EXPECTED_ENTRIES.items(): if espected_variables.issubset(file_variables): return possible_type From dbea5ac87d817954f3b39d9c8ff00aa0a8cf4a00 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 23:09:14 +0200 Subject: [PATCH 055/195] Update pynapple/io/interface_npz.py Co-authored-by: Guillaume Viejo --- pynapple/io/interface_npz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index d0ca4ab7..fd648ce5 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -32,7 +32,7 @@ def _find_class_from_variables(file_variables, data_ndims=None): return "Tsd" if data_ndims == 1 else "TsdTensor" for possible_type, expected_variables in EXPECTED_ENTRIES.items(): - if espected_variables.issubset(file_variables): + if expected_variables.issubset(file_variables): return possible_type return "npz" From 797bb0a4cd2728a53ba09a019c055f9146d903a7 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Wed, 17 Jul 2024 23:13:11 +0200 Subject: [PATCH 056/195] blacked --- pynapple/core/base_class.py | 26 ++------------------------ pynapple/core/interval_set.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 79f13486..802e64e5 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -432,8 +432,8 @@ def _from_npz_reader(cls, file): Parameters ---------- - file : npz file interface - The file interface to read from + file : NPZFile object + opened npz file interface. Returns ------- @@ -445,25 +445,3 @@ def _from_npz_reader(cls, file): } iset = IntervalSet(start=file["start"], end=file["end"]) return cls(time_support=iset, **kwargs) - - # def find_gaps(self, min_gap, time_units='s'): - # """ - # finds gaps in a tsd larger than min_gap. Return an IntervalSet. - # Epochs are defined by adding and removing 1 microsecond to the time index. - - # Parameters - # ---------- - # min_gap : float - # The minimum interval size considered to be a gap (default is second). - # time_units : str, optional - # Time units of min_gap ('us', 'ms', 's' [default]) - # """ - # min_gap = format_timestamps(np.array([min_gap]), time_units)[0] - - # time_array = self.index - # starts = self.time_support.start - # ends = self.time_support.end - - # s, e = jitfind_gaps(time_array, starts, ends, min_gap) - - # return nap.IntervalSet(s, e) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index b3df9de3..5ba835d1 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -672,19 +672,19 @@ def save(self, filename): @classmethod def _from_npz_reader(cls, file): - """ - Load an IntervalSet object from a npz file. + """Load an IntervalSet object from a npz file. + + The file should contain the keys 'start', 'end' and 'type'. + The 'type' key should be 'IntervalSet'. - The file should contain the keys 'start', 'end' and 'type'. The 'type' key should be 'IntervalSet'. + Parameters + ---------- + file : NPZFile object + opened npz file interface. Returns ------- IntervalSet The IntervalSet object - - Raises - ------ - RuntimeError - If the file does not contain the 'start', 'end' and 'type' keys. """ return cls(start=file["start"], end=file["end"]) From 8cc3210835c603b7e800fe148d70e31df5bee1ef Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Thu, 18 Jul 2024 00:06:28 +0200 Subject: [PATCH 057/195] neurosuite --- pynapple/io/loader.py | 37 ++++++++++++++++++------------------- pynapple/io/neurosuite.py | 39 +++++++++++++++++++-------------------- 2 files changed, 37 insertions(+), 39 deletions(-) diff --git a/pynapple/io/loader.py b/pynapple/io/loader.py index e08c1b95..cfe88595 100644 --- a/pynapple/io/loader.py +++ b/pynapple/io/loader.py @@ -10,6 +10,7 @@ @author: Guillaume Viejo """ import os +from pathlib import Path import warnings import pandas as pd @@ -56,23 +57,26 @@ class BaseLoader(object): """ def __init__(self, path=None): - self.path = path + self.path = Path(path) - file_found = False # Check if a pynapplenwb folder exist - if self.path is not None: - nwb_path = os.path.join(self.path, "pynapplenwb") - if os.path.exists(nwb_path): - files = os.listdir(nwb_path) - if len([f for f in files if f.endswith(".nwb")]): - file_found = True - self.load_data(path) - - # Starting the GUI - if not file_found: + nwb_path = self.path / "pynapplenwb" + files = list(nwb_path.glob("*.nwb")) + + if len(files) > 0: + self.load_data() + else: raise RuntimeError(get_error_text(path)) - - def load_data(self, path): + + @property + def nwbfilepath(self): + try: + nwbfilepath = next(self.path.glob("pynapplenwb/*nwb")) + except StopIteration: + raise FileNotFoundError("No NWB file found in {}".format(self.path / "pynapplenwb")) + return nwbfilepath + + def load_data(self): """ Load NWB data saved with pynapple in the pynapplenwb folder @@ -81,11 +85,6 @@ def load_data(self, path): path : str Path to the session folder """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r+") nwbfile = io.read() diff --git a/pynapple/io/neurosuite.py b/pynapple/io/neurosuite.py index 952853a5..506fc262 100755 --- a/pynapple/io/neurosuite.py +++ b/pynapple/io/neurosuite.py @@ -12,7 +12,8 @@ @author: Guillaume Viejo """ -import os +# import os +from pathlib import Path import sys import numpy as np @@ -37,14 +38,15 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name self.time_support = None super().__init__(path) - self.load_nwb_spikes(path) + self.load_nwb_spikes() - def load_nwb_spikes(self, path): + def load_nwb_spikes(self): """ Read the NWB spikes to extract the spike times. @@ -58,11 +60,6 @@ def load_nwb_spikes(self, path): TYPE Description """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -129,16 +126,16 @@ def load_lfp( The lfp in a time series format """ if filename is not None: - filepath = os.path.join(self.path, filename) + filepath = self.path / filename else: - listdir = os.listdir(self.path) - eegfile = [f for f in listdir if f.endswith(extension)] + eegfile = list(filepath.glob(f"*{extension}")) + if not len(eegfile): raise RuntimeError( "Path {} contains no {} files;".format(self.path, extension) ) - filepath = os.path.join(self.path, eegfile[0]) + filepath = eegfile[0] self.load_neurosuite_xml(self.path) @@ -200,9 +197,10 @@ def read_neuroscope_intervals(self, name=None, path2file=None): isets = self.load_nwb_intervals(name) if isinstance(isets, nap.IntervalSet): return isets + if name is not None and path2file is None: - path2file = os.path.join(self.path, self.basename + "." + name + ".evt") - if path2file is not None: + path2file = self.path / self.basename + "." + name + ".evt" + if path2file is not None: #TODO maybe useless conditional? try: # df = pd.read_csv(path2file, delimiter=' ', usecols = [0], header = None) tmp = np.genfromtxt(path2file)[:, 0] @@ -244,7 +242,7 @@ def write_neuroscope_intervals(self, extension, isets, name): ) ).T.flatten() - evt_file = os.path.join(self.path, self.basename + extension) + evt_file = self.path / (self.basename + extension) f = open(evt_file, "w") for t, n in zip(datatowrite, texttowrite): @@ -281,7 +279,7 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000 waveform_window = nap.IntervalSet(start=-0.5, end=1, time_units="ms") spikes = self.spikes - if not os.path.exists(self.path): # check if path exists + if not self.path.exists(): # check if path exists print("The path " + self.path + " doesn't exist; Exiting ...") sys.exit() @@ -304,11 +302,12 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000 epend = int(epoch.as_units("s")["end"].values[0] * fs) # Find dat file - files = os.listdir(self.path) - dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) + #files = os.listdir(self.path) + # dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) # Need n_samples collected in the entire recording from dat file to load - file = os.path.join(self.path, dat_files[0]) + # file = self.path / dat_files[0] + file = next(self.path.glob("^[^.][^.]*.dat")) f = open( file, "rb" ) # open file to get number of samples collected in the entire recording From 6063c5b0dc27edecd878f091929ed97989fc4537 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Thu, 18 Jul 2024 00:09:37 +0200 Subject: [PATCH 058/195] phy --- pynapple/io/phy.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pynapple/io/phy.py b/pynapple/io/phy.py index ae0b79a8..6e052cfd 100644 --- a/pynapple/io/phy.py +++ b/pynapple/io/phy.py @@ -6,8 +6,7 @@ @author: Sara Mahallati, Guillaume Viejo """ -import os - +from pathlib import Path import numpy as np from pynwb import NWBHDF5IO @@ -29,14 +28,16 @@ def __init__(self, path): path : str or Path object The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + + self.basename = path.name self.time_support = None super().__init__(path) - self.load_nwb_spikes(path) + self.load_nwb_spikes() - def load_nwb_spikes(self, path): + def load_nwb_spikes(self): """Read the NWB spikes to extract the spike times. Returns @@ -44,11 +45,6 @@ def load_nwb_spikes(self, path): TYPE Description """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() From 4ef312ca54a1b7702605e97f3740e4492ee86fd8 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Thu, 18 Jul 2024 00:12:23 +0200 Subject: [PATCH 059/195] cnmfe --- pynapple/io/cnmfe.py | 44 +++++++++++++------------------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/pynapple/io/cnmfe.py b/pynapple/io/cnmfe.py index c3266362..d6ad293e 100644 --- a/pynapple/io/cnmfe.py +++ b/pynapple/io/cnmfe.py @@ -11,7 +11,7 @@ # @Last Modified by: gviejo # @Last Modified time: 2023-11-16 13:14:54 -import os +from pathlib import Path from pynwb import NWBHDF5IO @@ -43,13 +43,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -58,13 +59,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -110,13 +104,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -125,13 +120,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() @@ -178,13 +166,14 @@ def __init__(self, path): path : str The path to the data. """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) - self.load_cnmfe_nwb(path) + self.load_cnmfe_nwb() - def load_cnmfe_nwb(self, path): + def load_cnmfe_nwb(self): """ Load the calcium transient and spatial footprint from nwb @@ -193,13 +182,6 @@ def load_cnmfe_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() From 3d9132095380095f25f905c7bf6997e009ba7f72 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Thu, 18 Jul 2024 01:24:56 +0200 Subject: [PATCH 060/195] all tests fixed after refactoring --- pynapple/core/base_class.py | 23 ++++++++++- pynapple/core/interval_set.py | 22 +---------- pynapple/core/time_series.py | 74 ++--------------------------------- pynapple/core/ts_group.py | 21 +--------- pynapple/core/utils.py | 33 ++++++++++++++++ pynapple/io/folder.py | 62 ++++++++++++++--------------- pynapple/io/interface_npz.py | 5 ++- pynapple/io/interface_nwb.py | 26 ++++++------ pynapple/io/loader.py | 10 +++-- pynapple/io/misc.py | 74 ++++++++++++++++------------------- pynapple/io/neurosuite.py | 11 +++--- pynapple/io/phy.py | 1 + pynapple/io/suite2p.py | 12 ++---- tests/test_interval_set.py | 8 ++-- tests/test_misc.py | 2 +- tests/test_nwb.py | 2 +- tests/test_time_series.py | 30 ++++++-------- tests/test_ts_group.py | 8 ++-- 18 files changed, 182 insertions(+), 242 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 802e64e5..22ebdaea 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -11,7 +11,7 @@ from ._core_functions import _count, _restrict, _value_from from .interval_set import IntervalSet from .time_index import TsIndex -from .utils import convert_to_numpy_array +from .utils import check_filename, convert_to_numpy_array class Base(abc.ABC): @@ -426,6 +426,27 @@ def get(self, start, end=None, time_units="s"): idx_end = np.searchsorted(time_array, end, side="right") return self[idx_start:idx_end] + def _get_filename(self, filename): + """Check if the filename is valid and return the path + + Parameters + ---------- + filename : str or Path + The filename + + Returns + ------- + Path + The path to the file + + Raises + ------ + RuntimeError + If the filename is a directory or the parent does not exist + """ + + return check_filename(filename) + @classmethod def _from_npz_reader(cls, file): """Load a time series object from a npz file interface. diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 5ba835d1..b39e6046 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -40,7 +40,6 @@ """ import importlib -import os import warnings from numbers import Number @@ -61,6 +60,7 @@ from .utils import ( _get_terminal_size, _IntervalSetSliceHelper, + check_filename, convert_to_numpy_array, is_array_like, ) @@ -643,26 +643,8 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) - np.savez( - filename, + check_filename(filename), start=self.values[:, 0], end=self.values[:, 1], type=np.array(["IntervalSet"], dtype=np.str_), diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index f3de4cf6..c7cd0e4c 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -17,7 +17,6 @@ import abc import importlib -import os import warnings from numbers import Number @@ -830,23 +829,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) np.savez( filename, @@ -1110,23 +1093,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) cols_name = self.columns if cols_name.dtype == np.dtype("O"): @@ -1437,24 +1404,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) - + filename = self._get_filename(filename) np.savez( filename, t=self.index.values, @@ -1730,23 +1680,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = self._get_filename(filename) np.savez( filename, diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index f96ed32e..3dfc0921 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -4,7 +4,6 @@ """ -import os import warnings from collections import UserDict from collections.abc import Hashable @@ -21,7 +20,7 @@ from .interval_set import IntervalSet from .time_index import TsIndex from .time_series import BaseTsd, Ts, Tsd, TsdFrame, is_array_like -from .utils import _get_terminal_size, convert_to_numpy_array +from .utils import _get_terminal_size, check_filename, convert_to_numpy_array def _union_intervals(i_sets): @@ -1325,23 +1324,7 @@ def save(self, filename): RuntimeError If filename is not str, path does not exist or filename is a directory. """ - if not isinstance(filename, str): - raise RuntimeError("Invalid type; please provide filename as string") - - if os.path.isdir(filename): - raise RuntimeError( - "Invalid filename input. {} is directory.".format(filename) - ) - - if not filename.lower().endswith(".npz"): - filename = filename + ".npz" - - dirname = os.path.dirname(filename) - - if len(dirname) and not os.path.exists(dirname): - raise RuntimeError( - "Path {} does not exist.".format(os.path.dirname(filename)) - ) + filename = check_filename(filename) dicttosave = {"type": np.array(["TsGroup"], dtype=np.str_)} for k in self._metadata.columns: diff --git a/pynapple/core/utils.py b/pynapple/core/utils.py index aa27eb7c..711aa0c8 100644 --- a/pynapple/core/utils.py +++ b/pynapple/core/utils.py @@ -6,6 +6,7 @@ import warnings from itertools import combinations from numbers import Number +from pathlib import Path import numpy as np @@ -403,3 +404,35 @@ def __getitem__(self, key): raise IndexError else: raise IndexError + + +def check_filename(filename): + """Check if the filename is valid and return the path + + Parameters + ---------- + filename : str or Path + The filename + + Returns + ------- + Path + The path to the file + + Raises + ------ + RuntimeError + If the filename is a directory or the parent does not exist + """ + filename = Path(filename).resolve() + + if filename.is_dir(): + raise RuntimeError("Invalid filename input. {} is directory.".format(filename)) + + filename = filename.with_suffix(".npz") + + parent_folder = filename.parent + if not parent_folder.exists(): + raise RuntimeError("Path {} does not exist.".format(parent_folder)) + + return filename diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index 60042bd9..de1b9ef5 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -12,10 +12,10 @@ import json -import os import string from collections import UserDict from datetime import datetime +from pathlib import Path from rich.console import Console # , ConsoleOptions, RenderResult from rich.panel import Panel @@ -30,27 +30,29 @@ def _find_files(path, extension=".npz"): Parameters ---------- - path : TYPE - Description + path : str or Path + The directory path where files will be searched. extension : str, optional - Description + The file extension to look for, default is ".npz". Returns ------- - TYPE - Description + dict + Dictionary with filenames (without extension and whitespace) as keys + and NPZFile or NWBFile objects as values. """ + extension = extension if extension.startswith(".") else "." + extension + path = Path(path) # Ensure path is a Path object files = {} - for f in os.scandir(path): - if f.is_file() and f.name.endswith(extension): - if extension == "npz": - filename = os.path.splitext(os.path.basename(f.path))[0] - filename.translate({ord(c): None for c in string.whitespace}) - files[filename] = NPZFile(f.path) - elif extension == "nwb": - filename = os.path.splitext(os.path.basename(f.path))[0] - filename.translate({ord(c): None for c in string.whitespace}) - files[filename] = NWBFile(f.path) + extensions_dict = {".npz": NPZFile, ".nwb": NWBFile} + assert extension in extensions_dict.keys(), f"Extension {extension} not supported" + + for f in path.iterdir(): + if f.is_file() and f.suffix == extension: + filename = f.stem + filename = filename.translate({ord(c): None for c in string.whitespace}) + files[filename] = extensions_dict[extension](f) + return files @@ -108,9 +110,9 @@ def __init__(self, path): # , exclude=(), max_depth=4): path : str Path to the folder """ - path = path.rstrip("/") + path = Path(path) self.path = path - self.name = os.path.basename(path) + self.name = self.path.name self._basic_view = Tree( ":open_file_folder: {}".format(self.name), guide_style="blue" ) @@ -118,16 +120,15 @@ def __init__(self, path): # , exclude=(), max_depth=4): # Search sub-folders subfolds = [ - f.path - for f in os.scandir(path) - if f.is_dir() and not f.name.startswith(".") + p for p in path.iterdir() if p.is_dir() and not p.name.startswith(".") ] + subfolds.sort() self.subfolds = {} for s in subfolds: - sub = os.path.basename(s) + sub = s.name self.subfolds[sub] = Folder(s) self._basic_view.add(":open_file_folder: [blue]" + sub) @@ -244,14 +245,14 @@ def save(self, name, obj, description=""): description : str, optional Metainformation added as a json sidecar. """ - filepath = os.path.join(self.path, name) + filepath = self.path / (name + ".npz") obj.save(filepath) - self.npz_files[name] = NPZFile(filepath + ".npz") + self.npz_files[name] = NPZFile(filepath) self.data[name] = obj metadata = {"time": str(datetime.now()), "info": str(description)} - with open(os.path.join(self.path, name + ".json"), "w") as ff: + with open(self.path / (name + ".json"), "w") as ff: json.dump(metadata, ff, indent=2) # regenerate the tree view @@ -295,19 +296,18 @@ def metadata(self, name): Name of the npz file """ # Search for json first - json_filename = os.path.join(self.path, name + ".json") - if os.path.isfile(json_filename): + json_filename = self.path / (name + ".json") + title = self.path / (name + ".npz") + if json_filename.exists(): with open(json_filename, "r") as ff: metadata = json.load(ff) text = "\n".join([" : ".join(it) for it in metadata.items()]) - panel = Panel.fit( - text, border_style="green", title=os.path.join(self.path, name + ".npz") - ) + panel = Panel.fit(text, border_style="green", title=title) else: panel = Panel.fit( "No metadata", border_style="red", - title=os.path.join(self.path, name + ".npz"), + title=title, ) with Console() as console: console.print(panel) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index fd648ce5..22da63bd 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -7,7 +7,7 @@ # @Last Modified time: 2024-04-02 14:32:25 -import os +from pathlib import Path import numpy as np @@ -63,8 +63,9 @@ def __init__(self, path): path : str Valid path to a NPZ file """ + path = Path(path) self.path = path - self.name = os.path.basename(path) + self.name = path.name self.file = np.load(self.path, allow_pickle=True) type_ = "" diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 124b9787..e94cfcb4 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -15,6 +15,7 @@ import warnings from collections import UserDict from numbers import Number +from pathlib import Path import numpy as np import pynwb @@ -386,22 +387,21 @@ def __init__(self, file, lazy_loading=True): RuntimeError If file is not an instance of NWBFile """ - if isinstance(file, str): - if os.path.exists(file): - self.path = file - self.name = os.path.basename(file).split(".")[0] - self.io = NWBHDF5IO(file, "r") - self.nwb = self.io.read() - else: - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) - elif isinstance(file, pynwb.file.NWBFile): + # TODO: do we really need to have instantiation from file and object in the same place? + + if isinstance(file, pynwb.file.NWBFile): self.nwb = file self.name = self.nwb.session_id - else: - raise RuntimeError( - "unrecognized argument. Please provide path to a valid NWB file or open NWB file." - ) + path = Path(file) + + if path.exists(): + self.path = path + self.name = path.stem + self.io = NWBHDF5IO(path, "r") + self.nwb = self.io.read() + else: + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), file) self.data = _extract_compatible_data_from_nwbfile(self.nwb) self.key_to_id = {k: self.data[k]["id"] for k in self.data.keys()} diff --git a/pynapple/io/loader.py b/pynapple/io/loader.py index cfe88595..8e03cc8a 100644 --- a/pynapple/io/loader.py +++ b/pynapple/io/loader.py @@ -10,8 +10,8 @@ @author: Guillaume Viejo """ import os -from pathlib import Path import warnings +from pathlib import Path import pandas as pd from pynwb import NWBHDF5IO, TimeSeries @@ -62,18 +62,20 @@ def __init__(self, path=None): # Check if a pynapplenwb folder exist nwb_path = self.path / "pynapplenwb" files = list(nwb_path.glob("*.nwb")) - + if len(files) > 0: self.load_data() else: raise RuntimeError(get_error_text(path)) - + @property def nwbfilepath(self): try: nwbfilepath = next(self.path.glob("pynapplenwb/*nwb")) except StopIteration: - raise FileNotFoundError("No NWB file found in {}".format(self.path / "pynapplenwb")) + raise FileNotFoundError( + "No NWB file found in {}".format(self.path / "pynapplenwb") + ) return nwbfilepath def load_data(self): diff --git a/pynapple/io/misc.py b/pynapple/io/misc.py index 2dff3fc1..03128adc 100644 --- a/pynapple/io/misc.py +++ b/pynapple/io/misc.py @@ -4,8 +4,8 @@ Various io functions """ -import os import warnings +from pathlib import Path from xml.dom import minidom import numpy as np @@ -49,23 +49,23 @@ def load_file(path, lazy_loading=None): FileNotFoundError If file is missing """ - if os.path.isfile(path): - if path.endswith(".npz"): - if lazy_loading: - warnings.warn("Lazy loading is not supported for NPZ files") - return NPZFile(path).load() - - elif path.endswith(".nwb"): - # preserves class init default: - kwargs_for_lazyloading = ( - {} if lazy_loading is None else {"lazy_loading": lazy_loading} - ) - - return NWBFile(path, **kwargs_for_lazyloading) - else: - raise RuntimeError("File format not supported") + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"File {path} does not exist") + + if path.suffix == ".npz": + if lazy_loading: + warnings.warn("Lazy loading is not supported for NPZ files") + return NPZFile(path).load() + + elif path.suffix == ".nwb": + # preserves class init default: + kwargs_for_lazyloading = ( + {} if lazy_loading is None else {"lazy_loading": lazy_loading} + ) + return NWBFile(path, **kwargs_for_lazyloading) else: - raise FileNotFoundError("File {} does not exist".format(path)) + raise RuntimeError("File format not supported") def load_folder(path): @@ -88,13 +88,13 @@ def load_folder(path): RuntimeError If folder is missing """ - if os.path.isdir(path): - return Folder(path) - else: - raise RuntimeError("Folder {} does not exist".format(path)) + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Folder {path} does not exist") + return Folder(path) -def load_session(path=None, session_type=None): +def load_session(path, session_type=None): """ %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% % WARNING : THIS FUNCTION IS DEPRECATED % @@ -124,9 +124,8 @@ def load_session(path=None, session_type=None): A class holding all the data from the session. """ - if path: - if not os.path.isdir(path): - raise RuntimeError("Path {} is not found.".format(path)) + path = Path(path) + assert path.exists(), f"Folder {path} does not exist" if isinstance(session_type, str): session_type = session_type.lower() @@ -196,13 +195,14 @@ def load_eeg( """ # Need to check if a xml file exists - path = os.path.dirname(filepath) - basename = os.path.basename(filepath).split(".")[0] - listdir = os.listdir(path) + filepath = Path(filepath) + path = filepath.parent + basename = filepath.name.split(".")[0] + listdir = list(path.glob("*")) if frequency is None or n_channels is None: if basename + ".xml" in listdir: - xmlpath = os.path.join(path, basename + ".xml") + xmlpath = path / (basename + ".xml") xmldoc = minidom.parse(xmlpath) else: raise RuntimeError( @@ -280,18 +280,12 @@ def append_NWB_LFP(path, lfp, channel=None): If no channel is specify when passing a Tsd """ - new_path = os.path.join(path, "pynapplenwb") + path = Path(path) + new_path = path / "pynapplenwb" nwb_path = "" - if os.path.exists(new_path): - nwbfilename = [f for f in os.listdir(new_path) if f.endswith(".nwb")] - if len(nwbfilename): - nwb_path = os.path.join(path, "pynapplenwb", nwbfilename[0]) - else: - nwbfilename = [f for f in os.listdir(path) if f.endswith(".nwb")] - if len(nwbfilename): - nwb_path = os.path.join(path, "pynapplenwb", nwbfilename[0]) - - if len(nwb_path) == 0: + try: + nwb_path = next(new_path.glob("*.nwb")) + except StopIteration: raise RuntimeError("Can't find nwb file in {}".format(path)) if isinstance(lfp, nap.TsdFrame): diff --git a/pynapple/io/neurosuite.py b/pynapple/io/neurosuite.py index 506fc262..605b59f6 100755 --- a/pynapple/io/neurosuite.py +++ b/pynapple/io/neurosuite.py @@ -12,9 +12,8 @@ @author: Guillaume Viejo """ -# import os -from pathlib import Path import sys +from pathlib import Path import numpy as np import pandas as pd @@ -197,10 +196,10 @@ def read_neuroscope_intervals(self, name=None, path2file=None): isets = self.load_nwb_intervals(name) if isinstance(isets, nap.IntervalSet): return isets - + if name is not None and path2file is None: path2file = self.path / self.basename + "." + name + ".evt" - if path2file is not None: #TODO maybe useless conditional? + if path2file is not None: # TODO maybe useless conditional? try: # df = pd.read_csv(path2file, delimiter=' ', usecols = [0], header = None) tmp = np.genfromtxt(path2file)[:, 0] @@ -302,8 +301,8 @@ def load_mean_waveforms(self, epoch=None, waveform_window=None, spike_count=1000 epend = int(epoch.as_units("s")["end"].values[0] * fs) # Find dat file - #files = os.listdir(self.path) - # dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) + # files = os.listdir(self.path) + # dat_files = np.sort([f for f in files if "dat" in f and f[0] != "."]) # Need n_samples collected in the entire recording from dat file to load # file = self.path / dat_files[0] diff --git a/pynapple/io/phy.py b/pynapple/io/phy.py index 6e052cfd..92568463 100644 --- a/pynapple/io/phy.py +++ b/pynapple/io/phy.py @@ -7,6 +7,7 @@ """ from pathlib import Path + import numpy as np from pynwb import NWBHDF5IO diff --git a/pynapple/io/suite2p.py b/pynapple/io/suite2p.py index 3a5fb045..c3242241 100644 --- a/pynapple/io/suite2p.py +++ b/pynapple/io/suite2p.py @@ -13,7 +13,7 @@ """ -import os +from pathlib import Path import numpy as np import pandas as pd @@ -60,7 +60,8 @@ def __init__(self, path): path : str The path of the session """ - self.basename = os.path.basename(path) + path = Path(path) + self.basename = path.name super().__init__(path) @@ -75,13 +76,6 @@ def load_suite2p_nwb(self, path): path : str Path to the session """ - self.nwb_path = os.path.join(path, "pynapplenwb") - if not os.path.exists(self.nwb_path): - raise RuntimeError("Path {} does not exist.".format(self.nwb_path)) - - self.nwbfilename = [f for f in os.listdir(self.nwb_path) if "nwb" in f][0] - self.nwbfilepath = os.path.join(self.nwb_path, self.nwbfilename) - io = NWBHDF5IO(self.nwbfilepath, "r") nwbfile = io.read() diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index bd4b5631..f93909cc 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -6,6 +6,7 @@ import pytest import warnings from .mock import MockArray +from pathlib import Path def test_create_iset(): @@ -480,18 +481,17 @@ def test_save_npz(): end = np.around(np.array([5, 15, 20], dtype=np.float64), 9) ep = nap.IntervalSet(start=start,end=end) - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: ep.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: ep.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: ep.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ep.save("ep.npz") os.listdir('.') diff --git a/tests/test_misc.py b/tests/test_misc.py index 47024e4b..003316a1 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -62,7 +62,7 @@ def test_load_folder(path): assert isinstance(folder, nap.io.Folder) def test_load_folder_foldernotfound(): - with pytest.raises(RuntimeError) as e: + with pytest.raises(FileNotFoundError) as e: nap.load_folder("MissingFolder") assert str(e.value) == "Folder MissingFolder does not exist" diff --git a/tests/test_nwb.py b/tests/test_nwb.py index f27cbafa..a438aa04 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -96,7 +96,7 @@ def test_NWBFile_missing_file(): def test_NWBFile_wrong_input(): - with pytest.raises(RuntimeError): + with pytest.raises(TypeError): nap.NWBFile(1) def test_wrong_key(): diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 906486c6..03879fed 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,7 +1,7 @@ """Tests of time series for `pynapple` package.""" import pickle - +from pathlib import Path import numpy as np import pandas as pd import pytest @@ -749,18 +749,17 @@ def test_to_tsgroup(self, tsd): def test_save_npz(self, tsd): import os - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsd.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsd.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsd.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsd.save("tsd.npz") os.listdir('.') @@ -990,18 +989,17 @@ def test_bin_average_with_ep(self, tsdframe): def test_save_npz(self, tsdframe): import os - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsdframe.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsdframe.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsdframe.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdframe.save("tsdframe.npz") os.listdir('.') @@ -1113,18 +1111,17 @@ def test_str_(self, ts): def test_save_npz(self, ts): import os - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: ts.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: ts.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: ts.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ts.save("ts.npz") os.listdir('.') @@ -1354,18 +1351,17 @@ def test_bin_average(self, tsdtensor): def test_save_npz(self, tsdtensor): import os - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsdtensor.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsdtensor.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsdtensor.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdtensor.save("tsdtensor.npz") os.listdir('.') diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index 3ade3c94..eb7449cd 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -14,6 +14,7 @@ import numpy as np import pandas as pd import pytest +from pathlib import Path import pynapple as nap @@ -529,18 +530,17 @@ def test_save_npz(self, group): tsgroup = nap.TsGroup(group, meta = np.arange(len(group), dtype=np.int64), meta2 = np.array(['a', 'b', 'c'])) - with pytest.raises(RuntimeError) as e: + with pytest.raises(TypeError) as e: tsgroup.save(dict) - assert str(e.value) == "Invalid type; please provide filename as string" with pytest.raises(RuntimeError) as e: tsgroup.save('./') - assert str(e.value) == "Invalid filename input. {} is directory.".format("./") + assert str(e.value) == "Invalid filename input. {} is directory.".format(Path("./").resolve()) fake_path = './fake/path' with pytest.raises(RuntimeError) as e: tsgroup.save(fake_path+'/file.npz') - assert str(e.value) == "Path {} does not exist.".format(fake_path) + assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsgroup.save("tsgroup.npz") os.listdir('.') From 894ff909be0f015ba92d57896d752c9aebed1e55 Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Thu, 18 Jul 2024 01:53:17 +0200 Subject: [PATCH 061/195] complete pathlib transition --- tests/test_folder.py | 34 ++++++++++---------------- tests/test_interval_set.py | 11 ++++----- tests/test_lazy_loading.py | 1 - tests/test_misc.py | 31 +++++++++++++----------- tests/test_npz_file.py | 46 +++++++++++++++++------------------ tests/test_time_series.py | 49 +++++++++++++------------------------- tests/test_ts_group.py | 14 ++++------- 7 files changed, 78 insertions(+), 108 deletions(-) diff --git a/tests/test_folder.py b/tests/test_folder.py index 21c2da72..c40d7083 100644 --- a/tests/test_folder.py +++ b/tests/test_folder.py @@ -11,25 +11,19 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil import json # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") - -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -# path2 = os.path.join(path, "sub") -# if not os.path.isdir(path): -# os.mkdir(path2) - -# Cleaning -for root, dirs, files in os.walk(path): - for f in files: - os.remove(os.path.join(root, f)) +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) # Populate the folder data = { @@ -45,9 +39,7 @@ } for k, d in data.items(): - d.save(os.path.join(path, k+".npz")) -# for k, d in data.items(): -# d.save(os.path.join(path, "sub", k+".npz")) + d.save(path / (k+".npz")) @pytest.mark.parametrize("path", [path]) def test_load_folder(path): @@ -78,11 +70,11 @@ def test_save(folder): assert isinstance(folder['tsd2'], nap.Tsd) - files = os.listdir(folder.path) + files = [f.name for f in path.iterdir()] assert "tsd2.json" in files # check json - metadata = json.load(open(os.path.join(path, "tsd2.json"), "r")) + metadata = json.load(open(path / "tsd2.json", "r")) assert "time" in metadata.keys() assert "info" in metadata.keys() assert "Test description" == metadata["info"] diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index f93909cc..a07837a9 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -475,7 +475,6 @@ def test_str_(): assert isinstance(ep.__str__(), str) def test_save_npz(): - import os start = np.around(np.array([0, 10, 16], dtype=np.float64), 9) end = np.around(np.array([5, 15, 20], dtype=np.float64), 9) @@ -494,12 +493,10 @@ def test_save_npz(): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ep.save("ep.npz") - os.listdir('.') - assert "ep.npz" in os.listdir(".") + assert "ep.npz" in [f.name for f in Path('.').iterdir()] ep.save("ep2") - os.listdir('.') - assert "ep2.npz" in os.listdir(".") + assert "ep2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("ep.npz") @@ -511,8 +508,8 @@ def test_save_npz(): np.testing.assert_array_almost_equal(file['end'], end) # Cleaning - os.remove("ep.npz") - os.remove("ep2.npz") + Path("ep.npz").unlink() + Path("ep2.npz").unlink() diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 0e3aa929..6acc4f30 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -1,4 +1,3 @@ -import os.path import warnings from contextlib import nullcontext as does_not_raise from pathlib import Path diff --git a/tests/test_misc.py b/tests/test_misc.py index 003316a1..ba5b91f5 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -11,24 +11,27 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) + +path2 = path.parent / "sub" +path2.mkdir(exist_ok=True, parents=True) -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -path2 = os.path.join(path, "sub") -if not os.path.isdir(path): - os.mkdir(path2) @pytest.mark.parametrize("path", [path]) def test_load_file(path): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - file_path = os.path.join(path, "tsd.npz") + file_path = path / "tsd.npz" tsd.save(file_path) tsd2 = nap.load_file(file_path) @@ -37,7 +40,7 @@ def test_load_file(path): np.testing.assert_array_equal(tsd.values, tsd2.values) np.testing.assert_array_equal(tsd.time_support.values, tsd2.time_support.values) - os.remove(file_path) + file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_file_filenotfound(path): @@ -48,13 +51,13 @@ def test_load_file_filenotfound(path): @pytest.mark.parametrize("path", [path]) def test_load_wrong_format(path): - file_path = os.path.join(path, "test.npy") + file_path = path / "test.npy" np.save(file_path, np.random.rand(10)) with pytest.raises(RuntimeError) as e: nap.load_file(file_path) assert str(e.value) == "File format not supported" - os.remove(file_path) + file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_folder(path): diff --git a/tests/test_npz_file.py b/tests/test_npz_file.py index c69805ba..f9de9158 100644 --- a/tests/test_npz_file.py +++ b/tests/test_npz_file.py @@ -11,24 +11,22 @@ import pandas as pd import pytest import warnings -import os +from pathlib import Path +import shutil # look for tests folder -path = os.getcwd() -if os.path.basename(path) == 'pynapple': - path = os.path.join(path, "tests") - -path = os.path.join(path, "npzfilestest") -if not os.path.isdir(path): - os.mkdir(path) -path2 = os.path.join(path, "sub") -if not os.path.isdir(path): - os.mkdir(path2) - -# Cleaning -for root, dirs, files in os.walk(path): - for f in files: - os.remove(os.path.join(root, f)) +path = Path(__file__).parent +if path.name == 'pynapple': + path = path / "tests" +path = path / "npzfilestest" + +# Recursively remove the folder: +shutil.rmtree(path, ignore_errors=True) +path.mkdir(exist_ok=True, parents=True) + +path2 = path.parent / "sub" +path2.mkdir(exist_ok=True, parents=True) + # Populate the folder data = { @@ -43,12 +41,12 @@ "iset":nap.IntervalSet(start=np.array([0.0, 5.0]), end=np.array([1.0, 6.0])) } for k, d in data.items(): - d.save(os.path.join(path, k+".npz")) + d.save(path / (k+".npz")) @pytest.mark.parametrize("path", [path]) def test_init(path): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - file_path = os.path.join(path, "tsd.npz") + file_path = path / "tsd.npz" tsd.save(file_path) file = nap.NPZFile(file_path) assert isinstance(file, nap.NPZFile) @@ -58,7 +56,7 @@ def test_init(path): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsd', 'ts', 'tsdframe', 'tsgroup', 'iset']) def test_load(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() assert isinstance(tmp, type(data[k])) @@ -66,7 +64,7 @@ def test_load(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsgroup']) def test_load_tsgroup(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() assert isinstance(tmp, type(data[k])) @@ -79,7 +77,7 @@ def test_load_tsgroup(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsd']) def test_load_tsd(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() assert isinstance(tmp, type(data[k])) @@ -91,7 +89,7 @@ def test_load_tsd(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['ts']) def test_load_ts(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() assert isinstance(tmp, type(data[k])) @@ -103,7 +101,7 @@ def test_load_ts(path, k): @pytest.mark.parametrize("path", [path]) @pytest.mark.parametrize("k", ['tsdframe']) def test_load_tsdframe(path, k): - file_path = os.path.join(path, k+".npz") + file_path = path / (k+".npz") file = nap.NPZFile(file_path) tmp = file.load() assert isinstance(tmp, type(data[k])) @@ -116,7 +114,7 @@ def test_load_tsdframe(path, k): @pytest.mark.parametrize("path", [path]) def test_load_non_npz(path): - file_path = os.path.join(path, "random.npz") + file_path = path / "random.npz" tmp = np.random.rand(100) np.savez(file_path, a=tmp) file = nap.NPZFile(file_path) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 03879fed..0942490f 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd import pytest +from pathlib import Path import pynapple as nap @@ -747,8 +748,6 @@ def test_to_tsgroup(self, tsd): np.testing.assert_array_almost_equal(tsgroup[i].index, t[i]) def test_save_npz(self, tsd): - import os - with pytest.raises(TypeError) as e: tsd.save(dict) @@ -762,12 +761,10 @@ def test_save_npz(self, tsd): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsd.save("tsd.npz") - os.listdir('.') - assert "tsd.npz" in os.listdir(".") + assert "tsd.npz" in [f.name for f in Path('.').iterdir()] tsd.save("tsd2") - os.listdir('.') - assert "tsd2.npz" in os.listdir(".") + assert "tsd2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsd.npz") @@ -782,8 +779,8 @@ def test_save_npz(self, tsd): np.testing.assert_array_almost_equal(file['start'], tsd.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsd.time_support.end) - os.remove("tsd.npz") - os.remove("tsd2.npz") + Path("tsd.npz").unlink() + Path("tsd2.npz").unlink() def test_interpolate(self, tsd): @@ -987,8 +984,6 @@ def test_bin_average_with_ep(self, tsdframe): np.testing.assert_array_almost_equal(meantsd.values, tmp.loc[np.arange(1,5)].values) def test_save_npz(self, tsdframe): - import os - with pytest.raises(TypeError) as e: tsdframe.save(dict) @@ -1002,12 +997,10 @@ def test_save_npz(self, tsdframe): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdframe.save("tsdframe.npz") - os.listdir('.') - assert "tsdframe.npz" in os.listdir(".") + assert "tsdframe.npz" in [f.name for f in Path('.').iterdir()] tsdframe.save("tsdframe2") - os.listdir('.') - assert "tsdframe2.npz" in os.listdir(".") + assert "tsdframe2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsdframe.npz") @@ -1024,8 +1017,8 @@ def test_save_npz(self, tsdframe): np.testing.assert_array_almost_equal(file['end'], tsdframe.time_support.end) np.testing.assert_array_almost_equal(file['columns'], tsdframe.columns) - os.remove("tsdframe.npz") - os.remove("tsdframe2.npz") + Path("tsdframe.npz").unlink() + Path("tsdframe2.npz").unlink() def test_interpolate(self, tsdframe): @@ -1109,8 +1102,6 @@ def test_str_(self, ts): assert isinstance(ts.__str__(), str) def test_save_npz(self, ts): - import os - with pytest.raises(TypeError) as e: ts.save(dict) @@ -1124,12 +1115,10 @@ def test_save_npz(self, ts): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) ts.save("ts.npz") - os.listdir('.') - assert "ts.npz" in os.listdir(".") + assert "ts.npz" in [f.name for f in Path('.').iterdir()] ts.save("ts2") - os.listdir('.') - assert "ts2.npz" in os.listdir(".") + assert "ts2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("ts.npz") @@ -1142,8 +1131,8 @@ def test_save_npz(self, ts): np.testing.assert_array_almost_equal(file['start'], ts.time_support.start) np.testing.assert_array_almost_equal(file['end'], ts.time_support.end) - os.remove("ts.npz") - os.remove("ts2.npz") + Path("ts.npz").unlink() + Path("ts2.npz").unlink() def test_fillna(self, ts): with pytest.raises(AssertionError): @@ -1349,8 +1338,6 @@ def test_bin_average(self, tsdtensor): np.testing.assert_array_almost_equal(meantsd.values, tmp) def test_save_npz(self, tsdtensor): - import os - with pytest.raises(TypeError) as e: tsdtensor.save(dict) @@ -1364,12 +1351,10 @@ def test_save_npz(self, tsdtensor): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsdtensor.save("tsdtensor.npz") - os.listdir('.') - assert "tsdtensor.npz" in os.listdir(".") + assert "tsdtensor.npz" in [f.name for f in Path('.').iterdir()] tsdtensor.save("tsdtensor2") - os.listdir('.') - assert "tsdtensor2.npz" in os.listdir(".") + assert "tsdtensor2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsdtensor.npz") @@ -1384,8 +1369,8 @@ def test_save_npz(self, tsdtensor): np.testing.assert_array_almost_equal(file['start'], tsdtensor.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsdtensor.time_support.end) - os.remove("tsdtensor.npz") - os.remove("tsdtensor2.npz") + Path("tsdtensor.npz").unlink() + Path("tsdtensor2.npz").unlink() def test_interpolate(self, tsdtensor): diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index eb7449cd..b4743911 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -520,8 +520,6 @@ def test_to_tsd_runtime_errors(self, group): def test_save_npz(self, group): - import os - group = { 0: nap.Tsd(t=np.arange(0, 20), d = np.random.rand(20)), 1: nap.Tsd(t=np.arange(0, 20, 0.5), d=np.random.rand(40)), @@ -543,12 +541,10 @@ def test_save_npz(self, group): assert str(e.value) == "Path {} does not exist.".format(Path(fake_path).resolve()) tsgroup.save("tsgroup.npz") - os.listdir('.') - assert "tsgroup.npz" in os.listdir(".") + assert "tsgroup.npz" in [f.name for f in Path('.').iterdir()] tsgroup.save("tsgroup2") - os.listdir('.') - assert "tsgroup2.npz" in os.listdir(".") + assert "tsgroup2.npz" in [f.name for f in Path('.').iterdir()] file = np.load("tsgroup.npz") @@ -589,9 +585,9 @@ def test_save_npz(self, group): assert 'd' not in list(file.keys()) np.testing.assert_array_almost_equal(file['t'], tsgroup3[0].index) - os.remove("tsgroup.npz") - os.remove("tsgroup2.npz") - os.remove("tsgroup3.npz") + Path("tsgroup.npz").unlink() + Path("tsgroup2.npz").unlink() + Path("tsgroup3.npz").unlink() @pytest.mark.parametrize( "keys, expectation", From 54d8cb6ebaf2b0828974aa44edc84c70276e0d4d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:20:56 +0100 Subject: [PATCH 062/195] better wavelet API notebook --- docs/examples/tutorial_wavelet_api.py | 306 +++++++++++++++++++------- pynapple/process/signal_processing.py | 2 +- 2 files changed, 223 insertions(+), 85 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index 55221265..e37e6c40 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -29,109 +29,92 @@ # %% # *** -# Generating a dummy signal +# Generating a Dummy Signal # ------------------ # Let's generate a dummy signal to analyse with wavelets! +# +# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined +# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. -# Our dummy dataset will contain two components, a low frequency 3Hz sinusoid combined -# with a weaker 25Hz sinusoid. -t = np.linspace(0, 10, 10000) -sig = nap.Tsd(d=np.sin(t * (5 + t) * np.pi * 2) + np.sin(t * 3 * np.pi * 2), t=t) -# Plot it -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) -ax.plot(sig) -ax.margins(0) -plt.show() +Fs = 2000 +t = np.linspace(0, 5, Fs*5) +two_hz_phase = t * 2 * np.pi * 2 +two_hz_component = np.sin(two_hz_phase) +increasing_freq_component = np.sin(t * (5+t) * np.pi * 2) +sig = nap.Tsd(d=two_hz_component + increasing_freq_component + + np.random.normal(0,0.1,10000), t=t) # %% -# *** -# Getting our Morlet wavelet filter bank -# ------------------ - -freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=1.5, scaling=1.0, precision=10 -) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") +# Lets plot it. +fig, ax = plt.subplots(3, constrained_layout=True, figsize=(10, 5)) +ax[0].plot(t, two_hz_component) +ax[1].plot(t, increasing_freq_component) +ax[2].plot(sig) +ax[0].set_title("2Hz Component") +ax[1].set_title("Increasing Frequency Component") +ax[2].set_title("Dummy Signal") +[ax[i].margins(0) for i in range(3)] +[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] +[ax[i].set_xlabel("Time (s)") for i in range(3)] +[ax[i].set_ylabel("Signal") for i in range(3)] +[ax[i].set_ylim(-2.5,2.5) for i in range(3)] plt.show() + # %% # *** -# Effect of n_cycles +# Getting our Morlet Wavelet Filter Bank # ------------------ +# We will be decomposing our dummy signal using wavelets of different frequencies. These wavelets +# can be examined using the `generate_morlet_filterbank' function. Here we will use the default parameters +# to define a Morlet filter bank with which we will later use to deconstruct the signal. +# Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) +# Get the filter bank filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=1.0, precision=10 + freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 ) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") -plt.show() # %% -# *** -# Effect of scaling -# ------------------ - -freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=2.0, precision=10 -) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -offset = 0.2 -for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), filter_bank[f_i, :] + offset * f_i - ) - ax.text(-2.2, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") -ax.margins(0) -ax.yaxis.set_visible(False) -ax.spines["left"].set_visible(False) -ax.spines["right"].set_visible(False) -ax.spines["top"].set_visible(False) -ax.set_xlim(-2, 2) -ax.set_xlabel("Time (s)") -ax.set_title("Morlet Wavelet Filter Bank") -plt.show() +# Lets plot it. +def plot_filterbank(filter_bank, freqs, title): + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + offset = 0.2 + for f_i in range(filter_bank.shape[0]): + ax.plot( + np.linspace(-8, 8, filter_bank.shape[1]), + filter_bank[f_i, :].real + offset * f_i + ) + ax.text(-2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") + ax.margins(0) + ax.yaxis.set_visible(False) + [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] + ax.set_xlim(-2, 2) + ax.set_xlabel("Time (s)") + ax.set_title(title) + plt.show() +title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" +plot_filterbank(filter_bank, freqs, title) # %% # *** -# Decomposing the dummy signal +# Decomposing the Dummy Signal # ------------------ +# Here we will use the `compute_wavelet_transform' function to decompose our signal using the filter bank shown +# above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and +# frequency information for analysis. We will calculate this decomposition and plot it's corresponding +# scalogram. +# Compute the wavelet transform using the parameters above mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=15 + sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 ) - +# %% +# Lets plot it. def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): if np.iscomplexobj(powers): powers = abs(powers) @@ -150,24 +133,131 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) -fig, ax = plt.subplots(1) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") +plt.show() + +# %% +# *** +# Reconstructing the Slow Oscillation and Phase +# ------------------ +# We can see that the decomposition has picked up on the 2Hz component of the signal, as well as the component with +# increasing frequency. In this section, we will extract just the 2Hz component from the wavelet decomposition, +# and see how it compares to the original section. + +# Get the index of the 2Hz frequency +two_hz_freq_idx = np.where(freqs == 2.)[0] +# The 2Hz component is the real component of the wavelet decomposition at this index +slow_oscillation = mwt[:, two_hz_freq_idx].values.real +# The 2Hz wavelet phase is the angle of the wavelet decomposition at this index +slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx].values) + +# %% +# Lets plot it. +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +axd = fig.subplot_mosaic( + [["signal"], ["phase"]], + height_ratios=[1, 0.4], +) +axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) +axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") +axd["signal"].legend() +axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) +axd["phase"].set_ylabel("Phase (rad)") +axd["signal"].set_ylabel("Signal") +axd["phase"].set_xlabel("Time (s)") +[axd[f].spines[sp].set_visible(False) for sp in ["right", "top"] for f in ["phase", "signal"]] +axd["signal"].get_xaxis().set_visible(False) +axd["signal"].spines["bottom"].set_visible(False) +[axd[k].margins(0) for k in ["signal", "phase"]] +axd["signal"].set_ylim(-2.5,2.5) +axd["phase"].set_ylim(-np.pi, np.pi) plt.show() +# %% +# *** +# Adding in the 15Hz Oscillation +# ------------------ +# Let's see what happens if we also add the 15 Hz component of the wavelet decomposition to the reconstruction. We +# will extract the 15 Hz components, and also the 15Hz wavelet power over time. The wavelet power tells us to what +# extent the 15 Hz frequency is present in our signal at different times. +# +# Finally, we will add this 15 Hz reconstruction to the one shown above, to see if it improves out reconstructed +# signal. + +# Get the index of the 15 Hz frequency +fifteen_hz_freq_idx = np.where(freqs == 15.)[0] +# The 15 Hz component is the real component of the wavelet decomposition at this index +fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real +# The 15 Hz poser is the absolute value of the wavelet decomposition at this index +fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx].values) + +# %% +# Lets plot it. +fig, ax = plt.subplots(2, constrained_layout=True, figsize=(10, 6)) +ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") +ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax[1].plot(sig, label="Raw Signal", alpha=0.5) +ax[1].plot(t, slow_oscillation+fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +[ax[i].set_ylim(-2.5,2.5) for i in range(2)] +[ax[i].margins(0) for i in range(2)] +[ax[i].legend() for i in range(2)] +[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] +ax[0].get_xaxis().set_visible(False) +ax[0].spines["bottom"].set_visible(False) +ax[1].set_xlabel("Time (s)") +[ax[i].set_ylabel("Signal") for i in range(2)] +plt.show() # %% # *** -# Increasing n_cycles increases resolution of decomposition +# Adding ALL the Oscillations! # ------------------ +# Let's now add together the real components of all frequency bands to recreate a version of the original signal. +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() +plt.show() + + + +# %% +# *** +# Effect of n_cycles +# ------------------ + +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 +) + +plot_filterbank(filter_bank, freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0") + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=10 + sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=20 ) + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], @@ -175,21 +265,69 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): np.transpose(mwt[:, :].values), ax=ax, ) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() plt.show() + # %% # *** -# Increasing n_cycles increases resolution of decomposition +# Effect of scaling # ------------------ -mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, n_cycles=7.5, scaling=2.0, precision=10 +freqs = np.linspace(1, 25, num=25) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 ) + +plot_filterbank(filter_bank, freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0") + +# %% +# *** +# Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=2.0, precision=20 +) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) + +# %% +# *** +# And let's see if that has an effect on the reconstructed version of the signal + +combined_oscillations = mwt.values.real.sum(axis=1) + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, alpha=0.5, label="Signal") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.legend() plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 28329092..7cbacc05 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -197,7 +197,7 @@ def compute_wavelet_transform( convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": - coef *= coef * (-np.sqrt(scaling) / (freqs / fs)) + coef *= (-np.sqrt(scaling) / (freqs / fs)) elif norm == "amp": coef *= -scaling / (freqs / fs) coef = np.insert( From 3d1ab70b24c6e9ab8dac9945d2d74bcc06e4504a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:21:22 +0100 Subject: [PATCH 063/195] removed tkagg --- docs/examples/tutorial_wavelet_api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index e37e6c40..8367a33d 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -19,9 +19,6 @@ # # Now, import the necessary libraries: -import matplotlib - -matplotlib.use("TkAgg") import matplotlib.pyplot as plt import numpy as np From 5df9ff045a9492daef2632c8115f480521722e81 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:34:10 +0100 Subject: [PATCH 064/195] linting --- docs/examples/tutorial_wavelet_api.py | 56 ++++++++++++++++++--------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index 8367a33d..b42cb609 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -34,12 +34,14 @@ # with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. Fs = 2000 -t = np.linspace(0, 5, Fs*5) +t = np.linspace(0, 5, Fs * 5) two_hz_phase = t * 2 * np.pi * 2 two_hz_component = np.sin(two_hz_phase) -increasing_freq_component = np.sin(t * (5+t) * np.pi * 2) -sig = nap.Tsd(d=two_hz_component + increasing_freq_component + - np.random.normal(0,0.1,10000), t=t) +increasing_freq_component = np.sin(t * (5 + t) * np.pi * 2) +sig = nap.Tsd( + d=two_hz_component + increasing_freq_component + np.random.normal(0, 0.1, 10000), + t=t, +) # %% # Lets plot it. @@ -51,11 +53,11 @@ ax[1].set_title("Increasing Frequency Component") ax[2].set_title("Dummy Signal") [ax[i].margins(0) for i in range(3)] -[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] [ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] -[ax[i].set_ylim(-2.5,2.5) for i in range(3)] +[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] plt.show() @@ -74,6 +76,7 @@ freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 ) + # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): @@ -82,9 +85,11 @@ def plot_filterbank(filter_bank, freqs, title): for f_i in range(filter_bank.shape[0]): ax.plot( np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i, :].real + offset * f_i + filter_bank[f_i, :].real + offset * f_i, + ) + ax.text( + -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) - ax.text(-2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] @@ -93,6 +98,7 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_title(title) plt.show() + title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" plot_filterbank(filter_bank, freqs, title) @@ -110,6 +116,7 @@ def plot_filterbank(filter_bank, freqs, title): sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 ) + # %% # Lets plot it. def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): @@ -149,7 +156,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # and see how it compares to the original section. # Get the index of the 2Hz frequency -two_hz_freq_idx = np.where(freqs == 2.)[0] +two_hz_freq_idx = np.where(freqs == 2.0)[0] # The 2Hz component is the real component of the wavelet decomposition at this index slow_oscillation = mwt[:, two_hz_freq_idx].values.real # The 2Hz wavelet phase is the angle of the wavelet decomposition at this index @@ -169,11 +176,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["phase"].set_ylabel("Phase (rad)") axd["signal"].set_ylabel("Signal") axd["phase"].set_xlabel("Time (s)") -[axd[f].spines[sp].set_visible(False) for sp in ["right", "top"] for f in ["phase", "signal"]] +[ + axd[f].spines[sp].set_visible(False) + for sp in ["right", "top"] + for f in ["phase", "signal"] +] axd["signal"].get_xaxis().set_visible(False) axd["signal"].spines["bottom"].set_visible(False) [axd[k].margins(0) for k in ["signal", "phase"]] -axd["signal"].set_ylim(-2.5,2.5) +axd["signal"].set_ylim(-2.5, 2.5) axd["phase"].set_ylim(-np.pi, np.pi) plt.show() @@ -189,7 +200,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # signal. # Get the index of the 15 Hz frequency -fifteen_hz_freq_idx = np.where(freqs == 15.)[0] +fifteen_hz_freq_idx = np.where(freqs == 15.0)[0] # The 15 Hz component is the real component of the wavelet decomposition at this index fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real # The 15 Hz poser is the absolute value of the wavelet decomposition at this index @@ -201,8 +212,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") ax[1].plot(sig, label="Raw Signal", alpha=0.5) -ax[1].plot(t, slow_oscillation+fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") -[ax[i].set_ylim(-2.5,2.5) for i in range(2)] +ax[1].plot( + t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction" +) +[ax[i].set_ylim(-2.5, 2.5) for i in range(2)] [ax[i].margins(0) for i in range(2)] [ax[i].legend() for i in range(2)] [ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] @@ -234,7 +247,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): plt.show() - # %% # *** # Effect of n_cycles @@ -245,8 +257,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 ) -plot_filterbank(filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0") +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0", +) # %% # *** @@ -293,8 +308,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 ) -plot_filterbank(filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0") +plot_filterbank( + filter_bank, + freqs, + "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0", +) # %% # *** From c942a792f79c224443878699ef6ae76fdf27834a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 18 Jul 2024 20:39:14 +0100 Subject: [PATCH 065/195] linting --- pynapple/process/signal_processing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 7cbacc05..d9b79fd4 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -197,7 +197,7 @@ def compute_wavelet_transform( convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": - coef *= (-np.sqrt(scaling) / (freqs / fs)) + coef *= -np.sqrt(scaling) / (freqs / fs) elif norm == "amp": coef *= -scaling / (freqs / fs) coef = np.insert( From 04e9d8a99eed5337e7749f4026882465b3b68ba8 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:55:25 +0100 Subject: [PATCH 066/195] wavelet api tutorial improved, generate_filterbank returns TdsFrame --- docs/examples/tutorial_wavelet_api.py | 103 +++++++++++++++++++------- pynapple/process/signal_processing.py | 50 ++++--------- 2 files changed, 91 insertions(+), 62 deletions(-) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/examples/tutorial_wavelet_api.py index b42cb609..fdbd162d 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/examples/tutorial_wavelet_api.py @@ -3,7 +3,7 @@ Wavelet API tutorial ============ -Working with Wavelets. +Working with Wavelets! See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. @@ -72,9 +72,7 @@ # Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) # Get the filter bank -filter_bank = nap.generate_morlet_filterbank( - freqs, Fs, n_cycles=1.5, scaling=1.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, Fs, n_cycles=1.5, scaling=1.0) # %% @@ -82,11 +80,8 @@ def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) offset = 0.2 - for f_i in range(filter_bank.shape[0]): - ax.plot( - np.linspace(-8, 8, filter_bank.shape[1]), - filter_bank[f_i, :].real + offset * f_i, - ) + for f_i in range(filter_bank.shape[1]): + ax.plot(filter_bank[:, f_i] + offset * f_i) ax.text( -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) @@ -112,9 +107,7 @@ def plot_filterbank(filter_bank, freqs, title): # scalogram. # Compute the wavelet transform using the parameters above -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0) # %% @@ -237,25 +230,77 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) ax.margins(0) ax.legend() plt.show() +# %% +# *** +# Parametrization +# ------------------ +# Our reconstruction seems to get the amplitude modulations of our signal correct, but the amplitude is overestimated, +# in particular towards the end of the time period. Often, this is due to a suboptimal choice of parameters, which +# can lead to a low spatial or temporal resolution. Let's explore what changing our parameters does to the +# underlying wavelets. + +freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) +scales = [1.0, 2.0, 3.0] +cycles = [1.0, 2.0, 3.0] + +fig, ax = plt.subplots( + len(scales), len(cycles), constrained_layout=True, figsize=(10, 5) +) +for row_i, sc in enumerate(scales): + for col_i, cyc in enumerate(cycles): + filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, n_cycles=cyc, scaling=sc + ) + ax[row_i, col_i].plot(filter_bank[:, 0]) + ax[row_i, col_i].set_xlim(-15, 15) + ax[row_i, col_i].set_xlabel("Time (s)") + ax[row_i, col_i].set_ylabel("Signal") + [ + ax[row_i, col_i].spines[sp].set_visible(False) + for sp in ["top", "right", "left"] + ] + ax[row_i, col_i].get_yaxis().set_visible(False) + fig.text( + 0.01, + 0.6 / len(scales) + row_i / len(scales), + f"scaling={sc}", + ha="center", + va="center", + rotation="vertical", + fontsize=8, + ) +for col_i, cyc in enumerate(cycles): + ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=8) +fig.suptitle("Parametrization Visualization") +plt.show() + +# %% +# Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the +# Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution +# and frequency resolution. +# +# The scale parameter determines the dilation or compression of the wavelet. It controls the size of the wavelet in +# time, affecting the overall shape of the wavelet. + # %% # *** # Effect of n_cycles # ------------------ +# Let's increase n_cycles to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=7.5, scaling=1.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0) plot_filterbank( filter_bank, @@ -266,9 +311,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( @@ -277,6 +320,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") # %% # *** @@ -288,45 +332,47 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") +ax.set_ylim(-6, 6) ax.margins(0) ax.legend() plt.show() +# %% +# There's a small improvement, but perhaps we can do better. + # %% # *** # Effect of scaling # ------------------ +# Let's increase scaling to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=1.5, scaling=2.0, precision=20 -) +filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=2.0", + "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=2.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -mwt = nap.compute_wavelet_transform( - sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=2.0, precision=20 -) +mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=2.0) plot_timefrequency( mwt.index.values[:], freqs[:], np.transpose(mwt[:, :].values), ax=ax, ) +ax.set_title("Wavelet Decomposition Scalogram") # %% # *** @@ -338,11 +384,12 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction") +ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) [ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") ax.margins(0) +ax.set_ylim(-6, 6) ax.legend() plt.show() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index d9b79fd4..0c3f6e13 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -51,35 +51,9 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret -def compute_welch_spectogram(sig, fs=None): - """ - Performs Welch's decomposition on sig, returns output. - Estimates the power spectral density of a signal by segmenting it into overlapping sections, applying a - window function to each segment, computing their FFTs, and averaging the resulting periodograms to reduce noise. - - ..todo: remove this or add binsize parameter - ..todo: be careful of border artifacts - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Time series. - fs : float, optional - Sampling rate, in Hz. If None, will be calculated from the given signal - """ - if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError( - "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" - ) - if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) - freqs, spectogram = welch(sig.values, fs=fs, axis=0) - return pd.DataFrame(spectogram, freqs) - - def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ - Defines the complex Morlet wavelet kernel + Defines the complex Morlet wavelet kernel. Parameters ---------- @@ -137,7 +111,7 @@ def compute_wavelet_transform( sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None ): """ - Compute the time-frequency representation of a signal using morlet wavelets. + Compute the time-frequency representation of a signal using Morlet wavelets. Parameters ---------- @@ -192,8 +166,8 @@ def compute_wavelet_transform( sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) - convolved_real = sig.convolve(np.transpose(filter_bank.real)) - convolved_imag = sig.convolve(np.transpose(filter_bank.imag)) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j coef = -np.diff(convolved, axis=0) if norm == "sss": @@ -217,6 +191,8 @@ def compute_wavelet_transform( def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): """ + Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, + or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. Parameters ---------- @@ -236,14 +212,17 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 Returns ------- - filter_bank : np.ndarray + filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given """ + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") filter_bank = [] + time_cutoff = 8 morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) - x = np.linspace(-8, 8, int(2**precision)) + x = np.linspace(-time_cutoff, time_cutoff, int(2**precision)) int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) - max_len = 0 + max_len = -1 for freq in freqs: scale = scaling / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) @@ -253,6 +232,9 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 int_psi_scale = int_psi[j][::-1] if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) + time = np.linspace( + -time_cutoff * scaling / freq, time_cutoff * scaling / freq, max_len + ) filter_bank.append(int_psi_scale) filter_bank = [ np.pad( @@ -262,7 +244,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 ) for arr in filter_bank ] - return np.array(filter_bank) + return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) def _integrate(arr, step): From c49d767451c18afebc573c08ef85f18fd6dad25f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:57:28 +0100 Subject: [PATCH 067/195] welch removed --- pynapple/process/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 0986cc6d..a73dea00 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -18,7 +18,6 @@ from .signal_processing import ( compute_spectogram, compute_wavelet_transform, - compute_welch_spectogram, generate_morlet_filterbank, ) from .tuning_curves import ( From 917f932d20a93d72a3ad5429830921acfec3b6b6 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 16:59:22 +0100 Subject: [PATCH 068/195] welch import removed --- pynapple/process/signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 0c3f6e13..1b83ea34 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd -from scipy.signal import welch import pynapple as nap From 10e47fb0a414fe7be888d89d42473367c38c4446 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:22:12 +0100 Subject: [PATCH 069/195] review comments addressed --- .../tutorial_pynapple_wavelets.py} | 47 ++-- docs/examples/tutorial_phase_preferences.py | 240 +++++++++--------- docs/examples/tutorial_signal_processing.py | 181 ++++++------- 3 files changed, 226 insertions(+), 242 deletions(-) rename docs/{examples/tutorial_wavelet_api.py => api_guide/tutorial_pynapple_wavelets.py} (93%) diff --git a/docs/examples/tutorial_wavelet_api.py b/docs/api_guide/tutorial_pynapple_wavelets.py similarity index 93% rename from docs/examples/tutorial_wavelet_api.py rename to docs/api_guide/tutorial_pynapple_wavelets.py index fdbd162d..8cdeb9e1 100644 --- a/docs/examples/tutorial_wavelet_api.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -15,12 +15,15 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # Now, import the necessary libraries: import matplotlib.pyplot as plt import numpy as np +import seaborn + +seaborn.set_theme() import pynapple as nap @@ -58,7 +61,6 @@ [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] [ax[i].set_ylim(-2.5, 2.5) for i in range(3)] -plt.show() # %% @@ -66,7 +68,7 @@ # Getting our Morlet Wavelet Filter Bank # ------------------ # We will be decomposing our dummy signal using wavelets of different frequencies. These wavelets -# can be examined using the `generate_morlet_filterbank' function. Here we will use the default parameters +# can be examined using the `generate_morlet_filterbank` function. Here we will use the default parameters # to define a Morlet filter bank with which we will later use to deconstruct the signal. # Define the frequency of the wavelets in our filter bank @@ -78,7 +80,7 @@ # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): - fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) + fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) offset = 0.2 for f_i in range(filter_bank.shape[1]): ax.plot(filter_bank[:, f_i] + offset * f_i) @@ -91,7 +93,6 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title(title) - plt.show() title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" @@ -101,7 +102,7 @@ def plot_filterbank(filter_bank, freqs, title): # *** # Decomposing the Dummy Signal # ------------------ -# Here we will use the `compute_wavelet_transform' function to decompose our signal using the filter bank shown +# Here we will use the `compute_wavelet_transform` function to decompose our signal using the filter bank shown # above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and # frequency information for analysis. We will calculate this decomposition and plot it's corresponding # scalogram. @@ -128,9 +129,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): y_ticks = freqs y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( mwt.index.values[:], freqs[:], @@ -138,7 +140,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax=ax, ) ax.set_title("Wavelet Decomposition Scalogram") -plt.show() # %% # *** @@ -179,7 +180,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): [axd[k].margins(0) for k in ["signal", "phase"]] axd["signal"].set_ylim(-2.5, 2.5) axd["phase"].set_ylim(-np.pi, np.pi) -plt.show() # %% # *** @@ -216,7 +216,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[0].spines["bottom"].set_visible(False) ax[1].set_xlabel("Time (s)") [ax[i].set_ylabel("Signal") for i in range(2)] -plt.show() # %% # *** @@ -224,7 +223,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -238,7 +237,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set_ylim(-6, 6) ax.margins(0) ax.legend() -plt.show() # %% @@ -265,25 +263,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[row_i, col_i].plot(filter_bank[:, 0]) ax[row_i, col_i].set_xlim(-15, 15) ax[row_i, col_i].set_xlabel("Time (s)") - ax[row_i, col_i].set_ylabel("Signal") + ax[row_i, col_i].set_yticks([]) [ ax[row_i, col_i].spines[sp].set_visible(False) for sp in ["top", "right", "left"] ] - ax[row_i, col_i].get_yaxis().set_visible(False) - fig.text( - 0.01, - 0.6 / len(scales) + row_i / len(scales), - f"scaling={sc}", - ha="center", - va="center", - rotation="vertical", - fontsize=8, - ) + if col_i != 0: + ax[row_i, col_i].get_yaxis().set_visible(False) for col_i, cyc in enumerate(cycles): - ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=8) + ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=10) +for row_i, scl in enumerate(scales): + ax[row_i, 0].set_ylabel(f"scaling={scl}", fontsize=10) fig.suptitle("Parametrization Visualization") -plt.show() # %% # Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the @@ -326,7 +317,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -340,7 +331,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.set_ylim(-6, 6) ax.margins(0) ax.legend() -plt.show() # %% # There's a small improvement, but perhaps we can do better. @@ -378,7 +368,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.values.real.sum(axis=1) +combined_oscillations = mwt.sum(axis=1) # %% # Lets plot it. @@ -392,4 +382,3 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax.margins(0) ax.set_ylim(-6, 6) ax.legend() -plt.show() diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index b4f1f6ec..f5af3b66 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial 2 +Computing Phase Preferences ============ In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, @@ -20,7 +20,7 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # First, import the necessary libraries: @@ -29,10 +29,14 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd import requests import scipy +import seaborn import tqdm +seaborn.set_theme() + import pynapple as nap # %% @@ -62,7 +66,7 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +FS = 1250 # We know from the methods of the paper print(data) @@ -75,7 +79,7 @@ # Define the IntervalSet for this run and instantiate both LFP and # Position TsdFrame objects REM_minute_interval = nap.IntervalSet( - data["rem"]["start"][0] + 90.0, + data["rem"]["start"][0] + 95.0, data["rem"]["start"][0] + 100.0, ) REM_Tsd = data["eeg"].restrict(REM_minute_interval) @@ -88,9 +92,7 @@ (data["units"][i].times() > REM_minute_interval["start"][0]) & (data["units"][i].times() < REM_minute_interval["end"][0]) ] - -# The given dataset has only one channel, so we set channel = 0 here -channel = 0 +spikes_tsdg = data["units"].restrict(REM_minute_interval) # %% # *** @@ -98,19 +100,17 @@ # ----------------------------------- # We should first plot our REM Local Field Potential data. -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) - +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) ax.plot( - REM_Tsd[:, channel], + REM_Tsd, label="REM LFP Data", - color="green", + color="blue", ) ax.set_title("REM Local Field Potential") ax.set_ylabel("LFP (v)") ax.set_xlabel("time (s)") ax.margins(0) ax.legend() -plt.show() # %% # *** @@ -121,32 +121,15 @@ # as we did in the last tutorial, to see get a more informative breakdown of the # frequencies present in the data. -# We must define the frequency set that we'd like to use for our decomposition; -# these have been manually selected based on the frequencies used in -# Frey et. al (2021), but could also be defined as a linspace or logspace -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 152.35, - 192.19, - 200.0, - 234.38, - 270.00, - 331.5, - 390.00, - ] -) -mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, channel], fs=None, freqs=freqs) + +# We must define the frequency set that we'd like to use for our decomposition +freqs = np.geomspace(5, 200, 25) +# Compute the wavelet transform on our LFP data +mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, 0], fs=FS, freqs=freqs) + +# %% +# *** +# Now let's plot the calculated wavelet scalogram. # Define wavelet decomposition plotting function @@ -163,9 +146,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs + y_ticks = [np.round(f, 2) for f in freqs] y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) # And plot it @@ -190,19 +174,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["lfp_rem"].get_xaxis().set_visible(False) for spine in ["top", "right", "bottom", "left"]: axd["lfp_rem"].spines[spine].set_visible(False) -plt.show() # %% # *** # Visualizing Theta Band Power and Phase # ----------------------------------- # There seems to be a strong theta frequency present in the data during the maze traversal. -# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well -# they match up. We will also extract and plot the phase of the 8Hz wavelet from the decomposition. -theta_freq_index = 3 +# Let's plot the estimated 7Hz component of the wavelet decomposition on top of our data, and see how well +# they match up. We will also extract and plot the phase of the 7Hz wavelet from the decomposition. +theta_freq_index = np.argmin(np.abs(7 - freqs)) theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real # calculating phase here -theta_band_phase = np.angle(mwt_REM[:, theta_freq_index].values) +theta_band_phase = nap.Tsd( + t=mwt_REM.index, d=np.angle(mwt_REM[:, theta_freq_index].values) +) + +# %% +# *** +# Now let's plot the theta power and phase, along with the LFP. fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( @@ -213,63 +202,78 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): height_ratios=[0.4, 0.2], ) -axd["theta_pow"].plot( - REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" -) +axd["theta_pow"].plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") axd["theta_pow"].plot( REM_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["theta_pow"].set_ylabel("LFP (v)") axd["theta_pow"].set_xlabel("Time (s)") -axd["theta_pow"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") # +axd["theta_pow"].set_title( + f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power." +) # axd["theta_pow"].legend() -axd["phase"].plot(theta_band_phase) +axd["phase"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) [axd[k].margins(0) for k in ["theta_pow", "phase"]] axd["phase"].set_ylabel("Phase") -plt.show() +axd["phase"].get_xaxis().set_visible(False) # %% # *** # Finding Phase of Spikes # ----------------------------------- -# Now that we have the phase of our theta wavelet, and our spike times, we can find the theta phase at which every -# spike occurs +# Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences +# of each of the units using the compute_1d_tuning_curves function. +# +# We will start by throwing away cells which do not have a high enough firing rate during our interval. + +# Filter units based on firing rate +spikes_tsdg = spikes_tsdg[spikes_tsdg.rate > 5.0] +# Calculate theta phase firing preferences +tuning_curves = nap.compute_1d_tuning_curves( + group=spikes_tsdg, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) +) + +# %% +# *** +# Now we will plot these preferences as smoothed angular histograms. We will select the first 6 units +# to plot. + + +def smoothAngularTuningCurves(tuning_curves, sigma=2): + tmp = np.concatenate( + (tuning_curves.values, tuning_curves.values, tuning_curves.values) + ) + tmp = scipy.ndimage.gaussian_filter1d(tmp, sigma=sigma, axis=0) + return pd.DataFrame( + index=tuning_curves.index, + data=tmp[tuning_curves.shape[0] : tuning_curves.shape[0] * 2], + columns=tuning_curves.columns, + ) -# We will start by throwing away cells which do not have enough -# spikes during our interval -spikes = {k: v for k, v in spikes.items() if len(v) > 20} -# Get phase of each spike -phase = {} -for i in spikes.keys(): - phase_i = [] - for spike in spikes[i]: - phase_i.append( - np.angle( - mwt_REM[ - np.argmin(np.abs(REM_Tsd.index.values - spike)), theta_freq_index - ] - ) - ) - phase[i] = np.array(phase_i) -# Let's plot phase histograms for the first six units to see if there's -# any obvious preferences -fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) -for ri in range(2): - for ci in range(3): - ax[ri, ci].hist( - phase[list(phase.keys())[ri * 3 + ci]], - bins=np.linspace(-np.pi, np.pi, 10), - density=True, - ) - ax[ri, ci].set_xlabel("Phase (rad)") - ax[ri, ci].set_ylabel("Density") - ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") +smoothcurves = smoothAngularTuningCurves(tuning_curves, sigma=2) +fig, axd = plt.subplot_mosaic( + [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], + constrained_layout=True, + figsize=(10, 6), + subplot_kw={"projection": "polar"}, +) +for pl_i, sc_i in enumerate(list(smoothcurves)[:6]): + axd[f"phase_{pl_i}"].plot( + list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), + list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + ) + axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis + axd[f"phase_{pl_i}"].set_ylabel( + "Firing Rate (Hz)" + ) # Firing rate in Hz, on the Y-axis + axd[f"phase_{pl_i}"].set_xticks([]) + axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") fig.suptitle("Phase Preference Histograms of First 6 Units") -plt.show() + # %% # *** @@ -279,30 +283,38 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Now that we have our phases of firing for each unit, we can sort the units by the circular variance of the phase # of their spikes, to isolate the cells with the strongest phase preferences without manual inspection. -variances = { +# Get phase of each spike +phase = {} +for i in spikes_tsdg: + phase_i = [ + theta_band_phase[np.argmin(np.abs(REM_Tsd.index.values - s.index))] + for s in spikes_tsdg[i] + ] + phase[i] = np.array(phase_i) +phase_var = { key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) for key, value in phase.items() } -spikes = dict(sorted(spikes.items(), key=lambda item: variances[item[0]])) -phase = dict(sorted(phase.items(), key=lambda item: variances[item[0]])) - -# Now let's plot phase histograms for the six units with the least -# varied phase of spikes. -fig, ax = plt.subplots(2, 3, constrained_layout=True, figsize=(10, 6)) -for ri in range(2): - for ci in range(3): - ax[ri, ci].hist( - phase[list(phase.keys())[ri * 3 + ci]], - bins=np.linspace(-np.pi, np.pi, 10), - density=True, - ) - ax[ri, ci].set_xlabel("Phase (rad)") - ax[ri, ci].set_ylabel("Density") - ax[ri, ci].set_title(f"Unit {list(phase.keys())[ri*3 + ci]}") -fig.suptitle( - "Phase Preference Histograms of 6 Units with " + "Highest Phase Preference" +phase_var = dict(sorted(phase_var.items(), key=lambda item: item[1])) + +fig, axd = plt.subplot_mosaic( + [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], + constrained_layout=True, + figsize=(10, 6), + subplot_kw={"projection": "polar"}, ) -plt.show() +for pl_i, sc_i in enumerate(list(phase_var.keys())[:6]): + axd[f"phase_{pl_i}"].plot( + list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), + list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + ) + axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis + axd[f"phase_{pl_i}"].set_ylabel( + "Firing Rate (Hz)" + ) # Firing rate in Hz, on the Y-axis + axd[f"phase_{pl_i}"].set_xticks([]) + axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") +fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference ") # %% # *** @@ -311,38 +323,34 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There is definitely some strong phase preferences happening here. Let's visualize the firing preferences # of the 6 cells we've isolated to get an impression of just how striking these preferences are. -fig = plt.figure(constrained_layout=True, figsize=(10, 12)) +fig = plt.figure(constrained_layout=True, figsize=(10, 8)) axd = fig.subplot_mosaic( [ ["lfp_run"], ["phase_0"], ["phase_1"], ["phase_2"], - ["phase_3"], - ["phase_4"], - ["phase_5"], ], - height_ratios=[0.4, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2], + height_ratios=[0.4, 0.2, 0.2, 0.2], ) -[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(6)]] +[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(3)]] axd["lfp_run"].plot( - REM_Tsd.index.values, REM_Tsd[:, channel], alpha=0.5, label="LFP Data - REM" + REM_Tsd.index.values, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM" ) axd["lfp_run"].plot( REM_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index],2)}Hz oscillations", ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power.") axd["lfp_run"].legend() -for i in range(6): +for i in range(3): axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) axd[f"phase_{i}"].scatter( - spikes[list(spikes.keys())[i]], phase[list(spikes.keys())[i]] + spikes[list(phase_var.keys())[i]], phase[list(phase_var.keys())[i]] ) axd[f"phase_{i}"].set_ylabel("Phase") - axd[f"phase_{i}"].set_title(f"Unit {list(spikes.keys())[i]}") + axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") fig.suptitle("Phase Preference Visualizations") -plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index b5d786fe..d9b66be9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Grosmark & Buzsáki (2016) Tutorial 1 +Computing Wavelet Transform ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). @@ -16,7 +16,7 @@ # !!! warning # This tutorial uses matplotlib for displaying the figure # -# You can install all with `pip install matplotlib requests tqdm` +# You can install all with `pip install matplotlib requests tqdm seaborn` # # First, import the necessary libraries: @@ -26,8 +26,11 @@ import matplotlib.pyplot as plt import numpy as np import requests +import seaborn import tqdm +seaborn.set_theme() + import pynapple as nap # %% @@ -57,7 +60,7 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = len(data["eeg"].index[:]) / (data["eeg"].index[-1] - data["eeg"].index[0]) +FS = 1250 print(data) @@ -78,9 +81,7 @@ ) RUN_Tsd = data["eeg"].restrict(RUN_interval) RUN_pos = data["position"].restrict(RUN_interval) - -# The given dataset has only one channel, so we set channel = 0 here -channel = 0 +print(RUN_Tsd) # %% # *** @@ -90,11 +91,11 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [["ephys"], ["pos"]], - height_ratios=[1, 0.2], + height_ratios=[1, 0.4], ) axd["ephys"].plot( - RUN_Tsd[:, channel].restrict( + RUN_Tsd[:, 0].restrict( nap.IntervalSet( data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] ) @@ -103,7 +104,7 @@ color="green", ) axd["ephys"].plot( - RUN_Tsd[:, channel].restrict( + RUN_Tsd[:, 0].restrict( nap.IntervalSet( data["forward_ep"]["end"][run_index], data["forward_ep"]["end"][run_index] + 5.0, @@ -126,17 +127,23 @@ # %% # *** -# Getting the LFP Spectogram +# Getting the LFP Spectrogram # ----------------------------------- -# Let's take the Fourier transforms of one channel for both waking and sleeping and see if differences are present +# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) +print(fft) + +# %% +# *** +# The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. +# +# Now let's plot it -# Now we will plot it fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot( fft.index, - np.abs(fft.iloc[:, channel]), + np.abs(fft.iloc[:, 0]), alpha=0.5, label="LFP Frequency Power", c="blue", @@ -160,32 +167,14 @@ # LFP characteristics may be different while the animal is running along the track, and when it is finished. # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. -# We must define the frequency set that we'd like to use for our decomposition; these -# have been manually selected based on the frequencies used in Frey et. al (2021), but -# could also be defined as a linspace or logspace -freqs = np.array( - [ - 2.59, - 3.66, - 5.18, - 8.0, - 10.36, - 20.72, - 29.3, - 41.44, - 58.59, - 82.88, - 117.19, - 152.35, - 192.19, - 200.0, - 234.38, - 270.00, - 331.5, - 390.00, - ] -) -mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, channel], fs=None, freqs=freqs) +# We must define the frequency set that we'd like to use for our decomposition +freqs = np.geomspace(5, 250, 25) +# Compute and print the wavelet transform on our LFP data +mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, 0], fs=FS, freqs=freqs) + +# %% +# *** +# Now let's plot it. # Define wavelet decomposition plotting function @@ -202,20 +191,21 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): else: x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs + y_ticks = [np.round(f, 2) for f in freqs] y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.grid(False) # And plot -fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig = plt.figure(constrained_layout=True, figsize=(10, 8)) axd = fig.subplot_mosaic( [ ["wd_run"], ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.2, 0.4], + height_ratios=[1.2, 0.2, 0.6], ) plot_timefrequency( RUN_Tsd.index.values[:], @@ -243,36 +233,42 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well # they match up -theta_freq_index = 3 + +# Find the index of the frequency closest to theta band +theta_freq_index = np.argmin(np.abs(10 - freqs)) +# Extract its real component, as well as its power envelope theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) + +# %% +# *** +# Now let's visualise the theta band component of the signal over time. + fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( [ ["lfp_run"], ["pos_run"], ], - height_ratios=[1, 0.3], + height_ratios=[1, 0.4], ) -axd["lfp_run"].plot( - RUN_Tsd.index.values, RUN_Tsd[:, channel], alpha=0.5, label="LFP Data" -) +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], alpha=0.5, label="LFP Data") axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_reconstruction, - label=f"{freqs[theta_freq_index]}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["lfp_run"].plot( RUN_Tsd.index.values, theta_band_power_envelope, - label=f"{freqs[theta_freq_index]}Hz power envelope", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[theta_freq_index]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") axd["pos_run"].plot(RUN_pos) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] [ @@ -292,9 +288,16 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and # see what's going on. -ripple_freq_idx = 13 +# Find the index of the frequency closest to sharp wave ripple oscillations +ripple_freq_idx = np.argmin(np.abs(200 - freqs)) +# Extract its power envelope ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) + +# %% +# *** +# Now let's visualise the 200Hz component of the signal over time. + fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ @@ -303,7 +306,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") +axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], label="LFP Data") axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) @@ -316,7 +319,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["rip_pow"].spines["bottom"].set_visible(False) axd["rip_pow"].spines["left"].set_visible(False) axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") +axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% # *** @@ -325,32 +328,22 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold # to try to isolate this event. -# define our threshold -threshold = 100 -# smooth our wavelet power -window_size = 51 -window = np.ones(window_size) / window_size -smoother_swr_power = np.convolve( - np.abs(mwt_RUN[:, ripple_freq_idx].values), window, mode="same" +# Define threshold +threshold = 6000 +# Smooth wavelet power TsdFrame at the SWR frequency +smoother_swr_power = ( + mwt_RUN[:, ripple_freq_idx] + .abs() + .smooth(std=0.025, windowsize=0.2, time_units="s", norm=False) ) -# isolate our ripple periods -is_ripple = smoother_swr_power > threshold -start_idx = None -ripple_periods = [] -for i in range(len(RUN_Tsd.index.values)): - if is_ripple[i] and start_idx is None: - start_idx = i - elif not is_ripple[i] and start_idx is not None: - axd["rip_pow"].plot( - RUN_Tsd.index.values[start_idx:i], - smoother_swr_power[start_idx:i], - color="red", - linewidth=2, - ) - ripple_periods.append((start_idx, i)) - start_idx = None +# Threshold our TsdFrame +is_ripple = smoother_swr_power.threshold(threshold) + + +# %% +# *** +# Now let's plot the threshold ripple power over time. -# plot of captured ripple periods fig = plt.figure(constrained_layout=True, figsize=(10, 5)) axd = fig.subplot_mosaic( [ @@ -359,24 +352,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, channel], label="LFP Data") -axd["rip_pow"].plot(RUN_Tsd.index.values, smoother_swr_power) -for r in ripple_periods: - axd["rip_pow"].plot( - RUN_Tsd.index.values[r[0] : r[1]], - smoother_swr_power[r[0] : r[1]], - color="red", - linewidth=2, - ) +axd["lfp_run"].plot(RUN_Tsd[:, 0], label="LFP Data") +axd["rip_pow"].plot(smoother_swr_power) +axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{freqs[ripple_freq_idx]}Hz oscillation power.") +axd["lfp_run"].set_title(f"{np.round(freqs[ripple_freq_idx], 2)}Hz oscillation power.") axd["rip_pow"].axhline(threshold) [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] [axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["rip_pow"].set_ylabel(f"{freqs[ripple_freq_idx]}Hz Power") +axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% # *** @@ -384,21 +371,21 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ----------------------------------- # Let's zoom in on out detected ripples and have a closer look! -# Filter out ripples which do not last long enough -ripple_periods = [r for r in ripple_periods if r[1] - r[0] > 20] - -# And plot! fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -buffer = 200 +buffer = 0.1 ax.plot( - RUN_Tsd.index.values[r[0] - buffer : r[1] + buffer], - RUN_Tsd[r[0] - buffer : r[1] + buffer], + RUN_Tsd.restrict( + nap.IntervalSet( + start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer + ) + ), color="blue", label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.index.values[r[0] : r[1]], - RUN_Tsd[r[0] : r[1]], + RUN_Tsd.restrict( + nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) + ), color="red", label="SWR", linewidth=2, From fa7952efa2dfe64b1cac4b2fcc7d71b845c16b22 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:26:35 +0100 Subject: [PATCH 070/195] removing welch tests --- tests/test_signal_processing.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8edc4bd7..510d29da 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -44,35 +44,6 @@ def test_compute_spectogram(): ) -def test_compute_welch_spectogram(): - t = np.linspace(0, 1, 10000) - sig = nap.TsdFrame( - d=np.random.random((10000, 4)), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.4], end=[0.2, 0.525]), - ) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[1] == 4 - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) - r = nap.compute_welch_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[1] == 4 - - with pytest.raises(TypeError) as e_info: - nap.compute_welch_spectogram("a_string") - assert ( - str(e_info.value) - == "Currently compute_welch_spectogram is only implemented for Tsd or TsdFrame" - ) - - def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) From cac44ff9a8dcd35fbf6647479971f862f0705317 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:35:21 +0100 Subject: [PATCH 071/195] fixed broked phase notebook --- docs/examples/tutorial_phase_preferences.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index f5af3b66..3d42c663 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -86,13 +86,7 @@ # We will also extract spike times from all units in our dataset # which occur during our specified interval -spikes = {} -for i in data["units"].index: - spikes[i] = data["units"][i].times()[ - (data["units"][i].times() > REM_minute_interval["start"][0]) - & (data["units"][i].times() < REM_minute_interval["end"][0]) - ] -spikes_tsdg = data["units"].restrict(REM_minute_interval) +spikes = data["units"].restrict(REM_minute_interval) # %% # *** @@ -230,10 +224,10 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # We will start by throwing away cells which do not have a high enough firing rate during our interval. # Filter units based on firing rate -spikes_tsdg = spikes_tsdg[spikes_tsdg.rate > 5.0] +spikes = spikes[spikes.rate > 5.0] # Calculate theta phase firing preferences tuning_curves = nap.compute_1d_tuning_curves( - group=spikes_tsdg, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) + group=spikes, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) ) # %% @@ -285,10 +279,10 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): # Get phase of each spike phase = {} -for i in spikes_tsdg: +for i in spikes: phase_i = [ theta_band_phase[np.argmin(np.abs(REM_Tsd.index.values - s.index))] - for s in spikes_tsdg[i] + for s in spikes[i] ] phase[i] = np.array(phase_i) phase_var = { @@ -349,7 +343,7 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): for i in range(3): axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) axd[f"phase_{i}"].scatter( - spikes[list(phase_var.keys())[i]], phase[list(phase_var.keys())[i]] + spikes[list(phase_var.keys())[i]].index, phase[list(phase_var.keys())[i]] ) axd[f"phase_{i}"].set_ylabel("Phase") axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") From b0bb20f88982683d32a38e58652fd54249e98d5d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 19 Jul 2024 23:37:25 +0100 Subject: [PATCH 072/195] better comments on phase notebook --- docs/examples/tutorial_phase_preferences.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 3d42c663..57fc036a 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -291,6 +291,11 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): } phase_var = dict(sorted(phase_var.items(), key=lambda item: item[1])) +# %% +# *** +# And now we plot the phase preference histograms of the 6 units with the least variance in the phase of their +# spiking behaviour. + fig, axd = plt.subplot_mosaic( [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], constrained_layout=True, From c66fc875354b98cccd40555f23b728f2b4637bad Mon Sep 17 00:00:00 2001 From: Luigi Petrucco Date: Mon, 22 Jul 2024 16:14:35 +0200 Subject: [PATCH 073/195] required test fixes --- tests/test_folder.py | 15 --------------- tests/test_interval_set.py | 20 ++++++++++---------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/tests/test_folder.py b/tests/test_folder.py index c40d7083..4037eb71 100644 --- a/tests/test_folder.py +++ b/tests/test_folder.py @@ -94,18 +94,3 @@ def test_load(path): folder.load() for k in data.keys(): assert type(folder[k]) == type(data[k]) - - - - - - - - - - - - - - - diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index a07837a9..4e338ef7 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -260,11 +260,11 @@ def test_tot_length(): def test_as_units(): ep = nap.IntervalSet(start=0, end=100) - df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"]) - pd.testing.assert_frame_equal(df, ep.as_units("s")) - pd.testing.assert_frame_equal(df * 1e3, ep.as_units("ms")) + df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) + pd.testing.assert_frame_equal(df, ep.as_units("s").astype(np.float64)) + pd.testing.assert_frame_equal(df * 1e3, ep.as_units("ms").astype(np.float64)) tmp = df * 1e6 - np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values) + np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values.astype(np.float64)) def test_intersect(): @@ -498,14 +498,14 @@ def test_save_npz(): ep.save("ep2") assert "ep2.npz" in [f.name for f in Path('.').iterdir()] - file = np.load("ep.npz") + with np.load("ep.npz") as file: - keys = list(file.keys()) - assert 'start' in keys - assert 'end' in keys + keys = list(file.keys()) + assert 'start' in keys + assert 'end' in keys - np.testing.assert_array_almost_equal(file['start'], start) - np.testing.assert_array_almost_equal(file['end'], end) + np.testing.assert_array_almost_equal(file['start'], start) + np.testing.assert_array_almost_equal(file['end'], end) # Cleaning Path("ep.npz").unlink() From 496fbe359f12255fb3814605b22961653fab3df7 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 29 Jul 2024 15:54:15 +0100 Subject: [PATCH 074/195] PR comments addressed, tests added --- docs/examples/tutorial_phase_preferences.py | 5 ---- pynapple/process/signal_processing.py | 32 +++++++++++++++------ tests/test_signal_processing.py | 30 +++++++++++++++++++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 57fc036a..18e311c9 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -3,11 +3,6 @@ Computing Phase Preferences ============ -In the previous [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/) tutorial, -we learned how to use Pynapple's signal processing tools with Local Field Potential data. Specifically, we -used wavelet decompositions to isolate Theta band activity during active traversal of a linear track, -as well as to find Sharp Wave Ripples which occurred after traversal. - In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it with spiking data, to find phase preferences of spiking units. diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 1b83ea34..914a368b 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -14,6 +14,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. + Parameters ---------- sig : pynapple.Tsd or pynapple.TsdFrame Time series. @@ -24,6 +25,12 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + Returns + ------- + pandas.DataFrame + Time frequency representation of the input signal, indexes are frequencies, values + are powers. + Notes ----- compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep @@ -40,7 +47,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): if len(ep) != 1: raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: - fs = sig.index.shape[0] / (sig.index.max() - sig.index.min()) + fs = sig.rate fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) ret = pd.DataFrame(fft_result, fft_freq) @@ -107,7 +114,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas def compute_wavelet_transform( - sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=10, norm=None + sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=16, norm=None ): """ Compute the time-frequency representation of a signal using Morlet wavelets. @@ -128,7 +135,8 @@ def compute_wavelet_transform( scaling : float Scaling factor. precision: int. - Precision of wavelet to use. Default is 8 + Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. + Default is 16 norm : {None, 'sss', 'amp'}, optional Normalization method: * None - no normalization @@ -137,9 +145,18 @@ def compute_wavelet_transform( Returns ------- - pynapple.TsdFrame or pynapple.TsdTensor : 2d array + pynapple.TsdFrame or pynapple.TsdTensor Time frequency representation of the input signal. + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.linspace(0, 1, 1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> freqs = np.linspace(10, 100, 10) + >>> mwt = nap.compute_wavelet_transform(signal, fs=None, freqs=freqs) + Notes ----- This computes the continuous wavelet transform at specified frequencies across time. @@ -158,7 +175,6 @@ def compute_wavelet_transform( fs = sig.rate if isinstance(sig, nap.Tsd): - sig = sig.reshape((sig.shape[0], 1)) output_shape = (sig.shape[0], len(freqs)) else: output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) @@ -176,7 +192,7 @@ def compute_wavelet_transform( coef = np.insert( coef, 1, coef[0, :], axis=0 ) # slightly hacky line, necessary to make output correct shape - cwt = np.swapaxes(coef, 1, 2) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef if len(output_shape) == 2: return nap.TsdFrame( @@ -188,7 +204,7 @@ def compute_wavelet_transform( ) -def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=10): +def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=16): """ Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. @@ -207,7 +223,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 scaling : float Scaling factor. precision: int. - Precision of wavelet to use. + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. Returns ------- diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 510d29da..bfcd1eec 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import pytest +import pywt import pynapple as nap @@ -45,6 +46,35 @@ def test_compute_spectogram(): def test_compute_wavelet_transform(): + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 50 + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == 500 + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.sin(t * 10 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 10 + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == 500 + ) t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) From 647bbfc10d216f6865ff24bcae195469e122a72d Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 29 Jul 2024 15:59:22 +0100 Subject: [PATCH 075/195] unused import removed --- tests/test_signal_processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index bfcd1eec..8f853591 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,7 +3,6 @@ import numpy as np import pandas as pd import pytest -import pywt import pynapple as nap From 057d6553ecb33043fad99e3d27a4a593babe0db8 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jul 2024 16:55:23 -0400 Subject: [PATCH 076/195] REmoved neurosuite nwb saving of intervals --- pynapple/io/neurosuite.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/pynapple/io/neurosuite.py b/pynapple/io/neurosuite.py index 605b59f6..ca05ded2 100755 --- a/pynapple/io/neurosuite.py +++ b/pynapple/io/neurosuite.py @@ -1,10 +1,3 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2022-02-02 20:45:09 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-16 13:21:34 - """ > :warning: **DEPRECATED**: This will be removed in version 1.0.0. Check [nwbmatic](https://github.com/pynapple-org/nwbmatic) or [neuroconv](https://github.com/catalystneuro/neuroconv) instead. @@ -12,6 +5,7 @@ @author: Guillaume Viejo """ + import sys from pathlib import Path @@ -192,13 +186,13 @@ def read_neuroscope_intervals(self, name=None, path2file=None): Contains two columns corresponding to the start and end of the intervals. """ - if name: - isets = self.load_nwb_intervals(name) - if isinstance(isets, nap.IntervalSet): - return isets + # if name: + # isets = self.load_nwb_intervals(name) + # if isinstance(isets, nap.IntervalSet): + # return isets if name is not None and path2file is None: - path2file = self.path / self.basename + "." + name + ".evt" + path2file = self.path / (self.basename + "." + name + ".evt") if path2file is not None: # TODO maybe useless conditional? try: # df = pd.read_csv(path2file, delimiter=' ', usecols = [0], header = None) @@ -210,7 +204,7 @@ def read_neuroscope_intervals(self, name=None, path2file=None): if name is None: name = path2file.split(".")[-2] print("*** saving file in the nwb as", name) - self.save_nwb_intervals(isets, name) + # self.save_nwb_intervals(isets, name) else: raise ValueError("specify a valid path") return isets From 17017eb0de5b0f3812da4708f87eda53468f1624 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 29 Jul 2024 18:02:20 -0400 Subject: [PATCH 077/195] Test --- pynapple/core/interval_set.py | 84 ++++++++++++++++++++++++++++++----- 1 file changed, 73 insertions(+), 11 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index b39e6046..79dcda53 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -354,6 +354,25 @@ def loc(self): """ return _IntervalSetSliceHelper(self) + @classmethod + def _from_npz_reader(cls, file): + """Load an IntervalSet object from a npz file. + + The file should contain the keys 'start', 'end' and 'type'. + The 'type' key should be 'IntervalSet'. + + Parameters + ---------- + file : NPZFile object + opened npz file interface. + + Returns + ------- + IntervalSet + The IntervalSet object + """ + return cls(start=file["start"], end=file["end"]) + def time_span(self): """ Time span of the interval set. @@ -652,21 +671,64 @@ def save(self, filename): return - @classmethod - def _from_npz_reader(cls, file): - """Load an IntervalSet object from a npz file. - - The file should contain the keys 'start', 'end' and 'type'. - The 'type' key should be 'IntervalSet'. + def split(self, interval_size, time_units="s"): + """Split `IntervalSet` to a new `IntervalSet` with each interval being of size `interval_size`. + + Used mostly for chunking very large dataset or looping throught multiple epoch of same duration. + This function skips the epochs that are shorter than `interval_size`. + Parameters ---------- - file : NPZFile object - opened npz file interface. - + interval_size : Number + Description + time_units : str, optional + time units for the `interval_size` ('us', 'ms', 's' [default]) + Returns ------- IntervalSet - The IntervalSet object + New `IntervalSet` with equal sized intervals + + Raises + ------ + IOError + Description """ - return cls(start=file["start"], end=file["end"]) + if not isinstance(interval_size, Number): + raise IOError("Argument interval_size should of type float or int") + + if not isinstance(time_units, str): + raise IOError("Argument time_units should be of type float or int") + + if len(self) == 0: + return IntervalSet(start=[], end=[]) + + interval_size = TsIndex.format_timestamps( + np.array((interval_size,), dtype=np.float64).ravel(), time_units + )[0] + + durations = self.end - self.start + + new_starts = [] + new_ends = [] + + for i in range(len(self)): + if durations[i] > interval_size: + tmp = np.arange(self.start[i], self.end[i], interval_size) + tmp = np.hstack((tmp, np.array([self.end[i]]))) + + new_starts.append(tmp[0:-1]) + new_ends.append(tmp[1:]) + + new_starts = np.hstack(new_starts) + new_ends = np.hstack(new_ends) + + tokeep = np.round(new_ends - new_starts, nap_config.time_index_precision) + + new_iset = IntervalSet(new_starts, new_ends).drop_short_intervals(interval_size-1e-6) + + return new_iset + + + From a5a6295315a76eec9139d111d0788ac5c1ac5595 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 29 Jul 2024 18:55:29 -0400 Subject: [PATCH 078/195] added get_slice func --- pynapple/core/base_class.py | 54 +++++++++++++++++++++++++++ tests/test_time_series.py | 74 ++++++++++++++++++++++++++++++++++++- 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 22ebdaea..eb771d7c 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -466,3 +466,57 @@ def _from_npz_reader(cls, file): } iset = IntervalSet(start=file["start"], end=file["end"]) return cls(time_support=iset, **kwargs) + + def get_slice(self, start, end=None, mode="closest", time_unit="s"): + if not isinstance(start, Number): + raise ValueError(f"'start' must be an int or a float. Type {type(start)} provided instead!") + # convert and get index for start + start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] + + # check end + if end is not None and not isinstance(end, Number): + raise ValueError(f"'end' must be an int or a float. Type {type(end)} provided instead!") + + # get index of preceding time value + idx_start = np.searchsorted(self.t, start, side="left") + if idx_start == len(self.t): + idx_start -= 1 + + if mode == "backward": + # subtract one except if self.t[idx_start] is equal to start + idx_start -= (self.t[idx_start] > start) + elif mode == "closest": + di = np.argmin([self.t[idx_start] - start, np.abs(self.t[idx_start - 1] - start)]) + idx_start -= di + + if end is None: + if idx_start < 0: # happens only on backwards + return slice(0, 0) + elif idx_start == len(self.t) - 1 and mode == "forward": + return slice(idx_start, idx_start) + return slice(idx_start, idx_start + 1) + else: + idx_start = max([0, idx_start]) + + # convert and get index for end + end = TsIndex.format_timestamps(np.array([end]), time_unit)[0] + if start > end: + raise ValueError("'start' should not precede 'end'.") + + idx_end = np.searchsorted(self.t, end, side="left") + add_if_forward = 0 + if idx_end == len(self.t): + add_if_forward = idx_start < len(self.t) + idx_end -= 1 + + if mode == "backward": + # remove 1 if self.t[idx_end] is larger than end, except if idx_end is already 0 + idx_end -= (self.t[idx_end] > end) - int(idx_end == 0) + elif mode == "closest": + di = np.argmin([self.t[idx_end] - end, np.abs(self.t[idx_end - 1] - end)]) + idx_end -= di + elif mode == "forward" and idx_end == len(self.t) - 1: + idx_end += add_if_forward # add one if idx_start < len(self.t) + + return slice(idx_start, idx_end) + diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 0942490f..580b4201 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,11 +1,11 @@ """Tests of time series for `pynapple` package.""" import pickle -from pathlib import Path import numpy as np import pandas as pd import pytest from pathlib import Path +from contextlib import nullcontext as does_not_raise import pynapple as nap @@ -1449,3 +1449,75 @@ def test_pickling(obj): # Ensure time support is the same assert np.all(obj.time_support == unpickled_obj.time_support) + +@pytest.mark.parametrize( + "start, end, expectation", + [ + (1, 3, does_not_raise()), + (3, 1, pytest.raises(ValueError, match="'start' should not precede 'end'")), + (1., 3., does_not_raise()), + (1., None, does_not_raise()), + (None, 3, pytest.raises(ValueError, match="'start' must be an int or a float")), + ("a", 3, pytest.raises(ValueError, match="'start' must be an int or a float")), + (2, "a", pytest.raises(ValueError, match="'end' must be an int or a float")), + + ] +) +@pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) +def test_get_index_value_types(start, end, time_unit, expectation): + ts = nap.Ts(t=np.array([1, 2, 3, 4])) + with expectation: + ts.get_slice(start, end, time_unit=time_unit) + + +@pytest.mark.parametrize( + "start, end, mode, expected_slice, expected_array", + [ + (1, 3, "forward", slice(0, 2), np.array([1, 2])), + (1, 3, "backward", slice(0, 2), np.array([1, 2])), + (1, 3, "closest", slice(0, 2), np.array([1, 2])), + (1, 2.7, "forward", slice(0, 2), np.array([1, 2])), + (1, 2.7, "backward", slice(0, 1), np.array([1])), + (1, 2.7, "closest", slice(0, 2), np.array([1, 2])), + (1, 2.4, "forward", slice(0, 2), np.array([1, 2])), + (1, 2.4, "backward", slice(0, 1), np.array([1])), + (1, 2.4, "closest", slice(0, 1), np.array([1])), + (1.1, 3, "forward", slice(1, 2), np.array([2])), + (1.1, 3, "backward", slice(0, 2), np.array([1, 2])), + (1.1, 3, "closest", slice(0, 2), np.array([1, 2])), + (1.6, 3, "forward", slice(1, 2), np.array([2])), + (1.6, 3, "backward", slice(0, 2), np.array([1, 2])), + (1.6, 3, "closest", slice(1, 2), np.array([2])), + (3, 3, "forward", slice(2, 2), np.array([])), + (3, 3, "backward", slice(2, 2), np.array([])), + (3, 3, "closest", slice(2, 2), np.array([])), + (0, 3, "forward", slice(0, 2), np.array([1, 2])), + (0, 3, "backward", slice(0, 2), np.array([1, 2])), + (0, 3, "closest", slice(0, 2), np.array([1, 2])), + (4, 4, "forward", slice(3, 3), np.array([])), + (4, 4, "backward", slice(3, 3), np.array([])), + (4, 4, "closest", slice(3, 3), np.array([])), + (4, 5, "forward", slice(3, 4), np.array([4])), + (4, 5, "backward", slice(3, 3), np.array([])), + (4, 5, "closest", slice(3, 3), np.array([])), + (0, 1, "forward", slice(0, 0), np.array([])), + (0, 1, "backward", slice(0, 1), np.array([1])), + (0, 1, "closest", slice(0, 0), np.array([])), + (0, None, "forward", slice(0, 1), np.array([1])), + (0, None, "backward", slice(0, 0), np.array([])), + (0, None, "closest", slice(0, 1), np.array([1])), + (1, None, "forward", slice(0, 1), np.array([1])), + (1, None, "backward", slice(0, 1), np.array([1])), + (1, None, "closest", slice(0, 1), np.array([1])), + (5, None, "forward", slice(3, 3), np.array([])), + (5, None, "backward", slice(3, 4), np.array([4])), + (5, None, "closest", slice(3, 4), np.array([4])) + ] +) +def test_get_index_value_types(start, end, mode, expected_slice, expected_array): + ts = nap.Ts(t=np.array([1, 2, 3, 4])) + out_slice = ts.get_slice(start, end=end, mode=mode) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) + From 99e0e47ed44506bb726e29e85d566ef70aa2ccf6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 29 Jul 2024 20:18:40 -0400 Subject: [PATCH 079/195] added comment --- pynapple/core/base_class.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index eb771d7c..e572605d 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -480,23 +480,25 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): # get index of preceding time value idx_start = np.searchsorted(self.t, start, side="left") if idx_start == len(self.t): - idx_start -= 1 + idx_start -= 1 # make sure the index is not out of bound if mode == "backward": - # subtract one except if self.t[idx_start] is equal to start - idx_start -= (self.t[idx_start] > start) + # in order to get the index preceding start + # subtract one except if self.t[idx_start] is exactly equal to start + idx_start -= self.t[idx_start] > start elif mode == "closest": + # subtract 1 if start is closer to the previous index di = np.argmin([self.t[idx_start] - start, np.abs(self.t[idx_start - 1] - start)]) idx_start -= di if end is None: - if idx_start < 0: # happens only on backwards + if idx_start < 0: # happens only on backwards if start < self.t[0] return slice(0, 0) - elif idx_start == len(self.t) - 1 and mode == "forward": + elif idx_start == len(self.t) - 1 and mode == "forward": # happens only on forward if start >= self.t[-1] return slice(idx_start, idx_start) return slice(idx_start, idx_start + 1) else: - idx_start = max([0, idx_start]) + idx_start = max([0, idx_start]) # if taking a range set slice index to 0 # convert and get index for end end = TsIndex.format_timestamps(np.array([end]), time_unit)[0] @@ -506,13 +508,14 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): idx_end = np.searchsorted(self.t, end, side="left") add_if_forward = 0 if idx_end == len(self.t): - add_if_forward = idx_start < len(self.t) - idx_end -= 1 + idx_end -= 1 # make sure the index is not out of bound + add_if_forward = 1 # add back the index if forward if mode == "backward": - # remove 1 if self.t[idx_end] is larger than end, except if idx_end is already 0 + # remove 1 if self.t[idx_end] is larger than end, except if idx_end is 0 idx_end -= (self.t[idx_end] > end) - int(idx_end == 0) elif mode == "closest": + # subtract 1 if end is closer to self.t[idx_end - 1] di = np.argmin([self.t[idx_end] - end, np.abs(self.t[idx_end - 1] - end)]) idx_end -= di elif mode == "forward" and idx_end == len(self.t) - 1: From 9dae2d0431792860fc5ddf62241b41bb0f85b467 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 29 Jul 2024 20:58:18 -0400 Subject: [PATCH 080/195] improved test and docstring --- pynapple/core/base_class.py | 46 ++++++++++++++++++++++++++++++++++++- tests/test_time_series.py | 4 ++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index e572605d..2b544fee 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -468,6 +468,51 @@ def _from_npz_reader(cls, file): return cls(time_support=iset, **kwargs) def get_slice(self, start, end=None, mode="closest", time_unit="s"): + """ + Get a slice from the time series data based on the start and end values with the specified mode. + + Parameters + ---------- + start : int or float + The starting value for the slice. + end : int or float, optional + The ending value for the slice. Defaults to None. + mode : str, optional + The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". + time_unit : str, optional + The time unit for the start and end values. Defaults to "s" (seconds). + + Returns + ------- + slice : slice + If end is not provided: + - For mode == "backward": + - An empty slice for start < self.t[0] + - slice(idx, idx+1) with self.t[idx] <= start < self.t[idx+1] + - For mode == "forward": + - An empty slice for start >= self.t[-1] + - slice(idx, idx+1) with self.t[idx-1] < start <= self.t[idx] + - For mode == "closest": + - slice(idx, idx+1) with the closest index to start + If end is provided: + - For mode == "backward": + - An empty slice if end < self.t[0] + - slice(idx_start, idx_end) with self.t[idx_start] <= start < self.t[idx_start+1] and + self.t[idx_end] <= end < self.t[idx_end+1] + - For mode == "forward": + - An empty slice if start > self.t[-1] + - slice(idx_start, idx_end) with self.t[idx_start-1] <= start < self.t[idx_start] and + self.t[idx_end-1] <= end < self.t[idx_end] + - For mode == "closest": + - slice(idx_start, idx_end) with the closest indices to start and end + + Raises + ------ + ValueError + - If start or end is not a number. + - If start is greater than end. + + """ if not isinstance(start, Number): raise ValueError(f"'start' must be an int or a float. Type {type(start)} provided instead!") # convert and get index for start @@ -522,4 +567,3 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): idx_end += add_if_forward # add one if idx_start < len(self.t) return slice(idx_start, idx_end) - diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 580b4201..8bf99fbd 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1488,6 +1488,10 @@ def test_get_index_value_types(start, end, time_unit, expectation): (1.6, 3, "forward", slice(1, 2), np.array([2])), (1.6, 3, "backward", slice(0, 2), np.array([1, 2])), (1.6, 3, "closest", slice(1, 2), np.array([2])), + (1.6, 1.8, "backward", slice(0, 0), np.array([])), + (1.6, 1.8, "forward", slice(1, 1), np.array([])), + (1.6, 1.8, "closest", slice(1, 1), np.array([])), + (1.4, 1.6, "closest", slice(0, 1), np.array([1])), (3, 3, "forward", slice(2, 2), np.array([])), (3, 3, "backward", slice(2, 2), np.array([])), (3, 3, "closest", slice(2, 2), np.array([])), From e43434f741d6699adeed6a5c33f2e71f95964087 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 29 Jul 2024 21:50:15 -0400 Subject: [PATCH 081/195] linted --- pynapple/core/base_class.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 2b544fee..8bb1a4a9 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -514,13 +514,17 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): """ if not isinstance(start, Number): - raise ValueError(f"'start' must be an int or a float. Type {type(start)} provided instead!") + raise ValueError( + f"'start' must be an int or a float. Type {type(start)} provided instead!" + ) # convert and get index for start start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] # check end if end is not None and not isinstance(end, Number): - raise ValueError(f"'end' must be an int or a float. Type {type(end)} provided instead!") + raise ValueError( + f"'end' must be an int or a float. Type {type(end)} provided instead!" + ) # get index of preceding time value idx_start = np.searchsorted(self.t, start, side="left") @@ -533,13 +537,17 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): idx_start -= self.t[idx_start] > start elif mode == "closest": # subtract 1 if start is closer to the previous index - di = np.argmin([self.t[idx_start] - start, np.abs(self.t[idx_start - 1] - start)]) + di = np.argmin( + [self.t[idx_start] - start, np.abs(self.t[idx_start - 1] - start)] + ) idx_start -= di if end is None: if idx_start < 0: # happens only on backwards if start < self.t[0] return slice(0, 0) - elif idx_start == len(self.t) - 1 and mode == "forward": # happens only on forward if start >= self.t[-1] + elif ( + idx_start == len(self.t) - 1 and mode == "forward" + ): # happens only on forward if start >= self.t[-1] return slice(idx_start, idx_start) return slice(idx_start, idx_start + 1) else: From 9cf959e2d5b67cd0f7139664484fc946c8f4d6bd Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 29 Jul 2024 23:06:38 -0400 Subject: [PATCH 082/195] renamed test --- tests/test_time_series.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 8bf99fbd..5ba10ffc 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1450,6 +1450,7 @@ def test_pickling(obj): # Ensure time support is the same assert np.all(obj.time_support == unpickled_obj.time_support) + @pytest.mark.parametrize( "start, end, expectation", [ @@ -1518,8 +1519,14 @@ def test_get_index_value_types(start, end, time_unit, expectation): (5, None, "closest", slice(3, 4), np.array([4])) ] ) -def test_get_index_value_types(start, end, mode, expected_slice, expected_array): - ts = nap.Ts(t=np.array([1, 2, 3, 4])) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.array([1, 2, 3, 4])), + nap.Tsd(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])), + nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), + nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) + ]) +def test_get_index_value(start, end, mode, expected_slice, expected_array, ts): out_slice = ts.get_slice(start, end=end, mode=mode) out_array = ts.t[out_slice] assert out_slice == expected_slice From e90091c2ef1d8210a22c4d8ba4c0a09cb52c3878 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 08:14:36 -0400 Subject: [PATCH 083/195] speed up closest --- pynapple/core/base_class.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 8bb1a4a9..2aa47b7f 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -537,9 +537,7 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): idx_start -= self.t[idx_start] > start elif mode == "closest": # subtract 1 if start is closer to the previous index - di = np.argmin( - [self.t[idx_start] - start, np.abs(self.t[idx_start - 1] - start)] - ) + di = self.t[idx_start] - start > np.abs(self.t[idx_start - 1] - start) idx_start -= di if end is None: @@ -569,7 +567,7 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): idx_end -= (self.t[idx_end] > end) - int(idx_end == 0) elif mode == "closest": # subtract 1 if end is closer to self.t[idx_end - 1] - di = np.argmin([self.t[idx_end] - end, np.abs(self.t[idx_end - 1] - end)]) + di = self.t[idx_end] - end > np.abs(self.t[idx_end - 1] - end) idx_end -= di elif mode == "forward" and idx_end == len(self.t) - 1: idx_end += add_if_forward # add one if idx_start < len(self.t) From fecd8bddc1d36e96a008ab1d327d015f6ea94419 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 10:36:54 -0400 Subject: [PATCH 084/195] added an n_points --- pynapple/core/base_class.py | 18 ++++++++-- tests/test_time_series.py | 66 ++++++++++++++++++++++++++++++++++--- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 2aa47b7f..ddaa69d5 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -467,7 +467,7 @@ def _from_npz_reader(cls, file): iset = IntervalSet(start=file["start"], end=file["end"]) return cls(time_support=iset, **kwargs) - def get_slice(self, start, end=None, mode="closest", time_unit="s"): + def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit="s"): """ Get a slice from the time series data based on the start and end values with the specified mode. @@ -481,6 +481,8 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). + n_points : int + Max number of time point per the slice. This will be used to calculate a step size for the slice. Returns ------- @@ -517,6 +519,10 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): raise ValueError( f"'start' must be an int or a float. Type {type(start)} provided instead!" ) + + if end is None and n_points: + raise ValueError("'n_points' can be used only when 'end' is specified!") + # convert and get index for start start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] @@ -572,4 +578,12 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): elif mode == "forward" and idx_end == len(self.t) - 1: idx_end += add_if_forward # add one if idx_start < len(self.t) - return slice(idx_start, idx_end) + step = None + if n_points: + tot_tps = idx_end - idx_start + if tot_tps > n_points: + rounding = tot_tps % n_points + step = tot_tps // n_points + idx_end -= rounding + + return slice(idx_start, idx_end, step) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 5ba10ffc..df0f384f 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1465,10 +1465,10 @@ def test_pickling(obj): ] ) @pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) -def test_get_index_value_types(start, end, time_unit, expectation): +def test_get_slice_value_types(start, end, time_unit, expectation): ts = nap.Ts(t=np.array([1, 2, 3, 4])) with expectation: - ts.get_slice(start, end, time_unit=time_unit) + ts._get_slice(start, end, time_unit=time_unit) @pytest.mark.parametrize( @@ -1526,9 +1526,67 @@ def test_get_index_value_types(start, end, time_unit, expectation): nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) ]) -def test_get_index_value(start, end, mode, expected_slice, expected_array, ts): - out_slice = ts.get_slice(start, end=end, mode=mode) +def test_get_slice_value(start, end, mode, expected_slice, expected_array, ts): + out_slice = ts._get_slice(start, end=end, mode=mode) out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) +@pytest.mark.parametrize( + "end, n_points, expectation", + [ + (1, 3, does_not_raise()), + (None, 3, pytest.raises(ValueError, match="'n_points' can be used only when")), + + ] +) +@pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) +@pytest.mark.parametrize("mode", ["closest", "backward", "forward"]) +def test_get_slice_n_points(end, n_points, expectation, time_unit, mode): + ts = nap.Ts(t=np.array([1, 2, 3, 4])) + with expectation: + ts._get_slice(1, end, n_points=n_points, mode=mode) + + + +@pytest.mark.parametrize( + "start, end, n_points, mode, expected_slice, expected_array", + [ + # smaller than n_points + (1, 2, 2, "forward", slice(0, 1), np.array([1])), + (1, 2, 2, "backward", slice(0, 1), np.array([1])), + (1, 2, 2, "closest", slice(0, 1), np.array([1])), + # larger than n_points + (1, 5, 2, "forward", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "backward", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding down + (1, 5.2, 2, "forward", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "backward", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding down + (1, 6.2, 2, "forward", slice(0, 6, 3), np.array([1, 4])), + (1, 6.2, 2, "backward", slice(0, 4, 2), np.array([1, 3])), + (1, 6.2, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding up + (1, 5.6, 2, "forward", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "backward", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + # larger than n_points with rounding up + (1, 6.6, 2, "forward", slice(0, 6, 3), np.array([1, 4])), + (1, 6.6, 2, "backward", slice(0, 4, 2), np.array([1, 3])), + (1, 6.6, 2, "closest", slice(0, 6, 3), np.array([1, 4])), + ] +) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.arange(1, 10)), + nap.Tsd(t=np.arange(1, 10), d=np.arange(1, 10)), + nap.TsdFrame(t=np.arange(1, 10), d=np.arange(1, 10)[:, None]), + nap.TsdTensor(t=np.arange(1, 10), d=np.arange(1, 10)[:, None, None]) + ]) +def test_get_slice_value(start, end, n_points, mode, expected_slice, expected_array, ts): + out_slice = ts._get_slice(start, end=end, mode=mode, n_points=n_points) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) From b01394b78a3ea243c9eaa4bd8a475a7f504bc42f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 11:11:13 -0400 Subject: [PATCH 085/195] linted --- pynapple/core/base_class.py | 39 ++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index ddaa69d5..371011ab 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -481,7 +481,7 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). - n_points : int + n_points : int, optional Max number of time point per the slice. This will be used to calculate a step size for the slice. Returns @@ -587,3 +587,40 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" idx_end -= rounding return slice(idx_start, idx_end, step) + + def get_slice(self, start, end=None, mode="closest", time_unit="s"): + """ + Get a slice from the time series data based on the start and end values with the specified mode. + + Parameters + ---------- + start : int or float + The starting value for the slice. + end : int or float, optional + The ending value for the slice. Defaults to None. + mode : str, optional + The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". + time_unit : str, optional + The time unit for the start and end values. Defaults to "s" (seconds). + + Returns + ------- + slice : slice + - If mode = "closest": + the default mode, starts/ends the slice with indices closest to the start/end time provided + - If mode = "backward": + starts/ends the slice with the indices preceding the start/end time provided + - If mode = "forward": + starts/ends the slice with the indices following the start/end time provided + + + Raises + ------ + ValueError + - If start or end is not a number. + - If start is greater than end. + + """ + return self._get_slice( + start, end=end, mode=mode, n_points=None, time_unit=time_unit + ) From 5e50daa3c7816426e41f95dd281850e353d149f3 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 11:15:44 -0400 Subject: [PATCH 086/195] added tests for public --- pynapple/core/base_class.py | 8 ++--- tests/test_time_series.py | 61 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 371011ab..301794d8 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -606,13 +606,13 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): Returns ------- slice : slice + A slice determining the start and end indices, with unit step. - If mode = "closest": - the default mode, starts/ends the slice with indices closest to the start/end time provided + Starts/ends the slice with indices closest to the start/end time provided. - If mode = "backward": - starts/ends the slice with the indices preceding the start/end time provided + Starts/ends the slice with the indices preceding the start/end time provided. - If mode = "forward": - starts/ends the slice with the indices following the start/end time provided - + Starts/ends the slice with the indices following the start/end time provided. Raises ------ diff --git a/tests/test_time_series.py b/tests/test_time_series.py index df0f384f..82954637 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1590,3 +1590,64 @@ def test_get_slice_value(start, end, n_points, mode, expected_slice, expected_ar out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) + +@pytest.mark.parametrize( + "start, end, mode, expected_slice, expected_array", + [ + (1, 3, "forward", slice(0, 2), np.array([1, 2])), + (1, 3, "backward", slice(0, 2), np.array([1, 2])), + (1, 3, "closest", slice(0, 2), np.array([1, 2])), + (1, 2.7, "forward", slice(0, 2), np.array([1, 2])), + (1, 2.7, "backward", slice(0, 1), np.array([1])), + (1, 2.7, "closest", slice(0, 2), np.array([1, 2])), + (1, 2.4, "forward", slice(0, 2), np.array([1, 2])), + (1, 2.4, "backward", slice(0, 1), np.array([1])), + (1, 2.4, "closest", slice(0, 1), np.array([1])), + (1.1, 3, "forward", slice(1, 2), np.array([2])), + (1.1, 3, "backward", slice(0, 2), np.array([1, 2])), + (1.1, 3, "closest", slice(0, 2), np.array([1, 2])), + (1.6, 3, "forward", slice(1, 2), np.array([2])), + (1.6, 3, "backward", slice(0, 2), np.array([1, 2])), + (1.6, 3, "closest", slice(1, 2), np.array([2])), + (1.6, 1.8, "backward", slice(0, 0), np.array([])), + (1.6, 1.8, "forward", slice(1, 1), np.array([])), + (1.6, 1.8, "closest", slice(1, 1), np.array([])), + (1.4, 1.6, "closest", slice(0, 1), np.array([1])), + (3, 3, "forward", slice(2, 2), np.array([])), + (3, 3, "backward", slice(2, 2), np.array([])), + (3, 3, "closest", slice(2, 2), np.array([])), + (0, 3, "forward", slice(0, 2), np.array([1, 2])), + (0, 3, "backward", slice(0, 2), np.array([1, 2])), + (0, 3, "closest", slice(0, 2), np.array([1, 2])), + (4, 4, "forward", slice(3, 3), np.array([])), + (4, 4, "backward", slice(3, 3), np.array([])), + (4, 4, "closest", slice(3, 3), np.array([])), + (4, 5, "forward", slice(3, 4), np.array([4])), + (4, 5, "backward", slice(3, 3), np.array([])), + (4, 5, "closest", slice(3, 3), np.array([])), + (0, 1, "forward", slice(0, 0), np.array([])), + (0, 1, "backward", slice(0, 1), np.array([1])), + (0, 1, "closest", slice(0, 0), np.array([])), + (0, None, "forward", slice(0, 1), np.array([1])), + (0, None, "backward", slice(0, 0), np.array([])), + (0, None, "closest", slice(0, 1), np.array([1])), + (1, None, "forward", slice(0, 1), np.array([1])), + (1, None, "backward", slice(0, 1), np.array([1])), + (1, None, "closest", slice(0, 1), np.array([1])), + (5, None, "forward", slice(3, 3), np.array([])), + (5, None, "backward", slice(3, 4), np.array([4])), + (5, None, "closest", slice(3, 4), np.array([4])) + ] +) +@pytest.mark.parametrize("ts", + [ + nap.Ts(t=np.array([1, 2, 3, 4])), + nap.Tsd(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])), + nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), + nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) + ]) +def test_get_slice_public(start, end, mode, expected_slice, expected_array, ts): + out_slice = ts._get_slice(start, end=end, mode=mode) + out_array = ts.t[out_slice] + assert out_slice == expected_slice + assert np.all(out_array == expected_array) From edb5229b7b3f31ff5c352db8c726f24a7d5f72ae Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 11:17:26 -0400 Subject: [PATCH 087/195] improved docstrings --- pynapple/core/base_class.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 301794d8..3cbb8a0e 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -482,7 +482,8 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). n_points : int, optional - Max number of time point per the slice. This will be used to calculate a step size for the slice. + Number of time point that will result from applying the slice. This parameter is used to + calculate a step size for the slice. Returns ------- From a60a1b1f149ba3a1cf4caee037f8412829701073 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Jul 2024 11:30:22 -0400 Subject: [PATCH 088/195] intervalset chunking --- pynapple/core/interval_set.py | 45 +++++++++++++++++++-------------- tests/test_interval_set.py | 47 +++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 18 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 79dcda53..37633462 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -333,7 +333,7 @@ def starts(self): The starts of the IntervalSet """ time_series = importlib.import_module(".time_series", "pynapple.core") - return time_series.Ts(t=self.values[:, 0], time_support=self) + return time_series.Ts(t=self.values[:, 0]) @property def ends(self): @@ -345,7 +345,7 @@ def ends(self): The ends of the IntervalSet """ time_series = importlib.import_module(".time_series", "pynapple.core") - return time_series.Ts(t=self.values[:, 1], time_support=self) + return time_series.Ts(t=self.values[:, 1]) @property def loc(self): @@ -673,33 +673,39 @@ def save(self, filename): def split(self, interval_size, time_units="s"): """Split `IntervalSet` to a new `IntervalSet` with each interval being of size `interval_size`. - + Used mostly for chunking very large dataset or looping throught multiple epoch of same duration. This function skips the epochs that are shorter than `interval_size`. - + + Note that intervals are strictly non-overlapping in pynapple. One microsecond is removed from contiguous intervals. + Parameters ---------- interval_size : Number Description time_units : str, optional time units for the `interval_size` ('us', 'ms', 's' [default]) - + Returns ------- IntervalSet New `IntervalSet` with equal sized intervals - + Raises ------ IOError - Description + If `interval_size` is not a Number or is below 0 + If `time_units` is not a string """ if not isinstance(interval_size, Number): raise IOError("Argument interval_size should of type float or int") + if not interval_size>0: + raise IOError("Argument interval_size should be strictly larger than 0") + if not isinstance(time_units, str): - raise IOError("Argument time_units should be of type float or int") + raise IOError("Argument time_units should be of type str") if len(self) == 0: return IntervalSet(start=[], end=[]) @@ -708,27 +714,30 @@ def split(self, interval_size, time_units="s"): np.array((interval_size,), dtype=np.float64).ravel(), time_units )[0] - durations = self.end - self.start + interval_size = np.round(interval_size, nap_config.time_index_precision) + + durations = np.round(self.end - self.start, nap_config.time_index_precision) new_starts = [] new_ends = [] for i in range(len(self)): - if durations[i] > interval_size: + if durations[i] > interval_size: tmp = np.arange(self.start[i], self.end[i], interval_size) tmp = np.hstack((tmp, np.array([self.end[i]]))) - + tmp = np.round(tmp, nap_config.time_index_precision) new_starts.append(tmp[0:-1]) new_ends.append(tmp[1:]) - + new_starts = np.hstack(new_starts) new_ends = np.hstack(new_ends) - tokeep = np.round(new_ends - new_starts, nap_config.time_index_precision) - - new_iset = IntervalSet(new_starts, new_ends).drop_short_intervals(interval_size-1e-6) - - return new_iset - + durations = np.round(new_ends - new_starts, nap_config.time_index_precision) + tokeep = durations >= interval_size + new_starts = new_starts[tokeep] + new_ends = new_ends[tokeep] + # Removing 1 microsecond to have strictly non-overlapping intervals for intervals coming from the same epoch + new_ends -= 1e-6 + return IntervalSet(new_starts, new_ends) diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 4e338ef7..01d4c6f7 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -266,6 +266,10 @@ def test_as_units(): tmp = df * 1e6 np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values.astype(np.float64)) +def test_as_dataframe(): + ep = nap.IntervalSet(start=0, end=100) + df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) + pd.testing.assert_frame_equal(df, ep.as_dataframe()) def test_intersect(): ep = nap.IntervalSet(start=[0, 30], end=[10, 70]) @@ -511,8 +515,51 @@ def test_save_npz(): Path("ep.npz").unlink() Path("ep2.npz").unlink() +def test_split(): + np.random.seed(0) + start = np.round(np.random.uniform(0, 10)) + end = np.round(np.random.uniform(90, 100)) + tmp = np.linspace(start, end, 100) + interval_size = np.round(tmp[1] - tmp[0], 9) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ep0 = nap.IntervalSet(tmp[0:-1], tmp[1:]) + ep = nap.IntervalSet(tmp[0], tmp[-1]) + ep1 = ep.split(interval_size) + np.testing.assert_array_almost_equal(ep0, ep1) + + # Test with a smaller epochs + start = np.hstack((tmp[0:-1], np.array([200]))) + end = np.hstack((tmp[1:], np.array([200+0.9*interval_size]))) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ep2 = nap.IntervalSet(start, end) + + ep = nap.IntervalSet([start[0], 200], end[-2:]) + ep1 = ep.split(interval_size) + np.testing.assert_array_almost_equal(ep0, ep1) + + # Empty intervalset + ep = nap.IntervalSet([], []) + assert len(ep.split(1)) == 0 + +def test_split_errors(): + start = [0, 10, 16, 25] + end = [5, 15, 20, 40] + ep = nap.IntervalSet(start=start, end=end) + with pytest.raises(IOError) as e: + ep.split('a') + assert str(e.value) == "Argument interval_size should of type float or int" + with pytest.raises(IOError) as e: + ep.split(0) + assert str(e.value) == "Argument interval_size should be strictly larger than 0" + + with pytest.raises(IOError) as e: + ep.split(1, time_units=1) + assert str(e.value) == "Argument time_units should be of type str" + From 8d867b223c43c73be8cd94b5668d99ba17777922 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Jul 2024 14:21:45 -0400 Subject: [PATCH 089/195] add warnings --- pynapple/core/interval_set.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 37633462..cb425751 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -332,6 +332,11 @@ def starts(self): Ts The starts of the IntervalSet """ + warnings.warn( + "starts is a deprecated function. It will be removed in future versions", + category=DeprecationWarning, + stacklevel=2, + ) time_series = importlib.import_module(".time_series", "pynapple.core") return time_series.Ts(t=self.values[:, 0]) @@ -344,6 +349,11 @@ def ends(self): Ts The ends of the IntervalSet """ + warnings.warn( + "ends is a deprecated function. It will be removed in future versions", + category=DeprecationWarning, + stacklevel=2, + ) time_series = importlib.import_module(".time_series", "pynapple.core") return time_series.Ts(t=self.values[:, 1]) @@ -701,7 +711,7 @@ def split(self, interval_size, time_units="s"): if not isinstance(interval_size, Number): raise IOError("Argument interval_size should of type float or int") - if not interval_size>0: + if not interval_size > 0: raise IOError("Argument interval_size should be strictly larger than 0") if not isinstance(time_units, str): From 993831e5230e379430f608a4124886f668f4743b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 15:09:23 -0400 Subject: [PATCH 090/195] added examples --- pynapple/core/base_class.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 3cbb8a0e..ea645c18 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -621,6 +621,23 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): - If start or end is not a number. - If start is greater than end. + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + + >>> ts = nap.Ts(t = np.arange(1, 5)) + >>> start, end = 2.2, 3.6 + + >>> # slice over a range + >>> print(ts.get_slice(start, end, mode="closest")) # returns `slice(1, 3, None)` + >>> print(ts.get_slice(start, end, mode="backward")) # returns `slice(1, 2, None)` + >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(2, 3, None)` + + >>> # slice a single value + >>> print(ts.get_slice(start, None, mode="closest")) # returns `slice(1, 2, None)` + >>> print(ts.get_slice(start, None, mode="backward")) # returns `slice(1, 2, None)` + >>> print(ts.get_slice(start, None, mode="forward")) # returns `slice(2, 3, None)` """ return self._get_slice( start, end=end, mode=mode, n_points=None, time_unit=time_unit From 0a351e03af285d0e0e63062c3ff893bc073cf919 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 15:10:16 -0400 Subject: [PATCH 091/195] fix test public --- tests/test_time_series.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 82954637..6a19663f 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1647,7 +1647,7 @@ def test_get_slice_value(start, end, n_points, mode, expected_slice, expected_ar nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) ]) def test_get_slice_public(start, end, mode, expected_slice, expected_array, ts): - out_slice = ts._get_slice(start, end=end, mode=mode) + out_slice = ts.get_slice(start, end=end, mode=mode) out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) From 0e6665d549de5843b3b965a83c1115358916fea0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 17:09:18 -0400 Subject: [PATCH 092/195] fix example --- pynapple/core/base_class.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index ea645c18..522cd78b 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -624,10 +624,9 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): Examples -------- >>> import pynapple as nap - >>> import numpy as np - >>> ts = nap.Ts(t = np.arange(1, 5)) - >>> start, end = 2.2, 3.6 + >>> ts = nap.Ts(t = [0, 1, 2, 3]) + >>> start, end = 1.2, 2.6 >>> # slice over a range >>> print(ts.get_slice(start, end, mode="closest")) # returns `slice(1, 3, None)` From e63ffe8d2153c1bea709f7b081da1bfdf128f204 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 30 Jul 2024 22:41:21 +0100 Subject: [PATCH 093/195] removing integrate->conv->diff pipeline --- docs/api_guide/tutorial_pynapple_wavelets.py | 71 +++++++++------- pynapple/process/signal_processing.py | 88 ++++++++++---------- tests/test_signal_processing.py | 28 +++++-- 3 files changed, 108 insertions(+), 79 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 8cdeb9e1..d34de4f2 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -74,16 +74,18 @@ # Define the frequency of the wavelets in our filter bank freqs = np.linspace(1, 25, num=25) # Get the filter bank -filter_bank = nap.generate_morlet_filterbank(freqs, Fs, n_cycles=1.5, scaling=1.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, Fs, gaussian_width=1.5, window_length=1.0 +) # %% # Lets plot it. def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) - offset = 0.2 + offset = 1.0 for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i] + offset * f_i) + ax.plot(filter_bank[:, f_i].real() + offset * f_i) ax.text( -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" ) @@ -95,7 +97,7 @@ def plot_filterbank(filter_bank, freqs, title): ax.set_title(title) -title = "Morlet Wavelet Filter Bank (Real Components): n_cycles=1.5, scaling=1.0" +title = "Morlet Wavelet Filter Bank (Real Components): gaussian_width=1.5, window_length=1.0" plot_filterbank(filter_bank, freqs, title) # %% @@ -108,7 +110,9 @@ def plot_filterbank(filter_bank, freqs, title): # scalogram. # Compute the wavelet transform using the parameters above -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=1.5, scaling=1.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=1.5, window_length=1.0 +) # %% @@ -217,13 +221,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ax[1].set_xlabel("Time (s)") [ax[i].set_ylabel("Signal") for i in range(2)] + # %% # *** # Adding ALL the Oscillations! # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. @@ -249,19 +254,18 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # underlying wavelets. freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) -scales = [1.0, 2.0, 3.0] -cycles = [1.0, 2.0, 3.0] +window_lengths = [1.0, 2.0, 3.0] +gaussian_width = [1.0, 2.0, 3.0] fig, ax = plt.subplots( - len(scales), len(cycles), constrained_layout=True, figsize=(10, 5) + len(window_lengths), len(gaussian_width), constrained_layout=True, figsize=(10, 8) ) -for row_i, sc in enumerate(scales): - for col_i, cyc in enumerate(cycles): +for row_i, wl in enumerate(window_lengths): + for col_i, gw in enumerate(gaussian_width): filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, n_cycles=cyc, scaling=sc + freqs, 1000, gaussian_width=gw, window_length=wl, precision=12 ) - ax[row_i, col_i].plot(filter_bank[:, 0]) - ax[row_i, col_i].set_xlim(-15, 15) + ax[row_i, col_i].plot(filter_bank[:, 0].real()) ax[row_i, col_i].set_xlabel("Time (s)") ax[row_i, col_i].set_yticks([]) [ @@ -270,14 +274,15 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ] if col_i != 0: ax[row_i, col_i].get_yaxis().set_visible(False) -for col_i, cyc in enumerate(cycles): - ax[0, col_i].set_title(f"n_cycles={cyc}", fontsize=10) -for row_i, scl in enumerate(scales): - ax[row_i, 0].set_ylabel(f"scaling={scl}", fontsize=10) +for col_i, gw in enumerate(gaussian_width): + ax[0, col_i].set_title(f"gaussian_width={gw}", fontsize=10) +for row_i, wl in enumerate(window_lengths): + ax[row_i, 0].set_ylabel(f"window_length={wl}", fontsize=10) fig.suptitle("Parametrization Visualization") + # %% -# Increasing n_cycles increases the number of wavelet cycles present in the oscillations (cycles) within the +# Increasing time_decay increases the number of wavelet cycles present in the oscillations (cycles) within the # Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution # and frequency resolution. # @@ -286,23 +291,27 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of n_cycles +# Effect of time_decay # ------------------ -# Let's increase n_cycles to 7.5 and see the effect on the resultant filter bank. +# Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=1.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=1.0 +) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=1.0", + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=1.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=1.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=1.0 +) fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) plot_timefrequency( @@ -317,7 +326,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. @@ -343,19 +352,23 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # Let's increase scaling to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) -filter_bank = nap.generate_morlet_filterbank(freqs, 1000, n_cycles=7.5, scaling=2.0) +filter_bank = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width=7.5, window_length=2.0 +) plot_filterbank( filter_bank, freqs, - "Morlet Wavelet Filter Bank (Real Components): n_cycles=7.5, scaling=2.0", + "Morlet Wavelet Filter Bank (Real Components): gaussian_width=7.5, center_frequency=2.0", ) # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -mwt = nap.compute_wavelet_transform(sig, fs=Fs, freqs=freqs, n_cycles=7.5, scaling=2.0) +mwt = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0 +) plot_timefrequency( mwt.index.values[:], freqs[:], @@ -368,7 +381,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1) +combined_oscillations = mwt.sum(axis=1).real() # %% # Lets plot it. diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 914a368b..58ad9fbc 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -57,7 +57,7 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret -def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): +def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): """ Defines the complex Morlet wavelet kernel. @@ -65,10 +65,10 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): ---------- M : int Length of the wavelet - ncycles : float - number of wavelet cycles to use. Default is 1.5 - scaling: float - Scaling factor. Default is 1.0 + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. Default is 8 @@ -79,9 +79,9 @@ def _morlet(M=1024, ncycles=1.5, scaling=1.0, precision=8): """ x = np.linspace(-precision, precision, M) return ( - ((np.pi * ncycles) ** (-0.25)) - * np.exp(-(x**2) / ncycles) - * np.exp(1j * 2 * np.pi * scaling * x) + ((np.pi * gaussian_width) ** (-0.25)) + * np.exp(-(x**2) / gaussian_width) + * np.exp(1j * 2 * np.pi * window_length * x) ) @@ -114,7 +114,7 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas def compute_wavelet_transform( - sig, freqs, fs=None, n_cycles=1.5, scaling=1.0, precision=16, norm=None + sig, freqs, fs=None, gaussian_width=1.5, window_length=1.0, precision=16, norm="l1" ): """ Compute the time-frequency representation of a signal using Morlet wavelets. @@ -129,19 +129,18 @@ def compute_wavelet_transform( The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float or None Sampling rate, in Hz. Defaults to sig.rate if None is given. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. Default is 16 - norm : {None, 'sss', 'amp'}, optional + norm : {None, 'l1', 'l2'}, optional Normalization method: * None - no normalization - * 'sss' - divide by the square root of the sum of squares - * 'amp' - divide by the sum of amplitudes + * 'l1' - divide by the sum of amplitudes + * 'l2' - divide by the square root of the sum of squares Returns ------- @@ -164,9 +163,11 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if isinstance(n_cycles, (int, float, np.number)): - if n_cycles <= 0: - raise ValueError("Number of cycles must be a positive number.") + if isinstance(gaussian_width, (int, float, np.number)): + if gaussian_width <= 0: + raise ValueError("gaussian_width must be a positive number.") + if norm is not None and norm not in ["l1", "l2"]: + raise ValueError("norm parameter must be 'l1', 'l2', or None.") if isinstance(freqs, (tuple, list)): freqs = _create_freqs(*freqs) @@ -180,18 +181,18 @@ def compute_wavelet_transform( output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = generate_morlet_filterbank(freqs, fs, n_cycles, scaling, precision) + filter_bank = generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) convolved_real = sig.convolve(filter_bank.real().values) convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j - coef = -np.diff(convolved, axis=0) - if norm == "sss": - coef *= -np.sqrt(scaling) / (freqs / fs) - elif norm == "amp": - coef *= -scaling / (freqs / fs) - coef = np.insert( - coef, 1, coef[0, :], axis=0 - ) # slightly hacky line, necessary to make output correct shape + if norm == "l1": + coef = convolved / (fs / freqs) + elif norm == "l2": + coef = convolved / (fs / np.sqrt(freqs)) + else: + coef = convolved cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef if len(output_shape) == 2: @@ -204,7 +205,9 @@ def compute_wavelet_transform( ) -def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=16): +def generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 +): """ Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. @@ -217,11 +220,10 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float Sampling rate, in Hz. - n_cycles : float or 1d array - Length of the filter, as the number of cycles for each frequency. - If 1d array, this defines n_cycles for each frequency. - scaling : float - Scaling factor. + gaussian_width : float + Defines width of Gaussian to be used in wavelet creation. + window_length : float + The length of window to be used for wavelet creation. precision: int. Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. @@ -233,13 +235,15 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") filter_bank = [] - time_cutoff = 8 - morlet_f = _morlet(int(2**precision), ncycles=n_cycles, scaling=scaling) - x = np.linspace(-time_cutoff, time_cutoff, int(2**precision)) - int_psi = np.conj(_integrate(morlet_f, x[1] - x[0])) + cutoff = 8 + morlet_f = _morlet( + int(2**precision), gaussian_width=gaussian_width, window_length=window_length + ) + x = np.linspace(-cutoff, cutoff, int(2**precision)) + int_psi = np.conj(morlet_f) max_len = -1 for freq in freqs: - scale = scaling / (freq / fs) + scale = window_length / (freq / fs) j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) j = j.astype(int) # floor if j[-1] >= int_psi.size: @@ -248,7 +252,7 @@ def generate_morlet_filterbank(freqs, fs, n_cycles=1.5, scaling=1.0, precision=1 if len(int_psi_scale) > max_len: max_len = len(int_psi_scale) time = np.linspace( - -time_cutoff * scaling / freq, time_cutoff * scaling / freq, max_len + -cutoff * window_length / freq, cutoff * window_length / freq, max_len ) filter_bank.append(int_psi_scale) filter_bank = [ @@ -271,7 +275,7 @@ def _integrate(arr, step): arr : np.ndarray wave function to be integrated step : float - Step size of vgiven wave function array + Step size of given wave function array Returns ------- diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 8f853591..9df76ae4 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -45,10 +45,10 @@ def test_compute_spectogram(): def test_compute_wavelet_transform(): - t = np.linspace(0, 1, 1000) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), t=t, ) freqs = np.linspace(10, 100, 10) @@ -60,10 +60,10 @@ def test_compute_wavelet_transform(): == 500 ) - t = np.linspace(0, 1, 1000) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1000), [0, 0.5, 1], [0, 1, 0]), + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), t=t, ) freqs = np.linspace(10, 100, 10) @@ -94,7 +94,7 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="sss") + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l1") mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] assert mpf == 20 assert mwt.shape == (1000, 10) @@ -102,7 +102,15 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="amp") + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l2") + mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] + assert mpf == 20 + assert mwt.shape == (1000, 10) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) + freqs = np.linspace(10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm=None) mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] assert mpf == 20 assert mwt.shape == (1000, 10) @@ -139,5 +147,9 @@ def test_compute_wavelet_transform(): assert mwt.shape == (1024, 10, 4, 2) with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, n_cycles=-1.5) - assert str(e_info.value) == "Number of cycles must be a positive number." + nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, gaussian_width=-1.5) + assert str(e_info.value) == "gaussian_width must be a positive number." + + +if __name__ == "__main__": + test_compute_wavelet_transform() From 154ee3dbb4d4dfb3abb38d0a534685fea523d306 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Jul 2024 17:46:45 -0400 Subject: [PATCH 094/195] Update tests/test_interval_set.py Co-authored-by: Edoardo Balzani --- tests/test_interval_set.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 01d4c6f7..00b5298b 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -548,9 +548,8 @@ def test_split_errors(): end = [5, 15, 20, 40] ep = nap.IntervalSet(start=start, end=end) - with pytest.raises(IOError) as e: + with pytest.raises(IOError, match="Argument interval_size should of type float or int"): ep.split('a') - assert str(e.value) == "Argument interval_size should of type float or int" with pytest.raises(IOError) as e: ep.split(0) From 9722d2160f494a375daf36d71ff4795dbd05a289 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Jul 2024 17:47:08 -0400 Subject: [PATCH 095/195] Update pynapple/core/interval_set.py Co-authored-by: Edoardo Balzani --- pynapple/core/interval_set.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index cb425751..9d76b37c 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -728,19 +728,18 @@ def split(self, interval_size, time_units="s"): durations = np.round(self.end - self.start, nap_config.time_index_precision) - new_starts = [] - new_ends = [] - - for i in range(len(self)): - if durations[i] > interval_size: - tmp = np.arange(self.start[i], self.end[i], interval_size) - tmp = np.hstack((tmp, np.array([self.end[i]]))) - tmp = np.round(tmp, nap_config.time_index_precision) - new_starts.append(tmp[0:-1]) - new_ends.append(tmp[1:]) - - new_starts = np.hstack(new_starts) - new_ends = np.hstack(new_ends) + idxs = np.where(durations > interval_size)[0] + size_tmp = (np.ceil((self.end[idxs] - self.start[idxs]) / interval_size)).astype(int) + 1 + new_starts = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) + new_ends = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) + i0 = 0 + for cnt, idx in enumerate(idxs): + new_starts[i0:i0 + size_tmp[cnt] - 1] = np.arange(self.start[idx], self.end[idx], interval_size) + new_ends[i0:i0 + size_tmp[cnt] - 2] = new_starts[i0 + 1: i0 + size_tmp[cnt] - 1] + new_ends[i0 + size_tmp[cnt] - 2] = self.end[idx] + i0 += size_tmp[cnt] - 1 + new_starts = np.round(new_starts, nap_config.time_index_precision) + new_ends = np.round(new_ends, nap_config.time_index_precision) durations = np.round(new_ends - new_starts, nap_config.time_index_precision) tokeep = durations >= interval_size From e8d51a58f75397052cdbf8a6085261e539167361 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 30 Jul 2024 17:51:42 -0400 Subject: [PATCH 096/195] Update --- pynapple/core/interval_set.py | 12 +++++++++--- tests/test_interval_set.py | 6 +++--- tests/test_time_series.py | 10 ---------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 9d76b37c..c19ba075 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -729,13 +729,19 @@ def split(self, interval_size, time_units="s"): durations = np.round(self.end - self.start, nap_config.time_index_precision) idxs = np.where(durations > interval_size)[0] - size_tmp = (np.ceil((self.end[idxs] - self.start[idxs]) / interval_size)).astype(int) + 1 + size_tmp = ( + np.ceil((self.end[idxs] - self.start[idxs]) / interval_size) + ).astype(int) + 1 new_starts = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) new_ends = np.full(size_tmp.sum() - size_tmp.shape[0], np.nan) i0 = 0 for cnt, idx in enumerate(idxs): - new_starts[i0:i0 + size_tmp[cnt] - 1] = np.arange(self.start[idx], self.end[idx], interval_size) - new_ends[i0:i0 + size_tmp[cnt] - 2] = new_starts[i0 + 1: i0 + size_tmp[cnt] - 1] + new_starts[i0 : i0 + size_tmp[cnt] - 1] = np.arange( + self.start[idx], self.end[idx], interval_size + ) + new_ends[i0 : i0 + size_tmp[cnt] - 2] = new_starts[ + i0 + 1 : i0 + size_tmp[cnt] - 1 + ] new_ends[i0 + size_tmp[cnt] - 2] = self.end[idx] i0 += size_tmp[cnt] - 1 new_starts = np.round(new_starts, nap_config.time_index_precision) diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 00b5298b..2c8fbf47 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -261,15 +261,15 @@ def test_tot_length(): def test_as_units(): ep = nap.IntervalSet(start=0, end=100) df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) - pd.testing.assert_frame_equal(df, ep.as_units("s").astype(np.float64)) - pd.testing.assert_frame_equal(df * 1e3, ep.as_units("ms").astype(np.float64)) + np.testing.assert_array_almost_equal(df.values, ep.as_units("s").values.astype(np.float64)) + np.testing.assert_array_almost_equal(df * 1e3, ep.as_units("ms").values.astype(np.float64)) tmp = df * 1e6 np.testing.assert_array_almost_equal(tmp.values, ep.as_units("us").values.astype(np.float64)) def test_as_dataframe(): ep = nap.IntervalSet(start=0, end=100) df = pd.DataFrame(data=np.array([[0.0, 100.0]]), columns=["start", "end"], dtype=np.float64) - pd.testing.assert_frame_equal(df, ep.as_dataframe()) + np.testing.assert_array_almost_equal(df.values, ep.as_dataframe().values) def test_intersect(): ep = nap.IntervalSet(start=[0, 30], end=[10, 70]) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 0942490f..d46fc47e 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -608,16 +608,6 @@ def test__getitems__(self, tsd): a.time_support, tsd.time_support ) - # def test_loc(self, tsd): - # a = tsd.loc[0:10] # should be 11 elements similar to pandas Series - # b = nap.Tsd(t=tsd.index[0:11], d=tsd.values[0:11]) - # assert isinstance(a, nap.Tsd) - # np.testing.assert_array_almost_equal(a.index, b.index) - # np.testing.assert_array_almost_equal(a.values, b.values) - # pd.testing.assert_frame_equal( - # a.time_support, b.time_support - # ) - def test_count(self, tsd): count = tsd.count(1) assert len(count) == 99 From c08a4f0d59134fdff09cef517f706e9aed34253e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 18:19:09 -0400 Subject: [PATCH 097/195] added temp path in lazy loading --- tests/test_lazy_loading.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 6acc4f30..032ff6f2 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -21,8 +21,8 @@ (np.arange(12), "not_an_array", pytest.raises(TypeError, match="Data should be array-like")) ] ) -def test_lazy_load_hdf5_is_array(time, data, expectation): - file_path = Path('data.h5') +def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): + file_path = tmp_path / Path('data.h5') try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data) @@ -42,8 +42,8 @@ def test_lazy_load_hdf5_is_array(time, data, expectation): ] ) @pytest.mark.parametrize("convert_flag", [True, False]) -def test_lazy_load_hdf5_is_array(time, data, convert_flag): - file_path = Path('data.h5') +def test_lazy_load_hdf5_is_array(time, data, convert_flag, tmp_path): + file_path = tmp_path / Path('data.h5') try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data) @@ -63,9 +63,9 @@ def test_lazy_load_hdf5_is_array(time, data, convert_flag): @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @pytest.mark.parametrize("cls", [nap.Tsd, nap.TsdFrame, nap.TsdTensor]) @pytest.mark.parametrize("func", [np.exp, lambda x: x*2]) -def test_lazy_load_hdf5_apply_func(time, data, func,cls): +def test_lazy_load_hdf5_apply_func(time, data, func,cls, tmp_path): """Apply a unary function to a lazy loaded array.""" - file_path = Path('data.h5') + file_path = tmp_path / Path('data.h5') try: if cls is nap.TsdFrame: data = data[:, None] @@ -101,8 +101,8 @@ def test_lazy_load_hdf5_apply_func(time, data, func,cls): ("get", [2, 7]) ] ) -def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls): - file_path = Path('data.h5') +def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls, tmp_path): + file_path = tmp_path / Path('data.h5') try: if cls is nap.TsdFrame: data = data[:, None] @@ -133,8 +133,8 @@ def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls): ("to_tsgroup", [], nap.TsGroup) ] ) -def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type): - file_path = Path('data.h5') +def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type, tmp_path): + file_path = tmp_path / Path('data.h5') try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data) @@ -157,8 +157,8 @@ def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, ("as_dataframe", [], pd.DataFrame), ] ) -def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type): - file_path = Path('data.h5') +def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type, tmp_path): + file_path = tmp_path / Path('data.h5') try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data[:, None]) @@ -174,8 +174,8 @@ def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, file_path.unlink() -def test_lazy_load_hdf5_tsdframe_loc(): - file_path = Path('data.h5') +def test_lazy_load_hdf5_tsdframe_loc(tmp_path): + file_path = tmp_path / Path('data.h5') data = np.arange(10).reshape(5, 2) try: with h5py.File(file_path, 'w') as f: @@ -227,8 +227,8 @@ def test_lazy_load_function(lazy): @pytest.mark.parametrize("data", [np.ones(10), np.ones((10, 2)), np.ones((10, 2, 2))]) -def test_lazy_load_nwb_no_warnings(data): - file_path = Path('data.h5') +def test_lazy_load_nwb_no_warnings(data, tmp_path): # tmp_path is a default fixture creating a temporary folder + file_path = tmp_path / Path('data.h5') try: with h5py.File(file_path, 'w') as f: @@ -252,11 +252,11 @@ def test_lazy_load_nwb_no_warnings(data): file_path.unlink() -def test_tsgroup_no_warnings(): +def test_tsgroup_no_warnings(tmp_path): # default fixture n_units = 2 try: for k in range(n_units): - file_path = Path(f'data_{k}.h5') + file_path = tmp_path / Path(f'data_{k}.h5') with h5py.File(file_path, 'w') as f: f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) with warnings.catch_warnings(record=True) as w: From bf8d45cf5cacc087372b654c2fb85166c87444c8 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Tue, 30 Jul 2024 18:25:26 -0400 Subject: [PATCH 098/195] fixed tests --- tests/test_lazy_loading.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 032ff6f2..bb7e2734 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -264,7 +264,7 @@ def test_tsgroup_no_warnings(tmp_path): # default fixture nwbfile = mock_NWBFile() for k in range(n_units): - file_path = Path(f'data_{k}.h5') + file_path = tmp_path / Path(f'data_{k}.h5') spike_times = h5py.File(file_path, "r")['spks'] nwbfile.add_unit(spike_times=spike_times) From 47e0876753000ba772284407ca94b0646ec7b515 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:03:54 -0400 Subject: [PATCH 099/195] changing pandas test --- tests/test_jitted.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index f97096d5..333e9cdb 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -248,7 +248,8 @@ def test_jitbin_array(): tsd2 = pd.concat(tsd2) # tsd2 = nap.TsdFrame(tsd2) - pd.testing.assert_frame_equal(tsd3, tsd2) + np.testing.assert_array_almost_equal(tsd3.values, tsd2.values) + np.testing.assert_array_almost_equal(tsd3.index.values, tsd2.index.values) def test_jitintersect(): for i in range(10): From 6c176ea42940e395886f6a280dd3d368a15fb391 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:10:52 -0400 Subject: [PATCH 100/195] update --- tests/test_jitted.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 333e9cdb..1dcdb989 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -71,8 +71,9 @@ def test_jitrestrict(): tsd2 = restrict(ep, tsd) ix = nap.core._jitted_functions.jitrestrict(tsd.index, ep.start, ep.end) - tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) - pd.testing.assert_series_equal(tsd2, tsd3) + tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) def test_jitrestrict_with_count(): for i in range(100): @@ -81,7 +82,9 @@ def test_jitrestrict_with_count(): tsd2 = restrict(ep, tsd) ix, count = nap.core._jitted_functions.jitrestrict_with_count(tsd.index, ep.start, ep.end) tsd3 = pd.Series(index=tsd.index[ix], data=tsd.values[ix]) - pd.testing.assert_series_equal(tsd2, tsd3) + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) + bins = ep.values.ravel() ix = np.array(pd.cut(tsd.index, bins, labels=np.arange(len(bins) - 1, dtype=np.float64))) @@ -146,8 +149,10 @@ def test_jitvalue_from(): tsd2.append(tsd.restrict(ep[j]).as_series().reindex(ix, method="nearest").fillna(0.0)) tsd2 = pd.concat(tsd2) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) - pd.testing.assert_series_equal(tsd2, tsd3) def test_jitcount(): for i in range(10): @@ -166,15 +171,15 @@ def test_jitcount(): idx = np.digitize(ts.restrict(ep[j]).index, bins)-1 tmp = np.array([np.sum(idx==j) for j in range(len(bins)-1)]) tmp = nap.Tsd(t = bins[0:-1] + np.diff(bins)/2, d = tmp) - tmp = tmp.restrict(ep[j]) - - # pd.testing.assert_series_equal(tmp, tsd3.restrict(ep.loc[[j]])) + tmp = tmp.restrict(ep[j]) tsd2.append(tmp.as_series()) tsd2 = pd.concat(tsd2) - - pd.testing.assert_series_equal(tsd3.as_series(), tsd2) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) + def test_jitbin(): for i in range(10): @@ -210,8 +215,10 @@ def test_jitbin(): tsd2 = pd.concat(tsd2) # tsd2 = nap.Tsd(tsd2) tsd2 = tsd2.fillna(0.0) + + np.testing.assert_array_almost_equal(tsd2.values, tsd3.values) + np.testing.assert_array_almost_equal(tsd2.index.values, tsd3.index.values) - pd.testing.assert_series_equal(tsd3, tsd2) def test_jitbin_array(): for i in range(10): @@ -281,7 +288,6 @@ def test_jitintersect(): ep4 = nap.IntervalSet(start, end) - # pd.testing.assert_frame_equal(ep3, ep4) np.testing.assert_array_almost_equal(ep3, ep4) def test_jitunion(): @@ -314,8 +320,7 @@ def test_jitunion(): stop = df["time"][ix_stop] ep4 = nap.IntervalSet(start, stop) - - # pd.testing.assert_frame_equal(ep3, ep4) + np.testing.assert_array_almost_equal(ep3, ep4) def test_jitdiff(): @@ -354,8 +359,7 @@ def test_jitdiff(): idx = start != end ep4 = nap.IntervalSet(start[idx], end[idx]) - - # pd.testing.assert_frame_equal(ep3, ep4) + np.testing.assert_array_almost_equal(ep3, ep4) def test_jitunion_isets(): @@ -389,8 +393,7 @@ def test_jitunion_isets(): stop = df["time"][ix_stop] ep5 = nap.IntervalSet(start, stop) - - # pd.testing.assert_frame_equal(ep5, ep6) + np.testing.assert_array_almost_equal(ep5, ep6) From 81695c37bedfdcedfdf4390c33bd309c6a289186 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:20:50 -0400 Subject: [PATCH 101/195] update --- tests/test_ts_group.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index b4743911..d7ca91d1 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:14:41 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 14:42:50 +# @Last Modified time: 2024-07-31 10:20:37 """Tests of ts group for `pynapple` package.""" @@ -128,7 +128,8 @@ def test_create_ts_group_with_metainfo(self, group): ar_info = np.ones(3) * 1 tsgroup = nap.TsGroup(group, sr=sr_info, ar=ar_info) assert tsgroup._metadata.shape == (3, 3) - pd.testing.assert_series_equal(tsgroup._metadata["sr"], sr_info) + np.testing.assert_array_almost_equal(tsgroup._metadata["sr"].values, sr_info.values) + np.testing.assert_array_almost_equal(tsgroup._metadata["sr"].index.values, sr_info.index.values) np.testing.assert_array_almost_equal(tsgroup._metadata["ar"].values, ar_info) def test_add_metainfo(self, group): From 070f3e690a9e4b8e0c96e54a9cfe893c8d800dfb Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 31 Jul 2024 10:26:46 -0400 Subject: [PATCH 102/195] added dtype for count --- pynapple/core/_core_functions.py | 7 ++++--- pynapple/core/base_class.py | 12 ++++++++++-- pynapple/core/time_series.py | 6 ++++-- pynapple/core/ts_group.py | 18 +++++++++++++----- tests/test_time_series.py | 19 ++++++++++++++++++- tests/test_ts_group.py | 32 ++++++++++++++++++++++++++++++++ 6 files changed, 81 insertions(+), 13 deletions(-) diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index 0ba1e915..9bbb871b 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -27,13 +27,14 @@ def _restrict(time_array, starts, ends): return jitrestrict(time_array, starts, ends) -def _count(time_array, starts, ends, bin_size=None): +def _count(time_array, starts, ends, bin_size=None, dtype=None): if isinstance(bin_size, (float, int)): - return jitcount(time_array, starts, ends, bin_size) + t, d = jitcount(time_array, starts, ends, bin_size) else: _, d = jitrestrict_with_count(time_array, starts, ends) t = starts + (ends - starts) / 2 - return t, d + d = d.astype(dtype) if dtype else d + return t, d def _value_from(time_array, time_target_array, data_target_array, starts, ends): diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 522cd78b..054907e4 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -203,7 +203,7 @@ def value_from(self, data, ep=None): return t, d, time_support, kwargs - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -231,6 +231,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -290,6 +292,12 @@ def count(self, *args, **kwargs): if isinstance(a, IntervalSet): ep = a + if dtype: + try: + dtype = np.dtype(dtype) + except Exception: + raise ValueError(f"{dtype} is not a valid numpy dtype.") + starts = ep.start ends = ep.end @@ -298,7 +306,7 @@ def count(self, *args, **kwargs): time_array = self.index.values - t, d = _count(time_array, starts, ends, bin_size) + t, d = _count(time_array, starts, ends, bin_size, dtype=dtype) return t, d, ep diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c7cd0e4c..576d8f49 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1565,7 +1565,7 @@ def value_from(self, data, ep=None): return data.__class__(t, d, time_support=time_support, **kwargs) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -1593,6 +1593,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -1620,7 +1622,7 @@ def count(self, *args, **kwargs): >>> start end >>> 0 100.0 800.0 """ - t, d, ep = super().count(*args, **kwargs) + t, d, ep = super().count(*args, dtype=dtype, **kwargs) return Tsd(t=t, d=d, time_support=ep) def fillna(self, value): diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 3dfc0921..fe3e6fbb 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -606,7 +606,7 @@ def value_from(self, tsd, ep=None): cols = self._metadata.columns.drop("rate") return TsGroup(newgr, time_support=ep, **self._metadata[cols]) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -634,6 +634,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -702,6 +704,12 @@ def count(self, *args, **kwargs): if isinstance(a, IntervalSet): ep = a + if dtype: + try: + dtype = np.dtype(dtype) + except Exception: + raise ValueError(f"{dtype} is not a valid numpy dtype.") + starts = ep.start ends = ep.end @@ -712,20 +720,20 @@ def count(self, *args, **kwargs): # Call it on first element to pre-allocate the array if len(self) >= 1: time_index, d = _count( - self.data[self.index[0]].index.values, starts, ends, bin_size + self.data[self.index[0]].index.values, starts, ends, bin_size, dtype=dtype ) - count = np.zeros((len(time_index), len(self.index)), dtype=np.int64) + count = np.zeros((len(time_index), len(self.index)), dtype=dtype) count[:, 0] = d for i in range(1, len(self.index)): count[:, i] = _count( - self.data[self.index[i]].index.values, starts, ends, bin_size + self.data[self.index[i]].index.values, starts, ends, bin_size, dtype=dtype )[1] return TsdFrame(t=time_index, d=count, time_support=ep, columns=self.index) else: - time_index, _ = _count(np.array([]), starts, ends, bin_size) + time_index, _ = _count(np.array([]), starts, ends, bin_size, dtype=dtype) return TsdFrame( t=time_index, d=np.empty((len(time_index), 0)), time_support=ep ) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 6a19663f..b097553b 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1201,7 +1201,6 @@ def test_count_with_ep_only(self, ts): assert len(count) == 1 np.testing.assert_array_almost_equal(count.values, np.array([100])) - def test_count_errors(self, ts): with pytest.raises(ValueError): ts.count(bin_size = {}) @@ -1212,6 +1211,24 @@ def test_count_errors(self, ts): with pytest.raises(ValueError): ts.count(time_units = {}) + @pytest.mark.parametrize( + "dtype, expectation", + [ + (None, does_not_raise()), + (float, does_not_raise()), + (int, does_not_raise()), + (np.int32, does_not_raise()), + (np.int64, does_not_raise()), + (np.float32, does_not_raise()), + (np.float64, does_not_raise()), + (1, pytest.raises(ValueError, match=f"1 is not a valid numpy dtype")), + ] + ) + def test_count_dtype(self, dtype, expectation, ts): + with expectation: + count = ts.count(bin_size=0.1, dtype=dtype) + if dtype: + assert np.issubdtype(count.dtype, dtype) #################################################### # Test for tsdtensor diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index b4743911..b4da2d6c 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -38,6 +38,16 @@ def ts_group(): group = nap.TsGroup(data, meta=[10, 11]) return group + +@pytest.fixture +def ts_group_one_group(): + # Placeholder setup for Ts and Tsd objects. Adjust as necessary. + ts1 = nap.Ts(t=np.arange(10)) + data = {1: ts1} + group = nap.TsGroup(data, meta=[10]) + return group + + class TestTsGroup1: def test_create_ts_group(self, group): @@ -880,3 +890,25 @@ def test_pickling(ts_group): # Ensure time support is the same assert np.all(ts_group.time_support == unpickled_obj.time_support) + + +@pytest.mark.parametrize( + "dtype, expectation", + [ + (None, does_not_raise()), + (float, does_not_raise()), + (int, does_not_raise()), + (np.int32, does_not_raise()), + (np.int64, does_not_raise()), + (np.float32, does_not_raise()), + (np.float64, does_not_raise()), + (1, pytest.raises(ValueError, match=f"1 is not a valid numpy dtype")), + ] +) +def test_count_dtype(dtype, expectation, ts_group, ts_group_one_group): + with expectation: + count = ts_group.count(bin_size=0.1, dtype=dtype) + count_one = ts_group_one_group.count(bin_size=0.1, dtype=dtype) + if dtype: + assert np.issubdtype(count.dtype, dtype) + assert np.issubdtype(count_one.dtype, dtype) From b9c4251a7212429a91ff1439c596f81cd44593ef Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:37:18 -0400 Subject: [PATCH 103/195] update --- tests/test_lazy_loading.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index bb7e2734..02a25d9d 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -26,9 +26,10 @@ def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data) - h5_data = h5py.File(file_path, 'r')["data"] - with expectation: - nap.Tsd(t=time, d=h5_data, load_array=False) + h5_data = h5py.File(file_path, 'r')["data"] + with expectation: + nap.Tsd(t=time, d=h5_data, load_array=False) + finally: # delete file if file_path.exists(): From e94a526b498d19713399eb17cbbd12ad39125a28 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:39:58 -0400 Subject: [PATCH 104/195] update --- tests/test_lazy_loading.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index 02a25d9d..c63e9add 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -26,10 +26,10 @@ def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): try: with h5py.File(file_path, 'w') as f: f.create_dataset('data', data=data) - h5_data = h5py.File(file_path, 'r')["data"] + with h5py.File(file_path, 'r') as h5_data: with expectation: - nap.Tsd(t=time, d=h5_data, load_array=False) - + nap.Tsd(t=time, d=h5_data['data'], load_array=False) + finally: # delete file if file_path.exists(): From eaada2e8184aab21d3152872159a6380887da430 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 10:50:30 -0400 Subject: [PATCH 105/195] Update --- tests/test_lazy_loading.py | 49 +++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 25 deletions(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index c63e9add..b5309a87 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -23,17 +23,16 @@ ) def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): file_path = tmp_path / Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - with h5py.File(file_path, 'r') as h5_data: - with expectation: - nap.Tsd(t=time, d=h5_data['data'], load_array=False) - - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + with h5py.File(file_path, 'r') as h5_data: + with expectation: + nap.Tsd(t=time, d=h5_data['data'], load_array=False) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize( @@ -45,20 +44,20 @@ def test_lazy_load_hdf5_is_array(time, data, expectation, tmp_path): @pytest.mark.parametrize("convert_flag", [True, False]) def test_lazy_load_hdf5_is_array(time, data, convert_flag, tmp_path): file_path = tmp_path / Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - tsd = nap.Tsd(t=time, d=h5_data, load_array=convert_flag) - if convert_flag: - assert not isinstance(tsd.d, h5py.Dataset) - else: - assert isinstance(tsd.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + tsd = nap.Tsd(t=time, d=h5_data, load_array=convert_flag) + if convert_flag: + assert not isinstance(tsd.d, h5py.Dataset) + else: + assert isinstance(tsd.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) From 1068c6151eebfcc34deb780bc017a525b20599cb Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 11:02:57 -0400 Subject: [PATCH 106/195] Update --- tests/test_lazy_loading.py | 238 ++++++++++++++++++------------------- 1 file changed, 119 insertions(+), 119 deletions(-) diff --git a/tests/test_lazy_loading.py b/tests/test_lazy_loading.py index b5309a87..b140bf0e 100644 --- a/tests/test_lazy_loading.py +++ b/tests/test_lazy_loading.py @@ -66,23 +66,23 @@ def test_lazy_load_hdf5_is_array(time, data, convert_flag, tmp_path): def test_lazy_load_hdf5_apply_func(time, data, func,cls, tmp_path): """Apply a unary function to a lazy loaded array.""" file_path = tmp_path / Path('data.h5') - try: - if cls is nap.TsdFrame: - data = data[:, None] - elif cls is nap.TsdTensor: - data = data[:, None, None] - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - res = func(cls(t=time, d=h5_data, load_array=False)) - assert isinstance(res, cls) - assert not isinstance(res.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + res = func(cls(t=time, d=h5_data, load_array=False)) + assert isinstance(res, cls) + assert not isinstance(res.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -103,24 +103,24 @@ def test_lazy_load_hdf5_apply_func(time, data, func,cls, tmp_path): ) def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls, tmp_path): file_path = tmp_path / Path('data.h5') - try: - if cls is nap.TsdFrame: - data = data[:, None] - elif cls is nap.TsdTensor: - data = data[:, None, None] - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = cls(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - out = func(*args) - assert not isinstance(out.d, h5py.Dataset) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + if cls is nap.TsdFrame: + data = data[:, None] + elif cls is nap.TsdTensor: + data = data[:, None, None] + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = cls(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + out = func(*args) + assert not isinstance(out.d, h5py.Dataset) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -135,19 +135,19 @@ def test_lazy_load_hdf5_apply_method(time, data, method_name, args, cls, tmp_pat ) def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, expected_out_type, tmp_path): file_path = tmp_path / Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.Tsd(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - assert isinstance(func(*args), expected_out_type) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.Tsd(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize("time, data", [(np.arange(12), np.arange(12))]) @@ -159,38 +159,38 @@ def test_lazy_load_hdf5_apply_method_tsd_specific(time, data, method_name, args, ) def test_lazy_load_hdf5_apply_method_tsdframe_specific(time, data, method_name, args, expected_out_type, tmp_path): file_path = tmp_path / Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data[:, None]) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.TsdFrame(t=time, d=h5_data, load_array=False) - func = getattr(tsd, method_name) - assert isinstance(func(*args), expected_out_type) - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data[:, None]) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=time, d=h5_data, load_array=False) + func = getattr(tsd, method_name) + assert isinstance(func(*args), expected_out_type) + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() def test_lazy_load_hdf5_tsdframe_loc(tmp_path): file_path = tmp_path / Path('data.h5') data = np.arange(10).reshape(5, 2) - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - # get the tsd - h5_data = h5py.File(file_path, 'r')["data"] - # lazy load and apply function - tsd = nap.TsdFrame(t=np.arange(data.shape[0]), d=h5_data, load_array=False).loc[1] - assert isinstance(tsd, nap.Tsd) - assert all(tsd.d == np.array([1, 3, 5, 7, 9])) - - finally: - # delete file - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + # get the tsd + h5_data = h5py.File(file_path, 'r')["data"] + # lazy load and apply function + tsd = nap.TsdFrame(t=np.arange(data.shape[0]), d=h5_data, load_array=False).loc[1] + assert isinstance(tsd, nap.Tsd) + assert all(tsd.d == np.array([1, 3, 5, 7, 9])) + + # finally: + # # delete file + # if file_path.exists(): + # file_path.unlink() @pytest.mark.parametrize( "lazy", @@ -230,56 +230,56 @@ def test_lazy_load_function(lazy): def test_lazy_load_nwb_no_warnings(data, tmp_path): # tmp_path is a default fixture creating a temporary folder file_path = tmp_path / Path('data.h5') - try: - with h5py.File(file_path, 'w') as f: - f.create_dataset('data', data=data) - time_series = mock_TimeSeries(name="TimeSeries", data=f["data"]) - nwbfile = mock_NWBFile() - nwbfile.add_acquisition(time_series) - nwb = nap.NWBFile(nwbfile) - - with warnings.catch_warnings(record=True) as w: - tsd = nwb["TimeSeries"] - tsd.count(0.1) - assert isinstance(tsd.d, h5py.Dataset) - - if len(w): - if not str(w[0].message).startswith("Converting 'd' to"): - raise RuntimeError - - finally: - if file_path.exists(): - file_path.unlink() + # try: + with h5py.File(file_path, 'w') as f: + f.create_dataset('data', data=data) + time_series = mock_TimeSeries(name="TimeSeries", data=f["data"]) + nwbfile = mock_NWBFile() + nwbfile.add_acquisition(time_series) + nwb = nap.NWBFile(nwbfile) + + with warnings.catch_warnings(record=True) as w: + tsd = nwb["TimeSeries"] + tsd.count(0.1) + assert isinstance(tsd.d, h5py.Dataset) + + if len(w): + if not str(w[0].message).startswith("Converting 'd' to"): + raise RuntimeError + + # finally: + # if file_path.exists(): + # file_path.unlink() def test_tsgroup_no_warnings(tmp_path): # default fixture n_units = 2 - try: + # try: + for k in range(n_units): + file_path = tmp_path / Path(f'data_{k}.h5') + with h5py.File(file_path, 'w') as f: + f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) + with warnings.catch_warnings(record=True) as w: + + nwbfile = mock_NWBFile() + for k in range(n_units): file_path = tmp_path / Path(f'data_{k}.h5') - with h5py.File(file_path, 'w') as f: - f.create_dataset('spks', data=np.sort(np.random.uniform(0, 10, size=20))) - with warnings.catch_warnings(record=True) as w: - - nwbfile = mock_NWBFile() - - for k in range(n_units): - file_path = tmp_path / Path(f'data_{k}.h5') - spike_times = h5py.File(file_path, "r")['spks'] - nwbfile.add_unit(spike_times=spike_times) - - nwb = nap.NWBFile(nwbfile) - tsgroup = nwb["units"] - tsgroup.count(0.1) - - if len(w): - if not str(w[0].message).startswith("Converting 'd' to"): - raise RuntimeError + spike_times = h5py.File(file_path, "r")['spks'] + nwbfile.add_unit(spike_times=spike_times) + + nwb = nap.NWBFile(nwbfile) + tsgroup = nwb["units"] + tsgroup.count(0.1) + + if len(w): + if not str(w[0].message).startswith("Converting 'd' to"): + raise RuntimeError - finally: - for k in range(n_units): - file_path = Path(f'data_{k}.h5') - if file_path.exists(): - file_path.unlink() + # finally: + # for k in range(n_units): + # file_path = Path(f'data_{k}.h5') + # if file_path.exists(): + # file_path.unlink() From 67954afb0b99961ea1cfbab851f8d616a8904a34 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 31 Jul 2024 11:14:59 -0400 Subject: [PATCH 107/195] linted --- pynapple/core/ts_group.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index fe3e6fbb..be6c955c 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -720,7 +720,11 @@ def count(self, *args, dtype=None, **kwargs): # Call it on first element to pre-allocate the array if len(self) >= 1: time_index, d = _count( - self.data[self.index[0]].index.values, starts, ends, bin_size, dtype=dtype + self.data[self.index[0]].index.values, + starts, + ends, + bin_size, + dtype=dtype, ) count = np.zeros((len(time_index), len(self.index)), dtype=dtype) @@ -728,7 +732,11 @@ def count(self, *args, dtype=None, **kwargs): for i in range(1, len(self.index)): count[:, i] = _count( - self.data[self.index[i]].index.values, starts, ends, bin_size, dtype=dtype + self.data[self.index[i]].index.values, + starts, + ends, + bin_size, + dtype=dtype, )[1] return TsdFrame(t=time_index, d=count, time_support=ep, columns=self.index) From 0fe97adc90a9c39f6696c86676ab808263528504 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 11:19:34 -0400 Subject: [PATCH 108/195] Update --- tests/test_misc.py | 6 +++--- tests/test_time_series.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_misc.py b/tests/test_misc.py index ba5b91f5..61c12b56 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -2,7 +2,7 @@ # @Author: Guillaume Viejo # @Date: 2023-07-10 12:26:20 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 16:05:24 +# @Last Modified time: 2024-07-31 11:17:59 """Tests of IO misc functions""" @@ -40,7 +40,7 @@ def test_load_file(path): np.testing.assert_array_equal(tsd.values, tsd2.values) np.testing.assert_array_equal(tsd.time_support.values, tsd2.time_support.values) - file_path.unlink() + # file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_file_filenotfound(path): @@ -57,7 +57,7 @@ def test_load_wrong_format(path): nap.load_file(file_path) assert str(e.value) == "File format not supported" - file_path.unlink() + # file_path.unlink() @pytest.mark.parametrize("path", [path]) def test_load_folder(path): diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 74b0a202..6e85f725 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -769,8 +769,8 @@ def test_save_npz(self, tsd): np.testing.assert_array_almost_equal(file['start'], tsd.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsd.time_support.end) - Path("tsd.npz").unlink() - Path("tsd2.npz").unlink() + # Path("tsd.npz").unlink() + # Path("tsd2.npz").unlink() def test_interpolate(self, tsd): @@ -1007,8 +1007,8 @@ def test_save_npz(self, tsdframe): np.testing.assert_array_almost_equal(file['end'], tsdframe.time_support.end) np.testing.assert_array_almost_equal(file['columns'], tsdframe.columns) - Path("tsdframe.npz").unlink() - Path("tsdframe2.npz").unlink() + # Path("tsdframe.npz").unlink() + # Path("tsdframe2.npz").unlink() def test_interpolate(self, tsdframe): @@ -1121,8 +1121,8 @@ def test_save_npz(self, ts): np.testing.assert_array_almost_equal(file['start'], ts.time_support.start) np.testing.assert_array_almost_equal(file['end'], ts.time_support.end) - Path("ts.npz").unlink() - Path("ts2.npz").unlink() + # Path("ts.npz").unlink() + # Path("ts2.npz").unlink() def test_fillna(self, ts): with pytest.raises(AssertionError): @@ -1359,8 +1359,8 @@ def test_save_npz(self, tsdtensor): np.testing.assert_array_almost_equal(file['start'], tsdtensor.time_support.start) np.testing.assert_array_almost_equal(file['end'], tsdtensor.time_support.end) - Path("tsdtensor.npz").unlink() - Path("tsdtensor2.npz").unlink() + # Path("tsdtensor.npz").unlink() + # Path("tsdtensor2.npz").unlink() def test_interpolate(self, tsdtensor): From 059d2c41907527f79a93eccbd36a3d24841a855f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 15:29:12 -0400 Subject: [PATCH 109/195] Adding new notebook for psd --- docs/api_guide/tutorial_pynapple_spectrum.py | 113 +++++++++++++++++++ pynapple/process/__init__.py | 3 +- pynapple/process/signal_processing.py | 13 ++- 3 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 docs/api_guide/tutorial_pynapple_spectrum.py diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py new file mode 100644 index 00000000..cb6ac699 --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +""" +Power spectral density +====================== + +Working with Wavelets! + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Kipp Freud. + +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# Now, import the necessary libraries: + +import matplotlib.pyplot as plt +import numpy as np +import seaborn + +seaborn.set_theme() + +import pynapple as nap + +# %% +# *** +# Generating a Dummy Signal +# ------------------ +# Let's generate a dummy signal to analyse with wavelets! +# +# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined +# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. + +F = [2, 10] + +Fs = 2000 +t = np.arange(0, 100, 1/Fs) +sig = nap.Tsd( + t=t, + d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 2, len(t)), + time_support = nap.IntervalSet(0, 10) + ) + +# %% +# Let's plot it +plt.figure() +plt.plot(sig.get(0, 1)) +plt.title("Signal") +plt.show() + + +# %% +# Computing power spectral density (PSD) +# -------------------------------------- +# +# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density` + +psd = nap.compute_power_spectral_density(sig) + +# %% +# Pynapple returns a pandas DataFrame. + +print(psd) + +# %% +# It is then easy to plot it. + +plt.figure() +plt.plot(psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.show() + +# %% +# Note that the output of the FFT is truncated to positive frequencies. To get positive and negative frequencies, you can set `full_range=True`. +# By default, the function returns the frequencies up to the Nyquist frequency. +# Let's zoom on the first 20 Hz. + +plt.figure() +plt.plot(psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 20) +plt.show() + +# %% +# We find the two frequencies 2 and 10 Hz. +# +# By default, pynapple assumes a constant sampling rate and a single epoch. For example, computing the FFT over more than 1 epoch will raise an error. +double_ep = nap.IntervalSet([0, 50], [20, 100]) + +try: + nap.compute_power_spectral_density(sig, ep=double_ep) +except ValueError as e: + print(e) + + +# %% +# Computing mean PSD +# ------------------ +# +# It is possible to compute an average PSD over multiple epochs with the function `nap.compute_mean_power_spectral_density`. +# +# In this case, the argument `interval_size` determines the duration of each epochs upon which the fft is computed. +# If not epochs is passed, the function will split the `time_support`. + + + diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index a73dea00..53a80375 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,7 +16,8 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_spectogram, + compute_power_spectral_density, + compute_mean_power_spectral_density, compute_wavelet_transform, generate_morlet_filterbank, ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 58ad9fbc..44e1d558 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -10,7 +10,7 @@ import pynapple as nap -def compute_spectogram(sig, fs=None, ep=None, full_range=False): +def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. @@ -56,6 +56,17 @@ def compute_spectogram(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False): + + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + + # split_ep = ep.split(interval_size) + + + def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): """ From eec4740bbe019d3dbfcd4ef284e4341d2d7e221c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 31 Jul 2024 16:52:50 -0400 Subject: [PATCH 110/195] Adding mean psd notebook --- docs/api_guide/tutorial_pynapple_spectrum.py | 42 +++++++----- docs/examples/tutorial_signal_processing.py | 2 +- pynapple/process/signal_processing.py | 69 +++++++++++++++++++- 3 files changed, 95 insertions(+), 18 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index cb6ac699..6fe6d727 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -3,12 +3,8 @@ Power spectral density ====================== -Working with Wavelets! - See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. -This tutorial was made by Kipp Freud. - """ # %% @@ -29,12 +25,10 @@ # %% # *** -# Generating a Dummy Signal +# Generating a signal # ------------------ -# Let's generate a dummy signal to analyse with wavelets! +# Let's generate a dummy signal with 2Hz and 10Hz sinusoide with white noise. # -# Our dummy dataset will contain two components, a low frequency 2Hz sinusoid combined -# with a sinusoid which increases frequency from 5 to 15 Hz throughout the signal. F = [2, 10] @@ -42,16 +36,17 @@ t = np.arange(0, 100, 1/Fs) sig = nap.Tsd( t=t, - d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 2, len(t)), - time_support = nap.IntervalSet(0, 10) + d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 3, len(t)), + time_support = nap.IntervalSet(0, 100) ) # %% # Let's plot it plt.figure() -plt.plot(sig.get(0, 1)) +plt.plot(sig.get(0, 0.4)) plt.title("Signal") -plt.show() +plt.xlabel("Time (s)") + # %% @@ -74,7 +69,7 @@ plt.plot(psd) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") -plt.show() + # %% # Note that the output of the FFT is truncated to positive frequencies. To get positive and negative frequencies, you can set `full_range=True`. @@ -86,7 +81,7 @@ plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 20) -plt.show() + # %% # We find the two frequencies 2 and 10 Hz. @@ -106,8 +101,25 @@ # # It is possible to compute an average PSD over multiple epochs with the function `nap.compute_mean_power_spectral_density`. # -# In this case, the argument `interval_size` determines the duration of each epochs upon which the fft is computed. +# In this case, the argument `interval_size` determines the duration of each epochs upon which the FFT is computed. # If not epochs is passed, the function will split the `time_support`. +# +# In this case, the FFT will be computed over epochs of 10 seconds. + +mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=10.0) +# %% +# Let's compare `mean_psd` to `psd`. + +plt.figure() +plt.plot(psd) +plt.plot(mean_psd) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 20) + +# %% +# As we can see, `nap.compute_mean_power_spectral_density` was able to smooth out the noise. + diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index d9b66be9..2453e0f2 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -131,7 +131,7 @@ # ----------------------------------- # Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. -fft = nap.compute_spectogram(RUN_Tsd, fs=int(FS)) +fft = nap.compute_power_spectral_density(RUN_Tsd, fs=int(FS)) print(fft) # %% diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 44e1d558..2a994f04 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -56,16 +56,81 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False): +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_units="s"): + """Compute mean power spectral density by averaging FFT over epochs of same size. + The parameter `interval_size` controls the duration of the epochs. + + Note that this function assumes a constant sampling rate for sig. + Parameters + ---------- + sig : TYPE + Description + interval_size : TYPE + Description + fs : None, optional + Description + ep : None, optional + Description + full_range : bool, optional + Description + time_units : str, optional + Description + + Returns + ------- + TYPE + Description + + Raises + ------ + RuntimeError + Description + TypeError + Description + """ if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support + if fs is None: + fs = sig.rate + + # Split the ep + split_ep = ep.split(interval_size) + + if len(split_ep) == 0: + raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") + + # Get the slices of each ep + slices = np.zeros((len(split_ep),2), dtype=int) - # split_ep = ep.split(interval_size) + for i in range(len(split_ep)): + sl = sig.get_slice(split_ep[i,0], split_ep[i,1]) + slices[i,0] = sl.start + slices[i,1] = sl.stop + + # Check what is the signal length + N = np.min(np.diff(slices, 1)) + if N == 0: + raise RuntimeError(f"One epoch doesn't have any signal. Check the parameter ep or the time support if no epoch is passed.") + # Get the freqs + fft_freq = np.fft.fftfreq(N, 1 / fs) + + # Compute the fft + fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) + + for i in range(len(slices)): + fft_result += np.fft.fft(sig[slices[i,0]:slices[i,1]].values[0:N], axis=0) + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret + def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): From 82dfc43adce03e6fcd4074f5454b51e4577c1d9a Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 13:00:26 -0400 Subject: [PATCH 111/195] pushing some failing tests --- pynapple/process/signal_processing.py | 2 +- tests/test_power_spectral_density.py | 60 +++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 tests/test_power_spectral_density.py diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 2a994f04..59af1130 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -106,7 +106,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu slices = np.zeros((len(split_ep),2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i,0], split_ep[i,1]) + sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') slices[i,0] = sl.start slices[i,1] = sl.stop diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py new file mode 100644 index 00000000..a604315e --- /dev/null +++ b/tests/test_power_spectral_density.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd +import pytest + +import pynapple as nap + + +def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): + t=np.arange(0, duration, 1/fs) + d=np.cos(2*np.pi*f*t) + sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) + tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + out = np.sum(np.fft.fft(tmp, axis=0), 1) + freq = np.fft.fftfreq(out.shape[0], 1 / fs) + order = np.argsort(freq) + out = out[order] + freq = freq[order] + return (sig, out, freq) + + +def test_basic(): + sig, out, freq = get_signal_and_output() + + psd = nap.compute_mean_power_spectral_density(sig, 10) + + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + + + + + +@pytest.mark.parametrize("interval_size, expected_exception", [ + (10, None), # Regular case + (200, RuntimeError), # Interval size too large + (1, RuntimeError) # Epoch too small +]) +@setup_signal_and_params +def test_compute_mean_power_spectral_density(sig, interval_size, expected_exception): + if expected_exception: + with pytest.raises(expected_exception): + compute_mean_power_spectral_density(sig, interval_size) + else: + psd = compute_mean_power_spectral_density(sig, interval_size) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + +@pytest.mark.parametrize("full_range", [True, False]) +@setup_signal_and_params +def test_full_range_option(sig, full_range): + interval_size = 10 # Choose a valid interval size for this test + + psd = compute_mean_power_spectral_density(sig, interval_size, full_range=full_range) + + if full_range: + assert (psd.index >= 0).all() + else: + assert (psd.index >= 0).any() # Partial range should exclude negative frequencies From 0fcbdfda866d01428d8c98a0726fe4b21a7cc75b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 14:48:25 -0400 Subject: [PATCH 112/195] start edit test --- pynapple/core/base_class.py | 20 ++-- tests/test_time_series.py | 215 ++++++++++++++++++------------------ 2 files changed, 119 insertions(+), 116 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 522cd78b..04ae4dfd 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -467,7 +467,7 @@ def _from_npz_reader(cls, file): iset = IntervalSet(start=file["start"], end=file["end"]) return cls(time_support=iset, **kwargs) - def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit="s"): + def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit="s"): """ Get a slice from the time series data based on the start and end values with the specified mode. @@ -478,7 +478,7 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" end : int or float, optional The ending value for the slice. Defaults to None. mode : str, optional - The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". + The mode for slicing. Can be "after_t", "before_t", or "closest". Defaults to "closest_t". time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). n_points : int, optional @@ -489,13 +489,13 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" ------- slice : slice If end is not provided: - - For mode == "backward": + - For mode == "before_t": - An empty slice for start < self.t[0] - slice(idx, idx+1) with self.t[idx] <= start < self.t[idx+1] - - For mode == "forward": + - For mode == "after_t": - An empty slice for start >= self.t[-1] - slice(idx, idx+1) with self.t[idx-1] < start <= self.t[idx] - - For mode == "closest": + - For mode == "closest_t": - slice(idx, idx+1) with the closest index to start If end is provided: - For mode == "backward": @@ -538,11 +538,11 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" if idx_start == len(self.t): idx_start -= 1 # make sure the index is not out of bound - if mode == "backward": + if mode == "before_t": # in order to get the index preceding start # subtract one except if self.t[idx_start] is exactly equal to start idx_start -= self.t[idx_start] > start - elif mode == "closest": + elif mode == "closest_t": # subtract 1 if start is closer to the previous index di = self.t[idx_start] - start > np.abs(self.t[idx_start - 1] - start) idx_start -= di @@ -569,14 +569,14 @@ def _get_slice(self, start, end=None, mode="closest", n_points=None, time_unit=" idx_end -= 1 # make sure the index is not out of bound add_if_forward = 1 # add back the index if forward - if mode == "backward": + if mode == "before_t": # remove 1 if self.t[idx_end] is larger than end, except if idx_end is 0 idx_end -= (self.t[idx_end] > end) - int(idx_end == 0) - elif mode == "closest": + elif mode == "closest_t": # subtract 1 if end is closer to self.t[idx_end - 1] di = self.t[idx_end] - end > np.abs(self.t[idx_end - 1] - end) idx_end -= di - elif mode == "forward" and idx_end == len(self.t) - 1: + elif mode == "after_t" and idx_end == len(self.t) - 1: idx_end += add_if_forward # add one if idx_start < len(self.t) step = None diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 6e85f725..fba76973 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1464,49 +1464,52 @@ def test_get_slice_value_types(start, end, time_unit, expectation): @pytest.mark.parametrize( "start, end, mode, expected_slice, expected_array", [ - (1, 3, "forward", slice(0, 2), np.array([1, 2])), - (1, 3, "backward", slice(0, 2), np.array([1, 2])), - (1, 3, "closest", slice(0, 2), np.array([1, 2])), - (1, 2.7, "forward", slice(0, 2), np.array([1, 2])), - (1, 2.7, "backward", slice(0, 1), np.array([1])), - (1, 2.7, "closest", slice(0, 2), np.array([1, 2])), - (1, 2.4, "forward", slice(0, 2), np.array([1, 2])), - (1, 2.4, "backward", slice(0, 1), np.array([1])), - (1, 2.4, "closest", slice(0, 1), np.array([1])), - (1.1, 3, "forward", slice(1, 2), np.array([2])), - (1.1, 3, "backward", slice(0, 2), np.array([1, 2])), - (1.1, 3, "closest", slice(0, 2), np.array([1, 2])), - (1.6, 3, "forward", slice(1, 2), np.array([2])), - (1.6, 3, "backward", slice(0, 2), np.array([1, 2])), - (1.6, 3, "closest", slice(1, 2), np.array([2])), - (1.6, 1.8, "backward", slice(0, 0), np.array([])), - (1.6, 1.8, "forward", slice(1, 1), np.array([])), - (1.6, 1.8, "closest", slice(1, 1), np.array([])), - (1.4, 1.6, "closest", slice(0, 1), np.array([1])), - (3, 3, "forward", slice(2, 2), np.array([])), - (3, 3, "backward", slice(2, 2), np.array([])), - (3, 3, "closest", slice(2, 2), np.array([])), - (0, 3, "forward", slice(0, 2), np.array([1, 2])), - (0, 3, "backward", slice(0, 2), np.array([1, 2])), - (0, 3, "closest", slice(0, 2), np.array([1, 2])), - (4, 4, "forward", slice(3, 3), np.array([])), - (4, 4, "backward", slice(3, 3), np.array([])), - (4, 4, "closest", slice(3, 3), np.array([])), - (4, 5, "forward", slice(3, 4), np.array([4])), - (4, 5, "backward", slice(3, 3), np.array([])), - (4, 5, "closest", slice(3, 3), np.array([])), - (0, 1, "forward", slice(0, 0), np.array([])), - (0, 1, "backward", slice(0, 1), np.array([1])), - (0, 1, "closest", slice(0, 0), np.array([])), - (0, None, "forward", slice(0, 1), np.array([1])), - (0, None, "backward", slice(0, 0), np.array([])), - (0, None, "closest", slice(0, 1), np.array([1])), - (1, None, "forward", slice(0, 1), np.array([1])), - (1, None, "backward", slice(0, 1), np.array([1])), - (1, None, "closest", slice(0, 1), np.array([1])), - (5, None, "forward", slice(3, 3), np.array([])), - (5, None, "backward", slice(3, 4), np.array([4])), - (5, None, "closest", slice(3, 4), np.array([4])) + (1, 3, "after_t", slice(0, 2), np.array([1, 2])), + (1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "before_t", slice(0, 1), np.array([1])), + (1, 2.7, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "before_t", slice(0, 1), np.array([1])), + (1, 2.4, "closest_t", slice(0, 1), np.array([1])), + (1.1, 3, "after_t", slice(1, 2), np.array([2])), + (1.1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "after_t", slice(1, 2), np.array([2])), + (1.6, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "closest_t", slice(1, 2), np.array([2])), + (1.6, 1.8, "before_t", slice(0, 0), np.array([])), + (1.6, 1.8, "after_t", slice(1, 1), np.array([])), + (1.6, 1.8, "closest_t", slice(1, 1), np.array([])), + (1.4, 1.6, "closest_t", slice(0, 1), np.array([1])), + (3, 3, "after_t", slice(2, 2), np.array([])), + (3, 3, "before_t", slice(2, 2), np.array([])), + (3, 3, "closest_t", slice(2, 2), np.array([])), + (0, 3, "after_t", slice(0, 2), np.array([1, 2])), + (0, 3, "before_t", slice(0, 2), np.array([1, 2])), + (0, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (0, 4, "after_t", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "before_t", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "closest_t", slice(0, 3), np.array([1, 2, 3])), + (4, 4, "after_t", slice(3, 3), np.array([])), + (4, 4, "before_t", slice(3, 3), np.array([])), + (4, 4, "closest_t", slice(3, 3), np.array([])), + (4, 5, "after_t", slice(3, 4), np.array([4])), + (4, 5, "before_t", slice(3, 3), np.array([])), + (4, 5, "closest_t", slice(3, 3), np.array([])), + (0, 1, "after_t", slice(0, 0), np.array([])), + (0, 1, "before_t", slice(0, 1), np.array([1])), + (0, 1, "closest_t", slice(0, 0), np.array([])), + (0, None, "after_t", slice(0, 1), np.array([1])), + (0, None, "before_t", slice(0, 0), np.array([])), + (0, None, "closest_t", slice(0, 1), np.array([1])), + (1, None, "after_t", slice(0, 1), np.array([1])), + (1, None, "before_t", slice(0, 1), np.array([1])), + (1, None, "closest_t", slice(0, 1), np.array([1])), + (5, None, "after_t", slice(3, 3), np.array([])), + (5, None, "before_t", slice(3, 4), np.array([4])), + (5, None, "closest_t", slice(3, 4), np.array([4])) ] ) @pytest.mark.parametrize("ts", @@ -1531,7 +1534,7 @@ def test_get_slice_value(start, end, mode, expected_slice, expected_array, ts): ] ) @pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) -@pytest.mark.parametrize("mode", ["closest", "backward", "forward"]) +@pytest.mark.parametrize("mode", ["closest_t", "before_t", "after_t"]) def test_get_slice_n_points(end, n_points, expectation, time_unit, mode): ts = nap.Ts(t=np.array([1, 2, 3, 4])) with expectation: @@ -1543,29 +1546,29 @@ def test_get_slice_n_points(end, n_points, expectation, time_unit, mode): "start, end, n_points, mode, expected_slice, expected_array", [ # smaller than n_points - (1, 2, 2, "forward", slice(0, 1), np.array([1])), - (1, 2, 2, "backward", slice(0, 1), np.array([1])), - (1, 2, 2, "closest", slice(0, 1), np.array([1])), + (1, 2, 2, "after_t", slice(0, 1), np.array([1])), + (1, 2, 2, "before_t", slice(0, 1), np.array([1])), + (1, 2, 2, "closest_t", slice(0, 1), np.array([1])), # larger than n_points - (1, 5, 2, "forward", slice(0, 4, 2), np.array([1, 3])), - (1, 5, 2, "backward", slice(0, 4, 2), np.array([1, 3])), - (1, 5, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), # larger than n_points with rounding down - (1, 5.2, 2, "forward", slice(0, 4, 2), np.array([1, 3])), - (1, 5.2, 2, "backward", slice(0, 4, 2), np.array([1, 3])), - (1, 5.2, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.2, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), # larger than n_points with rounding down - (1, 6.2, 2, "forward", slice(0, 6, 3), np.array([1, 4])), - (1, 6.2, 2, "backward", slice(0, 4, 2), np.array([1, 3])), - (1, 6.2, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + (1, 6.2, 2, "after_t", slice(0, 6, 3), np.array([1, 4])), + (1, 6.2, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 6.2, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), # larger than n_points with rounding up - (1, 5.6, 2, "forward", slice(0, 4, 2), np.array([1, 3])), - (1, 5.6, 2, "backward", slice(0, 4, 2), np.array([1, 3])), - (1, 5.6, 2, "closest", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "after_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 5.6, 2, "closest_t", slice(0, 4, 2), np.array([1, 3])), # larger than n_points with rounding up - (1, 6.6, 2, "forward", slice(0, 6, 3), np.array([1, 4])), - (1, 6.6, 2, "backward", slice(0, 4, 2), np.array([1, 3])), - (1, 6.6, 2, "closest", slice(0, 6, 3), np.array([1, 4])), + (1, 6.6, 2, "after_t", slice(0, 6, 3), np.array([1, 4])), + (1, 6.6, 2, "before_t", slice(0, 4, 2), np.array([1, 3])), + (1, 6.6, 2, "closest_t", slice(0, 6, 3), np.array([1, 4])), ] ) @pytest.mark.parametrize("ts", @@ -1575,7 +1578,7 @@ def test_get_slice_n_points(end, n_points, expectation, time_unit, mode): nap.TsdFrame(t=np.arange(1, 10), d=np.arange(1, 10)[:, None]), nap.TsdTensor(t=np.arange(1, 10), d=np.arange(1, 10)[:, None, None]) ]) -def test_get_slice_value(start, end, n_points, mode, expected_slice, expected_array, ts): +def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expected_array, ts): out_slice = ts._get_slice(start, end=end, mode=mode, n_points=n_points) out_array = ts.t[out_slice] assert out_slice == expected_slice @@ -1584,49 +1587,49 @@ def test_get_slice_value(start, end, n_points, mode, expected_slice, expected_ar @pytest.mark.parametrize( "start, end, mode, expected_slice, expected_array", [ - (1, 3, "forward", slice(0, 2), np.array([1, 2])), - (1, 3, "backward", slice(0, 2), np.array([1, 2])), - (1, 3, "closest", slice(0, 2), np.array([1, 2])), - (1, 2.7, "forward", slice(0, 2), np.array([1, 2])), - (1, 2.7, "backward", slice(0, 1), np.array([1])), - (1, 2.7, "closest", slice(0, 2), np.array([1, 2])), - (1, 2.4, "forward", slice(0, 2), np.array([1, 2])), - (1, 2.4, "backward", slice(0, 1), np.array([1])), - (1, 2.4, "closest", slice(0, 1), np.array([1])), - (1.1, 3, "forward", slice(1, 2), np.array([2])), - (1.1, 3, "backward", slice(0, 2), np.array([1, 2])), - (1.1, 3, "closest", slice(0, 2), np.array([1, 2])), - (1.6, 3, "forward", slice(1, 2), np.array([2])), - (1.6, 3, "backward", slice(0, 2), np.array([1, 2])), - (1.6, 3, "closest", slice(1, 2), np.array([2])), - (1.6, 1.8, "backward", slice(0, 0), np.array([])), - (1.6, 1.8, "forward", slice(1, 1), np.array([])), - (1.6, 1.8, "closest", slice(1, 1), np.array([])), - (1.4, 1.6, "closest", slice(0, 1), np.array([1])), - (3, 3, "forward", slice(2, 2), np.array([])), - (3, 3, "backward", slice(2, 2), np.array([])), - (3, 3, "closest", slice(2, 2), np.array([])), - (0, 3, "forward", slice(0, 2), np.array([1, 2])), - (0, 3, "backward", slice(0, 2), np.array([1, 2])), - (0, 3, "closest", slice(0, 2), np.array([1, 2])), - (4, 4, "forward", slice(3, 3), np.array([])), - (4, 4, "backward", slice(3, 3), np.array([])), - (4, 4, "closest", slice(3, 3), np.array([])), - (4, 5, "forward", slice(3, 4), np.array([4])), - (4, 5, "backward", slice(3, 3), np.array([])), - (4, 5, "closest", slice(3, 3), np.array([])), - (0, 1, "forward", slice(0, 0), np.array([])), - (0, 1, "backward", slice(0, 1), np.array([1])), - (0, 1, "closest", slice(0, 0), np.array([])), - (0, None, "forward", slice(0, 1), np.array([1])), - (0, None, "backward", slice(0, 0), np.array([])), - (0, None, "closest", slice(0, 1), np.array([1])), - (1, None, "forward", slice(0, 1), np.array([1])), - (1, None, "backward", slice(0, 1), np.array([1])), - (1, None, "closest", slice(0, 1), np.array([1])), - (5, None, "forward", slice(3, 3), np.array([])), - (5, None, "backward", slice(3, 4), np.array([4])), - (5, None, "closest", slice(3, 4), np.array([4])) + (1, 3, "after_t", slice(0, 2), np.array([1, 2])), + (1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.7, "before_t", slice(0, 1), np.array([1])), + (1, 2.7, "closest_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "after_t", slice(0, 2), np.array([1, 2])), + (1, 2.4, "before_t", slice(0, 1), np.array([1])), + (1, 2.4, "closest_t", slice(0, 1), np.array([1])), + (1.1, 3, "after_t", slice(1, 2), np.array([2])), + (1.1, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.1, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "after_t", slice(1, 2), np.array([2])), + (1.6, 3, "before_t", slice(0, 2), np.array([1, 2])), + (1.6, 3, "closest_t", slice(1, 2), np.array([2])), + (1.6, 1.8, "before_t", slice(0, 0), np.array([])), + (1.6, 1.8, "after_t", slice(1, 1), np.array([])), + (1.6, 1.8, "closest_t", slice(1, 1), np.array([])), + (1.4, 1.6, "closest_t", slice(0, 1), np.array([1])), + (3, 3, "after_t", slice(2, 2), np.array([])), + (3, 3, "before_t", slice(2, 2), np.array([])), + (3, 3, "closest_t", slice(2, 2), np.array([])), + (0, 3, "after_t", slice(0, 2), np.array([1, 2])), + (0, 3, "before_t", slice(0, 2), np.array([1, 2])), + (0, 3, "closest_t", slice(0, 2), np.array([1, 2])), + (4, 4, "after_t", slice(3, 3), np.array([])), + (4, 4, "before_t", slice(3, 3), np.array([])), + (4, 4, "closest_t", slice(3, 3), np.array([])), + (4, 5, "after_t", slice(3, 4), np.array([4])), + (4, 5, "before_t", slice(3, 3), np.array([])), + (4, 5, "closest_t", slice(3, 3), np.array([])), + (0, 1, "after_t", slice(0, 0), np.array([])), + (0, 1, "before_t", slice(0, 1), np.array([1])), + (0, 1, "closest_t", slice(0, 0), np.array([])), + (0, None, "after_t", slice(0, 1), np.array([1])), + (0, None, "before_t", slice(0, 0), np.array([])), + (0, None, "closest_t", slice(0, 1), np.array([1])), + (1, None, "after_t", slice(0, 1), np.array([1])), + (1, None, "before_t", slice(0, 1), np.array([1])), + (1, None, "closest_t", slice(0, 1), np.array([1])), + (5, None, "after_t", slice(3, 3), np.array([])), + (5, None, "before_t", slice(3, 4), np.array([4])), + (5, None, "closest_t", slice(3, 4), np.array([4])) ] ) @pytest.mark.parametrize("ts", From 0d55b170b47d3343a7d3653257a4dd7d2e11a1a9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 19:43:22 -0400 Subject: [PATCH 113/195] added slice mode restrict --- pynapple/core/base_class.py | 9 ++++++++- tests/test_time_series.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 04ae4dfd..6f7a0e8c 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -497,6 +497,8 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit - slice(idx, idx+1) with self.t[idx-1] < start <= self.t[idx] - For mode == "closest_t": - slice(idx, idx+1) with the closest index to start + - For mode == "restrict": + - slice the indices such that start <= self.t[idx] <= end If end is provided: - For mode == "backward": - An empty slice if end < self.t[0] @@ -524,6 +526,9 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit if end is None and n_points: raise ValueError("'n_points' can be used only when 'end' is specified!") + if mode == "restrict" and n_points: + raise ValueError("Fixing the number of time points is incompatible with 'restrict' mode.") + # convert and get index for start start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] @@ -551,7 +556,7 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit if idx_start < 0: # happens only on backwards if start < self.t[0] return slice(0, 0) elif ( - idx_start == len(self.t) - 1 and mode == "forward" + idx_start == len(self.t) - 1 and mode == "after_t" ): # happens only on forward if start >= self.t[-1] return slice(idx_start, idx_start) return slice(idx_start, idx_start + 1) @@ -578,6 +583,8 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit idx_end -= di elif mode == "after_t" and idx_end == len(self.t) - 1: idx_end += add_if_forward # add one if idx_start < len(self.t) + elif mode == "restrict": + idx_end += int(self.t[idx_end] <= end) step = None if n_points: diff --git a/tests/test_time_series.py b/tests/test_time_series.py index fba76973..6026d1fb 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1509,7 +1509,21 @@ def test_get_slice_value_types(start, end, time_unit, expectation): (1, None, "closest_t", slice(0, 1), np.array([1])), (5, None, "after_t", slice(3, 3), np.array([])), (5, None, "before_t", slice(3, 4), np.array([4])), - (5, None, "closest_t", slice(3, 4), np.array([4])) + (5, None, "closest_t", slice(3, 4), np.array([4])), + (1, 3, "restrict", slice(0, 3), np.array([1, 2, 3])), + (1, 2.7, "restrict", slice(0, 2), np.array([1, 2])), + (1, 2.4, "restrict", slice(0, 2), np.array([1, 2])), + (1.1, 3, "restrict", slice(1, 3), np.array([2, 3])), + (1.6, 3, "restrict", slice(1, 3), np.array([2, 3])), + (1.6, 1.8, "restrict", slice(1, 1), np.array([])), + (1.4, 1.6, "restrict", slice(1, 1), np.array([])), + (3, 3, "restrict", slice(2, 3), np.array([3])), + (0, 3, "restrict", slice(0, 3), np.array([1, 2, 3])), + (0, 4, "restrict", slice(0, 4), np.array([1, 2, 3, 4])), + (4, 4, "restrict", slice(3, 4), np.array([4])), + (4, 5, "restrict", slice(3, 4), np.array([4])), + (0, 1, "restrict", slice(0, 1), np.array([1])), + ] ) @pytest.mark.parametrize("ts", @@ -1524,6 +1538,25 @@ def test_get_slice_value(start, end, mode, expected_slice, expected_array, ts): out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) + if mode == "restrict": + iset = nap.IntervalSet(start, end) + out_restrict = ts.restrict(iset) + assert np.all(out_restrict.t == out_array) + + +def test_get_slice_restrict_random_val_value(): + np.random.seed(123) + ts = nap.Ts(np.linspace(0.2, 0.8, 100)) + se_vec = np.random.uniform(0, 1, size=(10000, 2)) + starts = np.min(se_vec, axis=1) + ends = np.max(se_vec, axis=1) + for start, end in zip(starts, ends): + out_slice = ts._get_slice(start, end=end, mode="restrict") + out_ts = ts[out_slice] + iset = nap.IntervalSet(start, end) + out_restrict = ts.restrict(iset) + assert np.all(out_restrict == out_ts) + @pytest.mark.parametrize( "end, n_points, expectation", From 6c3d6e532577d83cb1643d2e8dccd4cb52dc7bc6 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 19:44:52 -0400 Subject: [PATCH 114/195] Adding tests for mean PSD --- pynapple/process/signal_processing.py | 45 +++++---- tests/test_power_spectral_density.py | 130 +++++++++++++++++--------- tests/test_signal_processing.py | 36 ------- 3 files changed, 114 insertions(+), 97 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 59af1130..b7a5f8cd 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -6,8 +6,8 @@ import numpy as np import pandas as pd - -import pynapple as nap +from numbers import Number +from .. import core as nap def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): @@ -56,7 +56,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_units="s"): +def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s"): """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. @@ -64,46 +64,53 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu Parameters ---------- - sig : TYPE - Description - interval_size : TYPE - Description + sig : Tsd or TsdFrame + Signal with equispaced samples + interval_size : Number + Epochs size to compute to average the FFT across fs : None, optional - Description + Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` ep : None, optional - Description + The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional - Description - time_units : str, optional - Description + If true, will return full fft frequency range, otherwise will return only positive values + time_unit : str, optional + Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') Returns ------- - TYPE - Description + pandas.DataFrame + Power spectral density. Raises ------ RuntimeError - Description + If splitting the epoch with `interval_size` results in an empty set. TypeError - Description + If `ep` or `sig` are not respectively pynapple time series or interval set. """ if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support + + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") if fs is None: - fs = sig.rate + fs = sig.rate + + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") # Split the ep + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[0] split_ep = ep.split(interval_size) if len(split_ep) == 0: raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") # Get the slices of each ep - slices = np.zeros((len(split_ep),2), dtype=int) + slices = np.zeros((len(split_ep),2), dtype=int) for i in range(len(split_ep)): sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') @@ -114,7 +121,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu N = np.min(np.diff(slices, 1)) if N == 0: - raise RuntimeError(f"One epoch doesn't have any signal. Check the parameter ep or the time support if no epoch is passed.") + raise RuntimeError("One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.") # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index a604315e..0eb39cdd 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -1,60 +1,106 @@ import numpy as np import pandas as pd import pytest - +from contextlib import nullcontext as does_not_raise import pynapple as nap +############################################################ +# Test for mean_power_spectral_density +############################################################ -def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): - t=np.arange(0, duration, 1/fs) - d=np.cos(2*np.pi*f*t) - sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) - tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T - out = np.sum(np.fft.fft(tmp, axis=0), 1) - freq = np.fft.fftfreq(out.shape[0], 1 / fs) - order = np.argsort(freq) - out = out[order] - freq = freq[order] - return (sig, out, freq) +def test_compute_power_spectral_density(): + with pytest.raises(ValueError) as e_info: + t = np.linspace(0, 1, 1000) + sig = nap.Tsd( + d=np.random.random(1000), + t=t, + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ) + r = nap.compute_power_spectral_density(sig) + assert ( + str(e_info.value) == "Given epoch (or signal time_support) must have length 1" + ) + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig) + assert isinstance(r, pd.DataFrame) + assert r.shape == (500, 4) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) + r = nap.compute_power_spectral_density(sig, full_range=True) + assert isinstance(r, pd.DataFrame) + assert r.shape == (1000, 4) -def test_basic(): - sig, out, freq = get_signal_and_output() + with pytest.raises(TypeError) as e_info: + nap.compute_power_spectral_density("a_string") + assert ( + str(e_info.value) + == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" + ) + +############################################################ +# Test for mean_power_spectral_density +############################################################ + +def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): + t=np.arange(0, duration, 1/fs) + d=np.cos(2*np.pi*f*t) + sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) + tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + tmp = tmp[0:-1] + out = np.sum(np.fft.fft(tmp, axis=0), 1) + freq = np.fft.fftfreq(out.shape[0], 1 / fs) + order = np.argsort(freq) + out = out[order] + freq = freq[order] + return (sig, out, freq) + +def test_compute_mean_power_spectral_density(): + sig, out, freq = get_signal_and_output() psd = nap.compute_mean_power_spectral_density(sig, 10) - assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + # Full range + psd = nap.compute_mean_power_spectral_density(sig, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out) + np.testing.assert_array_almost_equal(psd.index.values, freq) + # TsdFrame + sig2 = nap.TsdFrame(t=sig.t, d=np.repeat(sig.values[:,None], 2, 1), time_support = sig.time_support) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:,None],2,1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) - -@pytest.mark.parametrize("interval_size, expected_exception", [ - (10, None), # Regular case - (200, RuntimeError), # Interval size too large - (1, RuntimeError) # Epoch too small -]) -@setup_signal_and_params -def test_compute_mean_power_spectral_density(sig, interval_size, expected_exception): - if expected_exception: - with pytest.raises(expected_exception): - compute_mean_power_spectral_density(sig, interval_size) - else: - psd = compute_mean_power_spectral_density(sig, interval_size) - assert isinstance(psd, pd.DataFrame) - assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - -@pytest.mark.parametrize("full_range", [True, False]) -@setup_signal_and_params -def test_full_range_option(sig, full_range): - interval_size = 10 # Choose a valid interval size for this test - - psd = compute_mean_power_spectral_density(sig, interval_size, full_range=full_range) - - if full_range: - assert (psd.index >= 0).all() - else: - assert (psd.index >= 0).any() # Partial range should exclude negative frequencies +@pytest.mark.parametrize( + "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", + [ + (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + (*get_signal_and_output(), 10, "a", None, False, "s", pytest.raises(TypeError, match="fs must be of type float or int")), + (*get_signal_and_output(), 10, None, "a", False, "s", pytest.raises(TypeError, match="ep param must be a pynapple IntervalSet object, or None")), + (*get_signal_and_output(), 10, None, None, "a", "s", pytest.raises(TypeError, match="full_range must be of type bool or None")), + (*get_signal_and_output(), 10*1e3, None, None, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10*1e6, None, None, False, "us", does_not_raise()), + (*get_signal_and_output(), 200, None, None, False, "s", pytest.raises(RuntimeError, match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size")), + (*get_signal_and_output(), 10, None, nap.IntervalSet([0, 200], [100,300]), False, "s", pytest.raises(RuntimeError, match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.")), + ] +) +def test_compute_mean_power_spectral_density_raise_errors( + sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation + ): + with expectation: + psd = nap.compute_mean_power_spectral_density(sig, interval_size, fs, ep, full_range, time_units) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 9df76ae4..3bacbb50 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -7,42 +7,6 @@ import pynapple as nap -def test_compute_spectogram(): - with pytest.raises(ValueError) as e_info: - t = np.linspace(0, 1, 1000) - sig = nap.Tsd( - d=np.random.random(1000), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), - ) - r = nap.compute_spectogram(sig) - assert ( - str(e_info.value) == "Given epoch (or signal time_support) must have length 1" - ) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.random.random(1000), t=t) - r = nap.compute_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape[0] == 500 - - sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) - r = nap.compute_spectogram(sig) - assert isinstance(r, pd.DataFrame) - assert r.shape == (500, 4) - - sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) - r = nap.compute_spectogram(sig, full_range=True) - assert isinstance(r, pd.DataFrame) - assert r.shape == (1000, 4) - - with pytest.raises(TypeError) as e_info: - nap.compute_spectogram("a_string") - assert ( - str(e_info.value) - == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) - def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1001) From c8a5fc47115f79e817f40cf0e4fa839a564be255 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 1 Aug 2024 19:50:26 -0400 Subject: [PATCH 115/195] linting --- pynapple/process/__init__.py | 2 +- pynapple/process/signal_processing.py | 46 ++++++++++++++++----------- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 53a80375..221ce039 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -16,8 +16,8 @@ shuffle_ts_intervals, ) from .signal_processing import ( - compute_power_spectral_density, compute_mean_power_spectral_density, + compute_power_spectral_density, compute_wavelet_transform, generate_morlet_filterbank, ) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index b7a5f8cd..4adf0567 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -4,9 +4,11 @@ Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ +from numbers import Number + import numpy as np import pandas as pd -from numbers import Number + from .. import core as nap @@ -56,10 +58,13 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): return ret.loc[ret.index >= 0] return ret -def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s"): - """Compute mean power spectral density by averaging FFT over epochs of same size. + +def compute_mean_power_spectral_density( + sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s" +): + """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. - + Note that this function assumes a constant sampling rate for sig. Parameters @@ -76,12 +81,12 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu If true, will return full fft frequency range, otherwise will return only positive values time_unit : str, optional Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') - + Returns ------- pandas.DataFrame Power spectral density. - + Raises ------ RuntimeError @@ -93,7 +98,7 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: ep = sig.time_support - + if not (fs is None or isinstance(fs, Number)): raise TypeError("fs must be of type float or int") if fs is None: @@ -103,25 +108,31 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu raise TypeError("full_range must be of type bool or None") # Split the ep - interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[0] + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ + 0 + ] split_ep = ep.split(interval_size) if len(split_ep) == 0: - raise RuntimeError(f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size") - + raise RuntimeError( + f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size" + ) + # Get the slices of each ep - slices = np.zeros((len(split_ep),2), dtype=int) + slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i,0], split_ep[i,1], mode='backward') - slices[i,0] = sl.start - slices[i,1] = sl.stop - + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + slices[i, 0] = sl.start + slices[i, 1] = sl.stop + # Check what is the signal length N = np.min(np.diff(slices, 1)) if N == 0: - raise RuntimeError("One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.") + raise RuntimeError( + "One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed." + ) # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) @@ -130,14 +141,13 @@ def compute_mean_power_spectral_density(sig, interval_size, fs=None, ep=None, fu fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) for i in range(len(slices)): - fft_result += np.fft.fft(sig[slices[i,0]:slices[i,1]].values[0:N], axis=0) + fft_result += np.fft.fft(sig[slices[i, 0] : slices[i, 1]].values[0:N], axis=0) ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) if not full_range: return ret.loc[ret.index >= 0] return ret - def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): From 283a0977922341929a3ba51eae3af3c588bf6abe Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 19:59:49 -0400 Subject: [PATCH 116/195] fixed public method --- pynapple/core/base_class.py | 15 +++---- tests/test_time_series.py | 89 ++++++++++++++++--------------------- 2 files changed, 43 insertions(+), 61 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 6f7a0e8c..632d89a7 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -596,7 +596,7 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit return slice(idx_start, idx_end, step) - def get_slice(self, start, end=None, mode="closest", time_unit="s"): + def get_slice(self, start, end=None, time_unit="s"): """ Get a slice from the time series data based on the start and end values with the specified mode. @@ -606,21 +606,15 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): The starting value for the slice. end : int or float, optional The ending value for the slice. Defaults to None. - mode : str, optional - The mode for slicing. Can be "forward", "backward", or "closest". Defaults to "closest". time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). Returns ------- slice : slice - A slice determining the start and end indices, with unit step. - - If mode = "closest": - Starts/ends the slice with indices closest to the start/end time provided. - - If mode = "backward": - Starts/ends the slice with the indices preceding the start/end time provided. - - If mode = "forward": - Starts/ends the slice with the indices following the start/end time provided. + A slice determining the start and end indices, with unit step + Slicing the array will behave like get: self[s] == self.get(start, end) + Raises ------ @@ -645,6 +639,7 @@ def get_slice(self, start, end=None, mode="closest", time_unit="s"): >>> print(ts.get_slice(start, None, mode="backward")) # returns `slice(1, 2, None)` >>> print(ts.get_slice(start, None, mode="forward")) # returns `slice(2, 3, None)` """ + mode = "closest_t" if end is None else "restrict" return self._get_slice( start, end=end, mode=mode, n_points=None, time_unit=time_unit ) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 6026d1fb..ecda7ac2 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1544,18 +1544,31 @@ def test_get_slice_value(start, end, mode, expected_slice, expected_array, ts): assert np.all(out_restrict.t == out_array) -def test_get_slice_restrict_random_val_value(): +def test_get_slice_vs_get_random_val_start_end_value(): np.random.seed(123) ts = nap.Ts(np.linspace(0.2, 0.8, 100)) se_vec = np.random.uniform(0, 1, size=(10000, 2)) starts = np.min(se_vec, axis=1) ends = np.max(se_vec, axis=1) for start, end in zip(starts, ends): - out_slice = ts._get_slice(start, end=end, mode="restrict") + out_slice = ts.get_slice(start=start, end=end) out_ts = ts[out_slice] - iset = nap.IntervalSet(start, end) - out_restrict = ts.restrict(iset) - assert np.all(out_restrict == out_ts) + out_get = ts.get(start, end) + assert np.all(out_get.t == out_ts.t) + + +def test_get_slice_vs_get_random_val_start_value(): + np.random.seed(123) + ts = nap.Ts(np.linspace(0.2, 0.8, 100)) + starts = np.random.uniform(0, 1, size=(10000, )) + + for start in starts: + out_slice = ts.get_slice(start=start, end=None) + out_ts = ts[out_slice] + out_get = ts.get(start) + assert np.all(out_get.t == out_ts.t) + + @pytest.mark.parametrize( @@ -1618,51 +1631,25 @@ def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expect assert np.all(out_array == expected_array) @pytest.mark.parametrize( - "start, end, mode, expected_slice, expected_array", + "start, end, expected_slice, expected_array", [ - (1, 3, "after_t", slice(0, 2), np.array([1, 2])), - (1, 3, "before_t", slice(0, 2), np.array([1, 2])), - (1, 3, "closest_t", slice(0, 2), np.array([1, 2])), - (1, 2.7, "after_t", slice(0, 2), np.array([1, 2])), - (1, 2.7, "before_t", slice(0, 1), np.array([1])), - (1, 2.7, "closest_t", slice(0, 2), np.array([1, 2])), - (1, 2.4, "after_t", slice(0, 2), np.array([1, 2])), - (1, 2.4, "before_t", slice(0, 1), np.array([1])), - (1, 2.4, "closest_t", slice(0, 1), np.array([1])), - (1.1, 3, "after_t", slice(1, 2), np.array([2])), - (1.1, 3, "before_t", slice(0, 2), np.array([1, 2])), - (1.1, 3, "closest_t", slice(0, 2), np.array([1, 2])), - (1.6, 3, "after_t", slice(1, 2), np.array([2])), - (1.6, 3, "before_t", slice(0, 2), np.array([1, 2])), - (1.6, 3, "closest_t", slice(1, 2), np.array([2])), - (1.6, 1.8, "before_t", slice(0, 0), np.array([])), - (1.6, 1.8, "after_t", slice(1, 1), np.array([])), - (1.6, 1.8, "closest_t", slice(1, 1), np.array([])), - (1.4, 1.6, "closest_t", slice(0, 1), np.array([1])), - (3, 3, "after_t", slice(2, 2), np.array([])), - (3, 3, "before_t", slice(2, 2), np.array([])), - (3, 3, "closest_t", slice(2, 2), np.array([])), - (0, 3, "after_t", slice(0, 2), np.array([1, 2])), - (0, 3, "before_t", slice(0, 2), np.array([1, 2])), - (0, 3, "closest_t", slice(0, 2), np.array([1, 2])), - (4, 4, "after_t", slice(3, 3), np.array([])), - (4, 4, "before_t", slice(3, 3), np.array([])), - (4, 4, "closest_t", slice(3, 3), np.array([])), - (4, 5, "after_t", slice(3, 4), np.array([4])), - (4, 5, "before_t", slice(3, 3), np.array([])), - (4, 5, "closest_t", slice(3, 3), np.array([])), - (0, 1, "after_t", slice(0, 0), np.array([])), - (0, 1, "before_t", slice(0, 1), np.array([1])), - (0, 1, "closest_t", slice(0, 0), np.array([])), - (0, None, "after_t", slice(0, 1), np.array([1])), - (0, None, "before_t", slice(0, 0), np.array([])), - (0, None, "closest_t", slice(0, 1), np.array([1])), - (1, None, "after_t", slice(0, 1), np.array([1])), - (1, None, "before_t", slice(0, 1), np.array([1])), - (1, None, "closest_t", slice(0, 1), np.array([1])), - (5, None, "after_t", slice(3, 3), np.array([])), - (5, None, "before_t", slice(3, 4), np.array([4])), - (5, None, "closest_t", slice(3, 4), np.array([4])) + (1, 3, slice(0, 3), np.array([1, 2, 3])), + (1, 2.7, slice(0, 2), np.array([1, 2])), + (1, 2.4, slice(0, 2), np.array([1, 2])), + (1.1, 3, slice(1, 3), np.array([2, 3])), + (1.6, 3, slice(1, 3), np.array([2, 3])), + (1.6, 1.8, slice(1, 1), np.array([])), + (1.4, 1.6, slice(1, 1), np.array([])), + (3, 3, slice(2, 3), np.array([3])), + (0, 3, slice(0, 3), np.array([1, 2, 3])), + (0, 4, slice(0, 4), np.array([1, 2, 3, 4])), + (4, 4, slice(3, 4), np.array([4])), + (4, 5, slice(3, 4), np.array([4])), + (0, 1, slice(0, 1), np.array([1])), + (0, None, slice(0, 1), np.array([1])), + (1, None, slice(0, 1), np.array([1])), + (4, None, slice(3, 4), np.array([4])), + (5, None, slice(3, 4), np.array([4])), ] ) @pytest.mark.parametrize("ts", @@ -1672,8 +1659,8 @@ def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expect nap.TsdFrame(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None]), nap.TsdTensor(t=np.array([1, 2, 3, 4]), d=np.array([1, 2, 3, 4])[:, None, None]) ]) -def test_get_slice_public(start, end, mode, expected_slice, expected_array, ts): - out_slice = ts.get_slice(start, end=end, mode=mode) +def test_get_slice_public(start, end, expected_slice, expected_array, ts): + out_slice = ts.get_slice(start, end=end) out_array = ts.t[out_slice] assert out_slice == expected_slice assert np.all(out_array == expected_array) From 84196314560365c184b6a346a1ec786cde99be42 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 20:06:49 -0400 Subject: [PATCH 117/195] fixed public method --- pynapple/core/base_class.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 632d89a7..95644e51 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -613,7 +613,7 @@ def get_slice(self, start, end=None, time_unit="s"): ------- slice : slice A slice determining the start and end indices, with unit step - Slicing the array will behave like get: self[s] == self.get(start, end) + Slicing the array will be equivalent to calling get: `ts[s].t == ts.get(start, end).t` Raises @@ -627,17 +627,18 @@ def get_slice(self, start, end=None, time_unit="s"): >>> import pynapple as nap >>> ts = nap.Ts(t = [0, 1, 2, 3]) - >>> start, end = 1.2, 2.6 >>> # slice over a range - >>> print(ts.get_slice(start, end, mode="closest")) # returns `slice(1, 3, None)` - >>> print(ts.get_slice(start, end, mode="backward")) # returns `slice(1, 2, None)` - >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(2, 3, None)` + >>> start, end = 1.2, 2.6 + >>> print(ts.get_slice(start, end)) # returns `slice(2, 3, None)` + >>> start, end = 1., 2. + >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(1, 3, None)` >>> # slice a single value - >>> print(ts.get_slice(start, None, mode="closest")) # returns `slice(1, 2, None)` - >>> print(ts.get_slice(start, None, mode="backward")) # returns `slice(1, 2, None)` - >>> print(ts.get_slice(start, None, mode="forward")) # returns `slice(2, 3, None)` + >>> start = 1.2 + >>> print(ts.get_slice(start)) # returns `slice(1, 2, None)` + >>> start = 2. + >>> print(ts.get_slice(start)) # returns `slice(2, 3, None)` """ mode = "closest_t" if end is None else "restrict" return self._get_slice( From a19f05dbc708b35838d937394db566fda0a06f5d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 20:16:19 -0400 Subject: [PATCH 118/195] linted --- pynapple/core/base_class.py | 8 ++++++-- pynapple/core/time_series.py | 8 +++++++- pynapple/io/__init__.py | 8 +++++++- pynapple/process/perievent.py | 5 ++++- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 95644e51..1f83a228 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -467,7 +467,9 @@ def _from_npz_reader(cls, file): iset = IntervalSet(start=file["start"], end=file["end"]) return cls(time_support=iset, **kwargs) - def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit="s"): + def _get_slice( + self, start, end=None, mode="closest_t", n_points=None, time_unit="s" + ): """ Get a slice from the time series data based on the start and end values with the specified mode. @@ -527,7 +529,9 @@ def _get_slice(self, start, end=None, mode="closest_t", n_points=None, time_unit raise ValueError("'n_points' can be used only when 'end' is specified!") if mode == "restrict" and n_points: - raise ValueError("Fixing the number of time points is incompatible with 'restrict' mode.") + raise ValueError( + "Fixing the number of time points is incompatible with 'restrict' mode." + ) # convert and get index for start start = TsIndex.format_timestamps(np.array([start]), time_unit)[0] diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index c7cd0e4c..8aa41a62 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -26,7 +26,13 @@ from scipy import signal from tabulate import tabulate -from ._core_functions import _bin_average, _convolve, _dropna, _restrict, _threshold +from ._core_functions import ( + _bin_average, + _convolve, + _dropna, + _restrict, + _threshold, +) from .base_class import Base from .interval_set import IntervalSet from .time_index import TsIndex diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index 20b194c7..f4eb2a70 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,4 +1,10 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session +from .misc import ( + append_NWB_LFP, + load_eeg, + load_file, + load_folder, + load_session, +) diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index a3dbb5d1..84aed7b1 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,7 +5,10 @@ import numpy as np from .. import core as nap -from ._process_functions import _perievent_continuous, _perievent_trigger_average +from ._process_functions import ( + _perievent_continuous, + _perievent_trigger_average, +) def _align_tsd(tsd, tref, window, time_support): From 1150c08b15b332aac8e2fa114fd81f8d4aea1e21 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 20:17:52 -0400 Subject: [PATCH 119/195] improved docstrings --- pynapple/core/base_class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 1f83a228..2c8b0e2d 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -480,7 +480,7 @@ def _get_slice( end : int or float, optional The ending value for the slice. Defaults to None. mode : str, optional - The mode for slicing. Can be "after_t", "before_t", or "closest". Defaults to "closest_t". + The mode for slicing. Can be "after_t", "before_t", "restrict", or "closest_t". Defaults to "closest_t". time_unit : str, optional The time unit for the start and end values. Defaults to "s" (seconds). n_points : int, optional From 0d134703e7cef954060bc4396dc07f8c17fda13e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 1 Aug 2024 20:18:54 -0400 Subject: [PATCH 120/195] isorted --- pynapple/core/time_series.py | 8 +------- pynapple/io/__init__.py | 8 +------- pynapple/process/perievent.py | 5 +---- 3 files changed, 3 insertions(+), 18 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 8aa41a62..c7cd0e4c 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -26,13 +26,7 @@ from scipy import signal from tabulate import tabulate -from ._core_functions import ( - _bin_average, - _convolve, - _dropna, - _restrict, - _threshold, -) +from ._core_functions import _bin_average, _convolve, _dropna, _restrict, _threshold from .base_class import Base from .interval_set import IntervalSet from .time_index import TsIndex diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index f4eb2a70..20b194c7 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,10 +1,4 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .misc import ( - append_NWB_LFP, - load_eeg, - load_file, - load_folder, - load_session, -) +from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 84aed7b1..a3dbb5d1 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -5,10 +5,7 @@ import numpy as np from .. import core as nap -from ._process_functions import ( - _perievent_continuous, - _perievent_trigger_average, -) +from ._process_functions import _perievent_continuous, _perievent_trigger_average def _align_tsd(tsd, tref, window, time_support): From 65a1a1a4a155b4066626300ddf91ef383aef0fa5 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 2 Aug 2024 16:26:48 +0100 Subject: [PATCH 121/195] param name changes --- docs/api_guide/tutorial_pynapple_wavelets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index d34de4f2..c0f13862 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -291,7 +291,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of time_decay +# Effect of gaussian_width # ------------------ # Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. @@ -347,9 +347,9 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** -# Effect of scaling +# Effect of window_length # ------------------ -# Let's increase scaling to 2.0 and see the effect on the resultant filter bank. +# Let's increase window_length to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( From b16c7c80a6db1eca021e637c9aa77be4be2c92fd Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 2 Aug 2024 11:54:24 -0400 Subject: [PATCH 122/195] fixed notebooks --- docs/api_guide/tutorial_pynapple_spectrum.py | 2 + docs/examples/tutorial_human_dataset.py | 39 +++++++------------- pynapple/io/folder.py | 6 +-- pynapple/io/interface_npz.py | 11 ++++-- 4 files changed, 26 insertions(+), 32 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index 6fe6d727..bcb64c50 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -14,6 +14,8 @@ # You can install all with `pip install matplotlib requests tqdm seaborn` # # Now, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 3 import matplotlib.pyplot as plt import numpy as np diff --git a/docs/examples/tutorial_human_dataset.py b/docs/examples/tutorial_human_dataset.py index f84cbef6..caeb5f0e 100644 --- a/docs/examples/tutorial_human_dataset.py +++ b/docs/examples/tutorial_human_dataset.py @@ -189,36 +189,23 @@ # ------------------ # # Now that we have the PETH of spiking, we can go one step further. We will plot the mean firing rate of this cell aligned to the boundary for each trial type. Doing this in Pynapple is very simple! - -bin_size = 0.2 # 200ms bin size -step_size = 0.01 # 10ms step size, to make overlapping bins -winsize = int(bin_size / step_size) # Window size - -# %% +# # Use Pynapple to compute binned spike counts - -counts_NB = NB_peth.count(step_size) # Spike counts binned in 10ms steps, for NB trials -counts_HB = HB_peth.count(step_size) # Spike counts binned in 10ms steps, for HB trials +bin_size = 0.01 +counts_NB = NB_peth.count(bin_size) # Spike counts binned in 10ms steps, for NB trials +counts_HB = HB_peth.count(bin_size) # Spike counts binned in 10ms steps, for HB trials # %% -# Smooth the binned spike counts using a window of size 20, for both trial types +# Compute firing rate for both trial types -counts_NB = ( - counts_NB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) -counts_HB = ( - counts_HB.as_dataframe() - .rolling(winsize, win_type="gaussian", min_periods=1, center=True, axis=0) - .mean(std=0.2 * winsize) -) +fr_NB = counts_NB / bin_size +fr_HB = counts_HB / bin_size # %% -# Compute firing rate for both trial types +# Smooth the firing rate with a gaussian window with std=4*bin_size +counts_NB = counts_NB.smooth(bin_size*4) +counts_HB = counts_HB.smooth(bin_size*4) -fr_NB = counts_NB * winsize -fr_HB = counts_HB * winsize # %% # Compute the mean firing rate for both trial types @@ -228,9 +215,9 @@ # %% # Compute standard error of mean (SEM) of the firing rate for both trial types - -error_NB = fr_NB.sem(axis=1) -error_HB = fr_HB.sem(axis=1) +from scipy.stats import sem +error_NB = sem(fr_NB, axis=1) +error_HB = sem(fr_HB, axis=1) # %% # Plot the mean +/- SEM of firing rate for both trial types diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index de1b9ef5..8f7d2f1a 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -4,7 +4,7 @@ # @Author: Guillaume Viejo # @Date: 2023-05-15 15:32:24 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-08-06 17:37:23 +# @Last Modified time: 2024-08-02 11:35:10 """ The Folder class helps to navigate a hierarchical data tree. @@ -302,12 +302,12 @@ def metadata(self, name): with open(json_filename, "r") as ff: metadata = json.load(ff) text = "\n".join([" : ".join(it) for it in metadata.items()]) - panel = Panel.fit(text, border_style="green", title=title) + panel = Panel.fit(text, border_style="green", title=str(title)) else: panel = Panel.fit( "No metadata", border_style="red", - title=title, + title=str(title), ) with Console() as console: console.print(panel) diff --git a/pynapple/io/interface_npz.py b/pynapple/io/interface_npz.py index 22da63bd..cedb779b 100644 --- a/pynapple/io/interface_npz.py +++ b/pynapple/io/interface_npz.py @@ -4,7 +4,7 @@ # @Author: Guillaume Viejo # @Date: 2023-07-05 16:03:25 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-02 14:32:25 +# @Last Modified time: 2024-08-02 11:16:07 from pathlib import Path @@ -26,10 +26,15 @@ def _find_class_from_variables(file_variables, data_ndims=None): if data_ndims is not None: - # either TsdTensor or Tsd: + assert EXPECTED_ENTRIES["Tsd"].issubset(file_variables) - return "Tsd" if data_ndims == 1 else "TsdTensor" + if data_ndims == 1: + return "Tsd" + elif data_ndims == 2: + return "TsdFrame" + else: + return "TsdTensor" for possible_type, expected_variables in EXPECTED_ENTRIES.items(): if expected_variables.issubset(file_variables): From c816da42d9122fe4a261fbcb175f515639503fed Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 2 Aug 2024 12:10:59 -0400 Subject: [PATCH 123/195] Adding filtering.py --- pynapple/process/filtering.py | 67 +++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 pynapple/process/filtering.py diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py new file mode 100644 index 00000000..4ea1b667 --- /dev/null +++ b/pynapple/process/filtering.py @@ -0,0 +1,67 @@ +""" + Filtering module +""" + +import numpy as np +from .. import core as nap +from scipy.signal import butter, lfilter, filtfilt + + +def _butter_bandpass(lowcut, highcut, fs, order=5): + nyq = 0.5 * fs + low = lowcut / nyq + high = highcut / nyq + b, a = butter(order, [low, high], btype='band') + return b, a + +def _butter_bandpass_filter(data, lowcut, highcut, fs, order=4): + b, a = _butter_bandpass(lowcut, highcut, fs, order=order) + y = lfilter(b, a, data) + return y + +def compute_bandpass_filter(data, freq_band, sampling_frequency=None, order=4): + """ + Bandpass filtering the LFP. + + Parameters + ---------- + data : Tsd/TsdFrame + Description + lowcut : TYPE + Description + highcut : TYPE + Description + fs : TYPE + Description + order : int, optional + Description + + Raises + ------ + RuntimeError + Description + """ + time_support = data.time_support + time_index = data.as_units('s').index.values + if type(data) is nap.TsdFrame: + tmp = np.zeros(data.shape) + for i in np.arange(data.shape[1]): + tmp[:,i] = bandpass_filter(data[:,i], lowcut, highcut, fs, order) + + return nap.TsdFrame( + t = time_index, + d = tmp, + time_support = time_support, + time_units = 's', + columns = data.columns) + + elif type(data) is nap.Tsd: + flfp = _butter_bandpass_filter(data.values, lowcut, highcut, fs, order) + return nap.Tsd( + t=time_index, + d=flfp, + time_support=time_support, + time_units='s') + + else: + raise RuntimeError("Unknow format. Should be Tsd/TsdFrame") \ No newline at end of file From 127b07b1c1763b4c955c317872fb054257e6d85c Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 2 Aug 2024 16:05:53 -0400 Subject: [PATCH 124/195] Propagating dtype to numba --- pynapple/core/_core_functions.py | 5 ++--- pynapple/core/_jitted_functions.py | 8 ++++---- pynapple/core/base_class.py | 4 +++- pynapple/core/time_series.py | 10 ++++++---- tests/test_jitted.py | 2 +- tests/test_time_series.py | 10 ++++++++++ tests/test_ts_group.py | 5 ++++- 7 files changed, 30 insertions(+), 14 deletions(-) diff --git a/pynapple/core/_core_functions.py b/pynapple/core/_core_functions.py index 9bbb871b..33cb2a0a 100644 --- a/pynapple/core/_core_functions.py +++ b/pynapple/core/_core_functions.py @@ -29,11 +29,10 @@ def _restrict(time_array, starts, ends): def _count(time_array, starts, ends, bin_size=None, dtype=None): if isinstance(bin_size, (float, int)): - t, d = jitcount(time_array, starts, ends, bin_size) + t, d = jitcount(time_array, starts, ends, bin_size, dtype) else: - _, d = jitrestrict_with_count(time_array, starts, ends) + _, d = jitrestrict_with_count(time_array, starts, ends, dtype) t = starts + (ends - starts) / 2 - d = d.astype(dtype) if dtype else d return t, d diff --git a/pynapple/core/_jitted_functions.py b/pynapple/core/_jitted_functions.py index 4dae7a99..9269a245 100644 --- a/pynapple/core/_jitted_functions.py +++ b/pynapple/core/_jitted_functions.py @@ -44,11 +44,11 @@ def jitrestrict(time_array, starts, ends): @jit(nopython=True) -def jitrestrict_with_count(time_array, starts, ends): +def jitrestrict_with_count(time_array, starts, ends, dtype=np.int64): n = len(time_array) m = len(starts) ix = np.zeros(n, dtype=np.int64) - count = np.zeros(m, dtype=np.int64) + count = np.zeros(m, dtype=dtype) k = 0 t = 0 @@ -118,7 +118,7 @@ def jitvaluefrom(time_array, time_target_array, count, count_target, starts, end @jit(nopython=True) -def jitcount(time_array, starts, ends, bin_size): +def jitcount(time_array, starts, ends, bin_size, dtype): idx, countin = jitrestrict_with_count(time_array, starts, ends) time_array = time_array[idx] @@ -133,7 +133,7 @@ def jitcount(time_array, starts, ends, bin_size): nb = np.sum(nb_bins) bins = np.zeros(nb, dtype=np.float64) - cnt = np.zeros(nb, dtype=np.int64) + cnt = np.zeros(nb, dtype=dtype) k = 0 t = 0 diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 054907e4..d76b5884 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -292,7 +292,9 @@ def count(self, *args, dtype=None, **kwargs): if isinstance(a, IntervalSet): ep = a - if dtype: + if dtype is None: + dtype = np.dtype(np.int64) + else: try: dtype = np.dtype(dtype) except Exception: diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 576d8f49..74692735 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -306,7 +306,7 @@ def value_from(self, data, ep=None): t, d, time_support, kwargs = super().value_from(data, ep) return data.__class__(t=t, d=d, time_support=time_support, **kwargs) - def count(self, *args, **kwargs): + def count(self, *args, dtype=None, **kwargs): """ Count occurences of events within bin_size or within a set of bins defined as an IntervalSet. You can call this function in multiple ways : @@ -334,6 +334,8 @@ def count(self, *args, **kwargs): IntervalSet to restrict the operation time_units : str, optional Time units of bin size ('us', 'ms', 's' [default]) + dtype: type, optional + Data type for the count. Default is np.int64. Returns ------- @@ -361,7 +363,7 @@ def count(self, *args, **kwargs): start end 0 100.0 800.0 """ - t, d, ep = super().count(*args, **kwargs) + t, d, ep = super().count(*args, dtype=dtype, **kwargs) return Tsd(t=t, d=d, time_support=ep) def bin_average(self, bin_size, ep=None, time_units="s"): @@ -1619,8 +1621,8 @@ def count(self, *args, dtype=None, **kwargs): And bincount automatically inherit ep as time support: >>> bincount.time_support - >>> start end - >>> 0 100.0 800.0 + start end + 0 100.0 800.0 """ t, d, ep = super().count(*args, dtype=dtype, **kwargs) return Tsd(t=t, d=d, time_support=ep) diff --git a/tests/test_jitted.py b/tests/test_jitted.py index f97096d5..61d1aeba 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -157,7 +157,7 @@ def test_jitcount(): starts = ep.start ends = ep.end bin_size = 1.0 - t, d = nap.core._jitted_functions.jitcount(time_array, starts, ends, bin_size) + t, d = nap.core._jitted_functions.jitcount(time_array, starts, ends, bin_size, np.int64) tsd3 = nap.Tsd(t=t, d=d, time_support = ep) tsd2 = [] diff --git a/tests/test_time_series.py b/tests/test_time_series.py index b097553b..f5925148 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -627,6 +627,11 @@ def test_count(self, tsd): assert len(count) == 99 np.testing.assert_array_almost_equal(count.index, np.arange(0.5, 99, 1)) + count = tsd.count(bin_size=1, dtype=np.int16) + assert len(count) == 99 + assert count.dtype == np.dtype(np.int16) + + def test_count_time_units(self, tsd): for b, tu in zip([1, 1e3, 1e6],['s', 'ms', 'us']): count = tsd.count(b, time_units = tu) @@ -1167,6 +1172,11 @@ def test_count(self, ts): assert len(count) == 99 np.testing.assert_array_almost_equal(count.index, np.arange(0.5, 99, 1)) + count = ts.count(bin_size=1, dtype=np.int16) + assert len(count) == 99 + assert count.dtype == np.dtype(np.int16) + + def test_count_time_units(self, ts): for b, tu in zip([1, 1e3, 1e6],['s', 'ms', 'us']): count = ts.count(b, time_units = tu) diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index b4da2d6c..17d21f95 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:14:41 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-04-11 14:42:50 +# @Last Modified time: 2024-08-02 16:02:43 """Tests of ts group for `pynapple` package.""" @@ -305,6 +305,9 @@ def test_count(self, group): count = tsgroup.count() np.testing.assert_array_almost_equal(count.values, np.array([[101, 201, 501]])) + count = tsgroup.count(1.0, dtype=np.int16) + assert count.dtype == np.dtype(np.int16) + def test_count_with_ep(self, group): ep = nap.IntervalSet(start=0, end=100) tsgroup = nap.TsGroup(group) From 74d9061aa1aacf535e4cdd7e22db13bbfa67c46a Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 15:14:09 +0100 Subject: [PATCH 125/195] better tests --- pynapple/process/signal_processing.py | 4 +- tests/test_signal_processing.py | 203 +++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 8 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4adf0567..25ef8448 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -233,7 +233,7 @@ def compute_wavelet_transform( Normalization method: * None - no normalization * 'l1' - divide by the sum of amplitudes - * 'l2' - divide by the square root of the sum of squares + * 'l2' - divide by the square root of the sum of amplitudes Returns ------- @@ -327,6 +327,8 @@ def generate_morlet_filterbank( """ if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") filter_bank = [] cutoff = 8 morlet_f = _morlet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 3bacbb50..b3bae4b9 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,12 +1,98 @@ """Tests of `signal_processing` for pynapple""" import numpy as np -import pandas as pd import pytest import pynapple as nap +def test_generate_morlet_filterbank(): + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 1000 + freqs = np.linspace(10, 100, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + fs = 10000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.5, window_length=3.0, precision=16 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + + +@pytest.mark.parametrize( + "freqs, fs, gaussian_width, window_length, precision, expectation", + [ + ( + np.linspace(0, 100, 11), + 1000, + 1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="All frequencies in freqs must be strictly positive" + ), + ), + ( + [], + 1000, + 1.5, + 1.0, + 16, + pytest.raises(ValueError, match="Given list of freqs cannot be empty."), + ), + ], +) +def test_generate_morlet_filterbank_raise_errors( + freqs, fs, gaussian_width, window_length, precision, expectation +): + with expectation: + _ = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width, window_length, precision + ) + def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1001) @@ -99,21 +185,124 @@ def test_compute_wavelet_transform(): mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) + t = np.linspace(0, 1, 1024) sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4) - t = np.linspace(0, 1, 1024) # can remove this when we move it + t = np.linspace(0, 1, 1024) sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 10, 4, 2) - with pytest.raises(ValueError) as e_info: - nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, gaussian_width=-1.5) - assert str(e_info.value) == "gaussian_width must be a positive number." + # Testing against manual convolution for l1 norm + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l1" + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved / (1024 / freqs) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) + + # Testing against manual convolution for l2 norm + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l2" + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved / (1024 / np.sqrt(freqs)) + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) + + # Testing against manual convolution for no normalization + t = np.linspace(0, 1, 1024) + sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) + freqs = np.linspace(1, 600, 10) + mwt = nap.compute_wavelet_transform( + sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm=None + ) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) + convolved_real = sig.convolve(filter_bank.real().values) + convolved_imag = sig.convolve(filter_bank.imag().values) + convolved = convolved_real.values + convolved_imag.values * 1j + coef = convolved + cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef + mwt2 = nap.TsdTensor( + t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + ) + assert np.array_equal(mwt, mwt2) -if __name__ == "__main__": - test_compute_wavelet_transform() +@pytest.mark.parametrize( + "sig, fs, freqs, gaussian_width, window_length, precision, norm, expectation", + [ + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(0, 600, 10), + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + ValueError, match="All frequencies in freqs must be strictly positive" + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + -1.5, + 1.0, + 16, + "l1", + pytest.raises( + ValueError, match="gaussian_width must be a positive number." + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + 1.0, + 16, + "l3", + pytest.raises( + ValueError, match="norm parameter must be 'l1', 'l2', or None." + ), + ), + ], +) +def test_compute_wavelet_transform_raise_errors( + sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation +): + with expectation: + _ = nap.compute_wavelet_transform( + sig, freqs, fs, gaussian_width, window_length, precision, norm + ) From 73ee4deac38ebab1f7160f40c11225dd78ef1004 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 15:26:32 +0100 Subject: [PATCH 126/195] one added case for tests --- pynapple/process/signal_processing.py | 15 +++++++-------- tests/test_signal_processing.py | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 25ef8448..5b8eaa3d 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -216,9 +216,9 @@ def compute_wavelet_transform( ---------- sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor Time series. - freqs : 1d array or list of float + freqs : 1d array or tuple of float If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. + If tuple, define the frequency range, as [freq_start, freq_stop, freq_step]. The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. fs : float or None Sampling rate, in Hz. Defaults to sig.rate if None is given. @@ -261,8 +261,9 @@ def compute_wavelet_transform( raise ValueError("gaussian_width must be a positive number.") if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") - - if isinstance(freqs, (tuple, list)): + if not isinstance(freqs, (np.ndarray, tuple)): + raise TypeError("`freqs` must be a ndarray or tuple instance.") + if isinstance(freqs, tuple): freqs = _create_freqs(*freqs) if fs is None: @@ -307,10 +308,8 @@ def generate_morlet_filterbank( Parameters ---------- - freqs : 1d array or list of float - If array, frequency values to estimate with morlet wavelets. - If list, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + freqs : 1d array + Frequency values to estimate with morlet wavelets. fs : float Sampling rate, in Hz. gaussian_width : float diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b3bae4b9..130c1b7a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -110,6 +110,17 @@ def test_compute_wavelet_transform(): == 500 ) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) @@ -297,6 +308,18 @@ def test_compute_wavelet_transform(): ValueError, match="norm parameter must be 'l1', 'l2', or None." ), ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + None, + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="`freqs` must be a ndarray or tuple instance." + ), + ), ], ) def test_compute_wavelet_transform_raise_errors( From a5b4f3bdf01b2ec5ade81b3817bb780f06512303 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 5 Aug 2024 15:48:04 -0400 Subject: [PATCH 127/195] Updating tutorial_signal_processing notebool --- docs/examples/tutorial_signal_processing.py | 189 ++++++++------------ 1 file changed, 74 insertions(+), 115 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 2453e0f2..dd0ac745 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -60,28 +60,33 @@ # Let's load and print the full dataset. data = nap.load_file(path) -FS = 1250 + print(data) +# %% +# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. +# +# The `time_support` of the object `data['position']` contains the interval for which the rat was running along the linear track. We will call it `wake_ep`. +# + +FS = 1250 + +eeg = data['eeg'] + +wake_ep = data['position'].time_support # %% # *** -# Selecting slices +# Selecting example # ----------------------------------- # We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, # followed by 4 seconds of post-traversal activity. -# Define the run to use for this Analysis -run_index = 7 -# Define the IntervalSet for this run and instantiate both LFP and -# Position TsdFrame objects -RUN_interval = nap.IntervalSet( - data["forward_ep"]["start"][run_index], - data["forward_ep"]["end"][run_index] + 4.0, -) -RUN_Tsd = data["eeg"].restrict(RUN_interval) -RUN_pos = data["position"].restrict(RUN_interval) -print(RUN_Tsd) +forward_ep = data['forward_ep'] +RUN_interval = nap.IntervalSet(forward_ep.start[7], forward_ep.end[7] + 4.0) + +eeg_example = eeg.restrict(RUN_interval)[:,0] +pos_example = data['position'].restrict(RUN_interval) # %% # *** @@ -93,71 +98,47 @@ [["ephys"], ["pos"]], height_ratios=[1, 0.4], ) - -axd["ephys"].plot( - RUN_Tsd[:, 0].restrict( - nap.IntervalSet( - data["forward_ep"]["start"][run_index], data["forward_ep"]["end"][run_index] - ) - ), - label="Traversal LFP Data", - color="green", -) -axd["ephys"].plot( - RUN_Tsd[:, 0].restrict( - nap.IntervalSet( - data["forward_ep"]["end"][run_index], - data["forward_ep"]["end"][run_index] + 5.0, - ) - ), - label="Post Traversal LFP Data", - color="blue", -) -axd["ephys"].set_title("Traversal & Post Traversal LFP") +axd["ephys"].plot(eeg_example, label="CA1") +axd["ephys"].set_title("EEG (1250 Hz)") axd["ephys"].set_ylabel("LFP (v)") axd["ephys"].set_xlabel("time (s)") axd["ephys"].margins(0) axd["ephys"].legend() -axd["pos"].plot(RUN_pos, color="black") +axd["pos"].plot(pos_example, color="black") axd["pos"].margins(0) axd["pos"].set_xlabel("time (s)") axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos"].set_xlim(RUN_interval[0,0], RUN_interval[0,1]) + # %% # *** # Getting the LFP Spectrogram # ----------------------------------- -# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies. +# Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies during exploration (`wake_ep`). + + +power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep) +print(power) + -fft = nap.compute_power_spectral_density(RUN_Tsd, fs=int(FS)) -print(fft) # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # -# Now let's plot it +# Let's plot the power between 1 and 100 Hz. +# +# The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.plot( - fft.index, - np.abs(fft.iloc[:, 0]), - alpha=0.5, - label="LFP Frequency Power", - c="blue", - linewidth=2, -) +ax.semilogy(np.abs(power[(power.index>1.0) & (power.index<100)]),alpha=0.5,label="LFP Frequency Power") +ax.axvspan(6, 12, color = 'red', alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") -ax.set_xlim(1, 30) -ax.axvline(9.36, c="orange", label="9.36Hz", alpha=0.5) -ax.axvline(18.74, c="green", label="18.74Hz", alpha=0.5) ax.legend() -# ax.set_yscale('log') -# ax.set_xscale('log') # %% # *** @@ -168,63 +149,41 @@ # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. # We must define the frequency set that we'd like to use for our decomposition -freqs = np.geomspace(5, 250, 25) +freqs = np.geomspace(3, 250, 100) # Compute and print the wavelet transform on our LFP data -mwt_RUN = nap.compute_wavelet_transform(RUN_Tsd[:, 0], fs=FS, freqs=freqs) +mwt_RUN = nap.compute_wavelet_transform(eeg_example, fs=FS, freqs=freqs) + + +# %% +# `mwt_RUN` is a TsdFrame with each column being the convolution with one wavelet at a particular frequency. +print(mwt_RUN) # %% # *** # Now let's plot it. +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(3, 1, figure=fig, height_ratios=[1.0, 0.5, 0.1]) + +ax0 = plt.subplot(gs[0,0]) +pcmesh = ax0.pcolormesh(mwt_RUN.t, freqs, np.transpose(np.abs(mwt_RUN))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet decomposition") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation='vertical') +ax0.set_label("Amplitude") + +ax1 = plt.subplot(gs[1,0], sharex = ax0) +ax1.plot(eeg_example) +ax1.set_ylabel("LFP (v)") + +ax1 = plt.subplot(gs[2,0], sharex = ax0) +ax1.plot(pos_example) +ax1.set_xlabel("Time (s)") +ax1.set_ylabel("Pos.") + +plt.show() -# Define wavelet decomposition plotting function -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) - ax.invert_yaxis() - ax.set_xlabel("Time (s)") - ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = [np.round(f, 2) for f in freqs] - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) - ax.grid(False) - - -# And plot -fig = plt.figure(constrained_layout=True, figsize=(10, 8)) -axd = fig.subplot_mosaic( - [ - ["wd_run"], - ["lfp_run"], - ["pos_run"], - ], - height_ratios=[1.2, 0.2, 0.6], -) -plot_timefrequency( - RUN_Tsd.index.values[:], - freqs[:], - np.transpose(mwt_RUN[:, :].values), - ax=axd["wd_run"], -) -axd["wd_run"].set_title(f"Wavelet Decomposition") -axd["lfp_run"].plot(RUN_Tsd) -axd["pos_run"].plot(RUN_pos) -axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) -axd["pos_run"].set_ylabel("Lin. Position (cm)") -for k in ["lfp_run", "pos_run"]: - axd[k].margins(0) - if k != "pos_run": - axd[k].set_ylabel("LFP (v)") - axd[k].get_xaxis().set_visible(False) - for spine in ["top", "right", "bottom", "left"]: - axd[k].spines[spine].set_visible(False) # %% # *** @@ -254,14 +213,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], alpha=0.5, label="LFP Data") +axd["lfp_run"].plot(eeg_example, alpha=0.5, label="LFP Data") axd["lfp_run"].plot( - RUN_Tsd.index.values, + eeg_example.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) axd["lfp_run"].plot( - RUN_Tsd.index.values, + eeg_example.index.values, theta_band_power_envelope, label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) @@ -269,14 +228,14 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") -axd["pos_run"].plot(RUN_pos) +axd["pos_run"].plot(pos_example) [axd[k].margins(0) for k in ["lfp_run", "pos_run"]] [ axd["pos_run"].spines[sp].set_visible(False) for sp in ["top", "right", "bottom", "left"] ] axd["pos_run"].get_xaxis().set_visible(False) -axd["pos_run"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["pos_run"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["pos_run"].set_ylabel("Lin. Position (cm)") axd["lfp_run"].legend() @@ -306,19 +265,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd.index.values, RUN_Tsd[:, 0], label="LFP Data") +axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(RUN_Tsd.index.values, ripple_power) +axd["rip_pow"].plot(eeg_example.index.values, ripple_power) axd["rip_pow"].margins(0) axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].spines["top"].set_visible(False) axd["rip_pow"].spines["right"].set_visible(False) axd["rip_pow"].spines["bottom"].set_visible(False) axd["rip_pow"].spines["left"].set_visible(False) -axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% @@ -352,7 +311,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): ], height_ratios=[1, 0.4], ) -axd["lfp_run"].plot(RUN_Tsd[:, 0], label="LFP Data") +axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["rip_pow"].plot(smoother_swr_power) axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) axd["lfp_run"].set_ylabel("LFP (v)") @@ -362,7 +321,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] [axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] axd["rip_pow"].get_xaxis().set_visible(False) -axd["rip_pow"].set_xlim(RUN_Tsd.index.min(), RUN_Tsd.index.max()) +axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") # %% @@ -374,7 +333,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) buffer = 0.1 ax.plot( - RUN_Tsd.restrict( + eeg_example.restrict( nap.IntervalSet( start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer ) @@ -383,7 +342,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): label="Non-SWR LFP", ) ax.plot( - RUN_Tsd.restrict( + eeg_example.restrict( nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) ), color="red", From b1540ea5d8418ccff69085c61e0f2133bb614902 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 21:47:34 +0100 Subject: [PATCH 128/195] more concise plotting code in docs --- docs/examples/tutorial_signal_processing.py | 107 +++++++++----------- 1 file changed, 48 insertions(+), 59 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index dd0ac745..ac3dc3c9 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -64,16 +64,16 @@ print(data) # %% -# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. +# First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. # # The `time_support` of the object `data['position']` contains the interval for which the rat was running along the linear track. We will call it `wake_ep`. # FS = 1250 -eeg = data['eeg'] +eeg = data["eeg"] -wake_ep = data['position'].time_support +wake_ep = data["position"].time_support # %% # *** @@ -82,11 +82,11 @@ # We will consider a single run of the experiment - where the rodent completes a full traversal of the linear track, # followed by 4 seconds of post-traversal activity. -forward_ep = data['forward_ep'] +forward_ep = data["forward_ep"] RUN_interval = nap.IntervalSet(forward_ep.start[7], forward_ep.end[7] + 4.0) -eeg_example = eeg.restrict(RUN_interval)[:,0] -pos_example = data['position'].restrict(RUN_interval) +eeg_example = eeg.restrict(RUN_interval)[:, 0] +pos_example = data["position"].restrict(RUN_interval) # %% # *** @@ -100,7 +100,7 @@ ) axd["ephys"].plot(eeg_example, label="CA1") axd["ephys"].set_title("EEG (1250 Hz)") -axd["ephys"].set_ylabel("LFP (v)") +axd["ephys"].set_ylabel("LFP (a.u.)") axd["ephys"].set_xlabel("time (s)") axd["ephys"].margins(0) axd["ephys"].legend() @@ -108,8 +108,7 @@ axd["pos"].margins(0) axd["pos"].set_xlabel("time (s)") axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_interval[0,0], RUN_interval[0,1]) - +axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) # %% @@ -123,18 +122,21 @@ print(power) - # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # -# Let's plot the power between 1 and 100 Hz. +# Let's plot the power between 1 and 100 Hz. # # The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.semilogy(np.abs(power[(power.index>1.0) & (power.index<100)]),alpha=0.5,label="LFP Frequency Power") -ax.axvspan(6, 12, color = 'red', alpha=0.1) +ax.semilogy( + np.abs(power[(power.index > 1.0) & (power.index < 100)]), + alpha=0.5, + label="LFP Frequency Power", +) +ax.axvspan(6, 12, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") @@ -165,20 +167,20 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) gs = plt.GridSpec(3, 1, figure=fig, height_ratios=[1.0, 0.5, 0.1]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) pcmesh = ax0.pcolormesh(mwt_RUN.t, freqs, np.transpose(np.abs(mwt_RUN))) ax0.grid(False) ax0.set_yscale("log") -ax0.set_title("Wavelet decomposition") -cbar = plt.colorbar(pcmesh, ax=ax0, orientation='vertical') +ax0.set_title("Wavelet Decomposition") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") ax0.set_label("Amplitude") -ax1 = plt.subplot(gs[1,0], sharex = ax0) +ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) ax1.set_ylabel("LFP (v)") -ax1 = plt.subplot(gs[2,0], sharex = ax0) -ax1.plot(pos_example) +ax1 = plt.subplot(gs[2, 0], sharex=ax0) +ax1.plot(pos_example, color="black") ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") @@ -194,7 +196,7 @@ # they match up # Find the index of the frequency closest to theta band -theta_freq_index = np.argmin(np.abs(10 - freqs)) +theta_freq_index = np.argmin(np.abs(8 - freqs)) # Extract its real component, as well as its power envelope theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) @@ -206,38 +208,31 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["pos_run"], - ], + [["ephys"], ["pos"]], height_ratios=[1, 0.4], ) - -axd["lfp_run"].plot(eeg_example, alpha=0.5, label="LFP Data") -axd["lfp_run"].plot( +axd["ephys"].plot(eeg_example, label="CA1") +axd["ephys"].plot( eeg_example.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) -axd["lfp_run"].plot( +axd["ephys"].plot( eeg_example.index.values, theta_band_power_envelope, label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", ) +axd["ephys"].set_title("EEG (1250 Hz)") +axd["ephys"].set_ylabel("LFP (a.u.)") +axd["ephys"].set_xlabel("time (s)") +axd["ephys"].margins(0) +axd["ephys"].legend() +axd["pos"].plot(pos_example, color="black") +axd["pos"].margins(0) +axd["pos"].set_xlabel("time (s)") +axd["pos"].set_ylabel("Linearized Position") +axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.") -axd["pos_run"].plot(pos_example) -[axd[k].margins(0) for k in ["lfp_run", "pos_run"]] -[ - axd["pos_run"].spines[sp].set_visible(False) - for sp in ["top", "right", "bottom", "left"] -] -axd["pos_run"].get_xaxis().set_visible(False) -axd["pos_run"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["pos_run"].set_ylabel("Lin. Position (cm)") -axd["lfp_run"].legend() # %% # *** @@ -266,17 +261,12 @@ height_ratios=[1, 0.4], ) axd["lfp_run"].plot(eeg_example, label="LFP Data") +axd["rip_pow"].plot(eeg_example.index.values, ripple_power) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") axd["lfp_run"].margins(0) -axd["lfp_run"].set_title(f"Traversal & Post-Traversal LFP") -axd["rip_pow"].plot(eeg_example.index.values, ripple_power) +axd["lfp_run"].set_title(f"EEG (1250 Hz)") axd["rip_pow"].margins(0) -axd["rip_pow"].get_xaxis().set_visible(False) -axd["rip_pow"].spines["top"].set_visible(False) -axd["rip_pow"].spines["right"].set_visible(False) -axd["rip_pow"].spines["bottom"].set_visible(False) -axd["rip_pow"].spines["left"].set_visible(False) axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") @@ -313,14 +303,14 @@ ) axd["lfp_run"].plot(eeg_example, label="LFP Data") axd["rip_pow"].plot(smoother_swr_power) -axd["rip_pow"].plot(is_ripple, color="red", linewidth=2) +axd["rip_pow"].axvspan( + is_ripple.index.min(), is_ripple.index.max(), color="red", alpha=0.3 +) axd["lfp_run"].set_ylabel("LFP (v)") axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[ripple_freq_idx], 2)}Hz oscillation power.") -axd["rip_pow"].axhline(threshold) +axd["lfp_run"].set_title(f"EEG (1250 Hz)") +axd["rip_pow"].axhline(threshold, linestyle="--", color="black", alpha=0.4) [axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] -[axd["rip_pow"].spines[sp].set_visible(False) for sp in ["top", "left", "right"]] -axd["rip_pow"].get_xaxis().set_visible(False) axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") @@ -341,13 +331,12 @@ color="blue", label="Non-SWR LFP", ) -ax.plot( - eeg_example.restrict( - nap.IntervalSet(start=is_ripple.index.min(), end=is_ripple.index.max()) - ), +ax.axvspan( + is_ripple.index.min(), + is_ripple.index.max(), color="red", - label="SWR", - linewidth=2, + alpha=0.3, + label="SWR LFP", ) ax.margins(0) ax.set_xlabel("Time (s)") From 05a59961a46453ac74958929521f3414bbc6fc2f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Mon, 5 Aug 2024 23:32:15 +0100 Subject: [PATCH 129/195] signal processing tests to 100% coverage --- pynapple/process/signal_processing.py | 34 +----- tests/test_power_spectral_density.py | 169 ++++++++++++++++++++------ tests/test_signal_processing.py | 25 +++- 3 files changed, 158 insertions(+), 70 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 5b8eaa3d..e0ee1e32 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -178,7 +178,7 @@ def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_base=np.e): +def _create_freqs(freq_start, freq_stop, num_freqs=10, log_scaling=False): """ Creates an array of frequencies. @@ -188,12 +188,10 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas Starting value for the frequency definition. freq_stop: float Stopping value for the frequency definition, inclusive. - freq_step: float, optional - Step value, for linearly spaced values between start and stop. + num_freqs: int, optional + Number of freqs to create. Default 10 log_scaling: Bool If True, will use log spacing with base log_base for frequency spacing. Default False. - log_base: float - If log_scaling==True, this defines the base of the log to use. Returns ------- @@ -201,9 +199,9 @@ def _create_freqs(freq_start, freq_stop, freq_step=1, log_scaling=False, log_bas Frequency indices. """ if not log_scaling: - return np.arange(freq_start, freq_stop + freq_step, freq_step) + return np.linspace(freq_start, freq_stop, num_freqs) else: - return np.logspace(freq_start, freq_stop, base=log_base) + return np.geomspace(freq_start, freq_stop, num_freqs) def compute_wavelet_transform( @@ -358,25 +356,3 @@ def generate_morlet_filterbank( for arr in filter_bank ] return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) - - -def _integrate(arr, step): - """ - Integrates an array with respect to some step param. Used for integrating complex wavelets. - - Parameters - ---------- - arr : np.ndarray - wave function to be integrated - step : float - Step size of given wave function array - - Returns - ------- - array - Complex-valued integrated wavelet - - """ - integral = np.cumsum(arr) - integral *= step - return integral diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 0eb39cdd..18503294 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -1,25 +1,18 @@ +import re +from contextlib import nullcontext as does_not_raise + import numpy as np import pandas as pd import pytest -from contextlib import nullcontext as does_not_raise + import pynapple as nap ############################################################ # Test for mean_power_spectral_density ############################################################ + def test_compute_power_spectral_density(): - with pytest.raises(ValueError) as e_info: - t = np.linspace(0, 1, 1000) - sig = nap.Tsd( - d=np.random.random(1000), - t=t, - time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), - ) - r = nap.compute_power_spectral_density(sig) - assert ( - str(e_info.value) == "Given epoch (or signal time_support) must have length 1" - ) t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.random.random(1000), t=t) @@ -37,23 +30,65 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) - with pytest.raises(TypeError) as e_info: - nap.compute_power_spectral_density("a_string") - assert ( - str(e_info.value) - == "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) + +@pytest.mark.parametrize( + "sig, fs, ep, full_range, expectation", + [ + ( + nap.Tsd( + d=np.random.random(1000), + t=np.linspace(0, 1, 1000), + time_support=nap.IntervalSet(start=[0.1, 0.6], end=[0.2, 0.81]), + ), + 1000, + None, + False, + pytest.raises( + ValueError, + match=re.escape( + "Given epoch (or signal time_support) must have length 1" + ), + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + "not_ep", + False, + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + "not_a_tsd", + 1000, + None, + False, + pytest.raises( + TypeError, + match="Currently compute_spectogram is only implemented for Tsd or TsdFrame", + ), + ), + ], +) +def test_compute_power_spectral_density_raise_errors( + sig, fs, ep, full_range, expectation +): + with expectation: + psd = nap.compute_power_spectral_density(sig, fs, ep, full_range) ############################################################ # Test for mean_power_spectral_density ############################################################ -def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): - t=np.arange(0, duration, 1/fs) - d=np.cos(2*np.pi*f*t) - sig = nap.Tsd(t=t,d=d, time_support=nap.IntervalSet(0,100)) - tmp = d.reshape((int(duration/interval_size),int(fs*interval_size))).T + +def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): + t = np.arange(0, duration, 1 / fs) + d = np.cos(2 * np.pi * f * t) + sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) + tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T tmp = tmp[0:-1] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) @@ -62,15 +97,16 @@ def get_signal_and_output(f=2, fs=1000,duration=100,interval_size=10): freq = freq[order] return (sig, out, freq) + def test_compute_mean_power_spectral_density(): - sig, out, freq = get_signal_and_output() + sig, out, freq = get_signal_and_output() psd = nap.compute_mean_power_spectral_density(sig, 10) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq>=0]) - np.testing.assert_array_almost_equal(psd.index.values, freq[freq>=0]) + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) - # Full range + # Full range psd = nap.compute_mean_power_spectral_density(sig, 10, full_range=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty @@ -78,29 +114,82 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.index.values, freq) # TsdFrame - sig2 = nap.TsdFrame(t=sig.t, d=np.repeat(sig.values[:,None], 2, 1), time_support = sig.time_support) + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:,None],2,1)) + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) np.testing.assert_array_almost_equal(psd.index.values, freq) @pytest.mark.parametrize( "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", [ - (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), - (*get_signal_and_output(), 10, "a", None, False, "s", pytest.raises(TypeError, match="fs must be of type float or int")), - (*get_signal_and_output(), 10, None, "a", False, "s", pytest.raises(TypeError, match="ep param must be a pynapple IntervalSet object, or None")), - (*get_signal_and_output(), 10, None, None, "a", "s", pytest.raises(TypeError, match="full_range must be of type bool or None")), - (*get_signal_and_output(), 10*1e3, None, None, False, "ms", does_not_raise()), - (*get_signal_and_output(), 10*1e6, None, None, False, "us", does_not_raise()), - (*get_signal_and_output(), 200, None, None, False, "s", pytest.raises(RuntimeError, match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size")), - (*get_signal_and_output(), 10, None, nap.IntervalSet([0, 200], [100,300]), False, "s", pytest.raises(RuntimeError, match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.")), - ] + (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + ( + *get_signal_and_output(), + 10, + "a", + None, + False, + "s", + pytest.raises(TypeError, match="fs must be of type float or int"), + ), + ( + *get_signal_and_output(), + 10, + None, + "a", + False, + "s", + pytest.raises( + TypeError, + match="ep param must be a pynapple IntervalSet object, or None", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + None, + "a", + "s", + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + (*get_signal_and_output(), 10 * 1e3, None, None, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10 * 1e6, None, None, False, "us", does_not_raise()), + ( + *get_signal_and_output(), + 200, + None, + None, + False, + "s", + pytest.raises( + RuntimeError, + match="Splitting epochs with interval_size=200 generated an empty IntervalSet. Try decreasing interval_size", + ), + ), + ( + *get_signal_and_output(), + 10, + None, + nap.IntervalSet([0, 200], [100, 300]), + False, + "s", + pytest.raises( + RuntimeError, + match="One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed.", + ), + ), + ], ) def test_compute_mean_power_spectral_density_raise_errors( sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation - ): +): with expectation: - psd = nap.compute_mean_power_spectral_density(sig, interval_size, fs, ep, full_range, time_units) + psd = nap.compute_mean_power_spectral_density( + sig, interval_size, fs, ep, full_range, time_units + ) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 130c1b7a..6310e0af 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -121,6 +121,17 @@ def test_compute_wavelet_transform(): mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10, True) + mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.geomspace(10, 100, 10)) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 10 * np.pi * 2) @@ -192,7 +203,7 @@ def test_compute_wavelet_transform(): t = np.linspace(0, 1, 1024) sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = (1, 51, 10) + freqs = (1, 51, 6) mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) assert mwt.shape == (1024, 6) @@ -320,6 +331,18 @@ def test_compute_wavelet_transform(): TypeError, match="`freqs` must be a ndarray or tuple instance." ), ), + ( + "not_a_signal", + None, + np.linspace(10, 100, 10), + 1.5, + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ), + ), ], ) def test_compute_wavelet_transform_raise_errors( From 05e29b65451f158ec8dc8742a007a4c3f0daf789 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 14:55:22 +0100 Subject: [PATCH 130/195] coverage actually to 100% --- pynapple/process/signal_processing.py | 7 +++++ tests/test_power_spectral_density.py | 22 +++++++++++++ tests/test_signal_processing.py | 45 +++++++++++++++++++++++++++ 3 files changed, 74 insertions(+) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index e0ee1e32..7f921402 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -257,6 +257,13 @@ def compute_wavelet_transform( if isinstance(gaussian_width, (int, float, np.number)): if gaussian_width <= 0: raise ValueError("gaussian_width must be a positive number.") + else: + raise TypeError("gaussian_width must be a float or int instance.") + if isinstance(window_length, (int, float, np.number)): + if window_length <= 0: + raise ValueError("window_length must be a positive number.") + else: + raise TypeError("window_length must be a float or int instance.") if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") if not isinstance(freqs, (np.ndarray, tuple)): diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 18503294..7d0ec64f 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -30,6 +30,18 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, ep=sig.time_support) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + + t = np.linspace(0, 1, 1000) + sig = nap.Tsd(d=np.random.random(1000), t=t) + r = nap.compute_power_spectral_density(sig, fs=1000) + assert isinstance(r, pd.DataFrame) + assert r.shape[0] == 500 + @pytest.mark.parametrize( "sig, fs, ep, full_range, expectation", @@ -123,6 +135,16 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) np.testing.assert_array_almost_equal(psd.index.values, freq) + # TsdFrame + sig2 = nap.TsdFrame( + t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support + ) + psd = nap.compute_mean_power_spectral_density(sig2, 10, full_range=True, fs=1000) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values, np.repeat(out[:, None], 2, 1)) + np.testing.assert_array_almost_equal(psd.index.values, freq) + @pytest.mark.parametrize( "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 6310e0af..34b0e79f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -121,6 +121,17 @@ def test_compute_wavelet_transform(): mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) + sig = nap.Tsd( + d=np.sin(t * 50 * np.pi * 2) + * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), + t=t, + ) + freqs = (10, 100, 10) + mwt = nap.compute_wavelet_transform(sig, fs=1001, freqs=freqs) + mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) + assert np.array_equal(mwt, mwt2) + t = np.linspace(0, 1, 1001) sig = nap.Tsd( d=np.sin(t * 50 * np.pi * 2) @@ -307,6 +318,40 @@ def test_compute_wavelet_transform(): ValueError, match="gaussian_width must be a positive number." ), ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + -1.0, + 16, + "l1", + pytest.raises(ValueError, match="window_length must be a positive number."), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + "not_number", + 1.0, + 16, + "l1", + pytest.raises( + TypeError, match="gaussian_width must be a float or int instance." + ), + ), + ( + nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + None, + np.linspace(1, 600, 10), + 1.5, + "not_number", + 16, + "l1", + pytest.raises( + TypeError, match="window_length must be a float or int instance." + ), + ), ( nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), None, From d723b8a660e9d02b4b664fd7fefd93e4ab1f109b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 20:30:18 +0100 Subject: [PATCH 131/195] changes to notebooks --- docs/api_guide/tutorial_pynapple_wavelets.py | 298 ++++++++++++------- docs/examples/tutorial_signal_processing.py | 2 - 2 files changed, 192 insertions(+), 108 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index c0f13862..f4a20b6d 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -50,17 +50,15 @@ # Lets plot it. fig, ax = plt.subplots(3, constrained_layout=True, figsize=(10, 5)) ax[0].plot(t, two_hz_component) -ax[1].plot(t, increasing_freq_component) -ax[2].plot(sig) ax[0].set_title("2Hz Component") +ax[1].plot(t, increasing_freq_component) ax[1].set_title("Increasing Frequency Component") +ax[2].plot(sig) ax[2].set_title("Dummy Signal") [ax[i].margins(0) for i in range(3)] [ax[i].set_ylim(-2.5, 2.5) for i in range(3)] -[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(3)] [ax[i].set_xlabel("Time (s)") for i in range(3)] [ax[i].set_ylabel("Signal") for i in range(3)] -[ax[i].set_ylim(-2.5, 2.5) for i in range(3)] # %% @@ -83,15 +81,11 @@ # Lets plot it. def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) - offset = 1.0 for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + offset * f_i) - ax.text( - -2.3, offset * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left" - ) + ax.plot(filter_bank[:, f_i].real() + 1.5 * f_i) + ax.text(-2.3, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") ax.margins(0) ax.yaxis.set_visible(False) - [ax.spines[sp].set_visible(False) for sp in ["left", "right", "top"]] ax.set_xlim(-2, 2) ax.set_xlabel("Time (s)") ax.set_title(title) @@ -117,33 +111,30 @@ def plot_filterbank(filter_bank, freqs, title): # %% # Lets plot it. -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = freqs - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.get_xaxis().set_visible(False) + ax.set(yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], yticklabels=freqs) ax.grid(False) + return im -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -162,7 +153,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # Lets plot it. -fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) axd = fig.subplot_mosaic( [["signal"], ["phase"]], height_ratios=[1, 0.4], @@ -170,20 +161,12 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") axd["signal"].legend() +axd["signal"].set_ylabel("Signal") + axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) axd["phase"].set_ylabel("Phase (rad)") -axd["signal"].set_ylabel("Signal") axd["phase"].set_xlabel("Time (s)") -[ - axd[f].spines[sp].set_visible(False) - for sp in ["right", "top"] - for f in ["phase", "signal"] -] -axd["signal"].get_xaxis().set_visible(False) -axd["signal"].spines["bottom"].set_visible(False) [axd[k].margins(0) for k in ["signal", "phase"]] -axd["signal"].set_ylim(-2.5, 2.5) -axd["phase"].set_ylim(-np.pi, np.pi) # %% # *** @@ -205,21 +188,24 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # Lets plot it. -fig, ax = plt.subplots(2, constrained_layout=True, figsize=(10, 6)) -ax[0].plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") -ax[0].plot(t, fifteenHz_oscillation_power, label="15Hz Power") -ax[1].plot(sig, label="Raw Signal", alpha=0.5) -ax[1].plot( - t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction" -) -[ax[i].set_ylim(-2.5, 2.5) for i in range(2)] -[ax[i].margins(0) for i in range(2)] -[ax[i].legend() for i in range(2)] -[ax[i].spines[sp].set_visible(False) for sp in ["right", "top"] for i in range(2)] -ax[0].get_xaxis().set_visible(False) -ax[0].spines["bottom"].set_visible(False) -ax[1].set_xlabel("Time (s)") -[ax[i].set_ylabel("Signal") for i in range(2)] + +fig = plt.figure(constrained_layout=True, figsize=(10, 4)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 1.0]) + +ax0 = plt.subplot(gs[0, 0]) +ax0.plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") +ax0.plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax0.set_xticklabels([]) + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig, label="Raw Signal", alpha=0.5) +ax1.plot(t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +ax1.set_xlabel("Time (s)") + +[ + (a.margins(0), a.legend(), a.set_ylim(-2.5, 2.5), a.set_ylabel("Signal")) + for a in [ax0, ax1] +] # %% @@ -235,7 +221,6 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) -[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -250,50 +235,59 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # ------------------ # Our reconstruction seems to get the amplitude modulations of our signal correct, but the amplitude is overestimated, # in particular towards the end of the time period. Often, this is due to a suboptimal choice of parameters, which -# can lead to a low spatial or temporal resolution. Let's explore what changing our parameters does to the +# can lead to a low spatial or temporal resolution. Let's visualize what changing our parameters does to the # underlying wavelets. -freqs = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) -window_lengths = [1.0, 2.0, 3.0] -gaussian_width = [1.0, 2.0, 3.0] - +window_lengths = [1.0, 3.0] +gaussian_widths = [1.0, 3.0] +colors = np.array([["r", "g"], ["b", "y"]]) fig, ax = plt.subplots( - len(window_lengths), len(gaussian_width), constrained_layout=True, figsize=(10, 8) + len(window_lengths) + 1, + len(gaussian_widths) + 1, + constrained_layout=True, + figsize=(10, 8), ) for row_i, wl in enumerate(window_lengths): - for col_i, gw in enumerate(gaussian_width): - filter_bank = nap.generate_morlet_filterbank( - freqs, 1000, gaussian_width=gw, window_length=wl, precision=12 - ) - ax[row_i, col_i].plot(filter_bank[:, 0].real()) - ax[row_i, col_i].set_xlabel("Time (s)") - ax[row_i, col_i].set_yticks([]) - [ - ax[row_i, col_i].spines[sp].set_visible(False) - for sp in ["top", "right", "left"] - ] - if col_i != 0: - ax[row_i, col_i].get_yaxis().set_visible(False) -for col_i, gw in enumerate(gaussian_width): - ax[0, col_i].set_title(f"gaussian_width={gw}", fontsize=10) -for row_i, wl in enumerate(window_lengths): - ax[row_i, 0].set_ylabel(f"window_length={wl}", fontsize=10) -fig.suptitle("Parametrization Visualization") - + for col_i, gw in enumerate(gaussian_widths): + wavelet = nap.generate_morlet_filterbank( + np.array([1.0]), 1000, gaussian_width=gw, window_length=wl, precision=12 + )[:, 0].real() + ax[row_i, col_i].plot(wavelet, c=colors[row_i, col_i]) + fft = nap.compute_power_spectral_density(wavelet) + for i, j in [(row_i, -1), (-1, col_i)]: + ax[i, j].plot(fft.abs(), c=colors[row_i, col_i]) +for i in range(len(window_lengths)): + for j in range(len(gaussian_widths)): + ax[i, j].set(xlabel="Time (s)", yticks=[]) +for ci, gw in enumerate(gaussian_widths): + ax[0, ci].set_title(f"gaussian_width={gw}", fontsize=10) +for ri, wl in enumerate(window_lengths): + ax[ri, 0].set_ylabel(f"window_length={wl}", fontsize=10) +fig.suptitle("Parametrization Visualization (1 Hz Wavelet)") +ax[-1, -1].set_visible(False) +for i in range(len(window_lengths)): + ax[-1, i].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) +for i in range(len(gaussian_widths)): + ax[i, -1].set( + xlim=(0, 2), yticks=[], ylabel="Frequency Response", xlabel="Frequency (Hz)" + ) # %% -# Increasing time_decay increases the number of wavelet cycles present in the oscillations (cycles) within the -# Gaussian window of the Morlet wavelet. It essentially controls the trade-off between time resolution -# and frequency resolution. +# Increasing window_length increases the number of wavelet cycles present in the oscillations (cycles), and +# correspondingly increases the time window that the wavelet covers. # -# The scale parameter determines the dilation or compression of the wavelet. It controls the size of the wavelet in -# time, affecting the overall shape of the wavelet. +# The gaussian_width parameter determines the shape of the gaussian window being convolved with the sinusoidal +# component of the wavelet +# +# Both of these parameters can be tweaked to control for the trade-off between time resolution and frequency resolution. # %% # *** # Effect of gaussian_width # ------------------ -# Let's increase time_decay to 7.5 and see the effect on the resultant filter bank. +# Let's increase gaussian_width to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -313,14 +307,19 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=1.0 ) -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -365,17 +364,23 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # %% # *** # Let's see what effect this has on the Wavelet Scalogram which is generated... -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 6)) mwt = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0 ) -plot_timefrequency( - mwt.index.values[:], - freqs[:], - np.transpose(mwt[:, :].values), - ax=ax, -) -ax.set_title("Wavelet Decomposition Scalogram") + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) # %% # *** @@ -388,10 +393,91 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) -[ax.spines[sp].set_visible(False) for sp in ["right", "top"]] ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") ax.margins(0) ax.set_ylim(-6, 6) ax.legend() + + +# %% +# *** +# Effect of L1 vs L2 normalization +# ------------------ +# compute_wavelet_transform contains two options for normalization; L1, and L2. L1 normalization. +# By default, L1 is used as it creates cleaner looking decomposition images. +# +# L1 normalization often increases the contrast between significant and insignificant coefficients. +# This can result in a sharper and more defined visual representation, making patterns and structures within +# the signal more evident. +# +# L2 normalization is directly related to the energy of the signal. By normalizing using the +# L2 norm, you ensure that the transformed coefficients preserve the energy distribution of the original signal. +# +# Let's compare two wavelet decomposition images, each generated with a different normalization strategy + +mwt_l1 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l1" +) +mwt_l2 = nap.compute_wavelet_transform( + sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l2" +) + +# %% +# Let's plot both the scalograms and see the difference. + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L1 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l1[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +fig.suptitle("Wavelet Decomposition - L2 Normalization") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_l2[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(sig) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + +# %% +# We see that the l1 normalized image contains a visually clearer image; the 5-15 Hz component of the signal is +# as powerful as the 2 Hz component, so it makes sense that they should be shown with the same power in the scalogram. +# Let's reconstruct the signal using both decompositions and see the resulting reconstruction... + +# %% + +combined_oscillations_l1 = mwt_l1.sum(axis=1).real() +combined_oscillations_l2 = mwt_l2.sum(axis=1).real() + +# %% +# Lets plot it. +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot(sig, label="Signal", linewidth=3, alpha=0.6, c="b") +ax.plot( + t, combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6 +) +ax.plot( + t, combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6 +) +ax.set_xlabel("Time (s)") +ax.set_ylabel("Signal") +ax.set_title("Wavelet Reconstruction of Signal") +ax.margins(0) +ax.set_ylim(-6, 6) +ax.legend() + +# %% +# We see that the reconstruction from the L2 normalized decomposition matched the original signal much more closely, +# this is due to the fact that L2 normalization preserved the energy of the original signal in its reconstruction. diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index ac3dc3c9..0e874085 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -184,8 +184,6 @@ ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") -plt.show() - # %% # *** From 8d78bb52aa4581e666356a020639e91622fbc132 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Tue, 6 Aug 2024 22:00:29 +0100 Subject: [PATCH 132/195] doc plot neatening --- docs/api_guide/tutorial_pynapple_wavelets.py | 2 + docs/examples/tutorial_phase_preferences.py | 186 +++++++++---------- docs/examples/tutorial_signal_processing.py | 2 + 3 files changed, 92 insertions(+), 98 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index f4a20b6d..94a74e97 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -17,6 +17,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 9 +# # Now, import the necessary libraries: import matplotlib.pyplot as plt diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 18e311c9..15d44dda 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -17,6 +17,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 6 +# # First, import the necessary libraries: import math @@ -93,10 +95,9 @@ ax.plot( REM_Tsd, label="REM LFP Data", - color="blue", ) ax.set_title("REM Local Field Potential") -ax.set_ylabel("LFP (v)") +ax.set_ylabel("LFP (a.u.)") ax.set_xlabel("time (s)") ax.margins(0) ax.legend() @@ -122,47 +123,34 @@ # Define wavelet decomposition plotting function -def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): - if np.iscomplexobj(powers): - powers = abs(powers) - ax.imshow(powers, aspect="auto", **kwargs) +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") - if isinstance(x_ticks, int): - x_tick_pos = np.linspace(0, times.size, x_ticks) - x_ticks = np.round(np.linspace(times[0], times[-1], x_ticks), 2) - else: - x_tick_pos = [np.argmin(np.abs(times - val)) for val in x_ticks] - ax.set(xticks=x_tick_pos, xticklabels=x_ticks) - y_ticks = [np.round(f, 2) for f in freqs] - y_ticks_pos = [np.argmin(np.abs(freqs - val)) for val in y_ticks] - ax.set(yticks=y_ticks_pos, yticklabels=y_ticks) + ax.get_xaxis().set_visible(False) + ax.set( + yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], + yticklabels=np.round(freqs, 2), + ) ax.grid(False) + return im -# And plot it fig = plt.figure(constrained_layout=True, figsize=(10, 6)) -axd = fig.subplot_mosaic( - [ - ["wd_rem"], - ["lfp_rem"], - ], - height_ratios=[1, 0.2], -) -plot_timefrequency( - REM_Tsd.index.values[:], - freqs[:], - np.transpose(mwt_REM[:, :].values), - ax=axd["wd_rem"], -) -axd["wd_rem"].set_title(f"Wavelet Decomposition") -axd["lfp_rem"].plot(REM_Tsd) -axd["lfp_rem"].margins(0) -axd["lfp_rem"].set_ylabel("LFP (v)") -axd["lfp_rem"].get_xaxis().set_visible(False) -for spine in ["top", "right", "bottom", "left"]: - axd["lfp_rem"].spines[spine].set_visible(False) +fig.suptitle("Wavelet Decomposition") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) + +ax0 = plt.subplot(gs[0, 0]) +im = plot_timefrequency(freqs[:], np.transpose(mwt_REM[:, :].values), ax=ax0) +cbar = fig.colorbar(im, ax=ax0, orientation="vertical") + +ax1 = plt.subplot(gs[1, 0]) +ax1.plot(REM_Tsd) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") +ax1.margins(0) + # %% # *** @@ -171,7 +159,7 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 7Hz component of the wavelet decomposition on top of our data, and see how well # they match up. We will also extract and plot the phase of the 7Hz wavelet from the decomposition. -theta_freq_index = np.argmin(np.abs(7 - freqs)) +theta_freq_index = np.argmin(np.abs(8 - freqs)) theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real # calculating phase here theta_band_phase = nap.Tsd( @@ -182,31 +170,29 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): # *** # Now let's plot the theta power and phase, along with the LFP. -fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["theta_pow"], - ["phase"], - ], - height_ratios=[0.4, 0.2], +fig, (ax1, ax2) = plt.subplots( + 2, 1, constrained_layout=True, figsize=(10, 5), height_ratios=[0.4, 0.2] ) -axd["theta_pow"].plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") -axd["theta_pow"].plot( +ax1.plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") +ax1.plot( REM_Tsd.index.values, theta_band_reconstruction, label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", ) -axd["theta_pow"].set_ylabel("LFP (v)") -axd["theta_pow"].set_xlabel("Time (s)") -axd["theta_pow"].set_title( - f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power." -) # -axd["theta_pow"].legend() -axd["phase"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) -[axd[k].margins(0) for k in ["theta_pow", "phase"]] -axd["phase"].set_ylabel("Phase") -axd["phase"].get_xaxis().set_visible(False) +ax1.set( + ylabel="LFP (v)", + title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", +) +ax1.get_xaxis().set_visible(False) +ax1.legend() +ax1.margins(0) + +ax2.plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) +ax2.set(ylabel="Phase", xlabel="Time (s)") +ax2.margins(0) + +plt.show() # %% @@ -232,13 +218,11 @@ def plot_timefrequency(times, freqs, powers, x_ticks=5, ax=None, **kwargs): def smoothAngularTuningCurves(tuning_curves, sigma=2): - tmp = np.concatenate( - (tuning_curves.values, tuning_curves.values, tuning_curves.values) - ) + tmp = np.concatenate([tuning_curves.values] * 3) tmp = scipy.ndimage.gaussian_filter1d(tmp, sigma=sigma, axis=0) return pd.DataFrame( + tmp[tuning_curves.shape[0] : 2 * tuning_curves.shape[0]], index=tuning_curves.index, - data=tmp[tuning_curves.shape[0] : tuning_curves.shape[0] * 2], columns=tuning_curves.columns, ) @@ -250,18 +234,18 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): figsize=(10, 6), subplot_kw={"projection": "polar"}, ) -for pl_i, sc_i in enumerate(list(smoothcurves)[:6]): - axd[f"phase_{pl_i}"].plot( - list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), - list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + +for i, unit in enumerate(list(smoothcurves)[:6]): + ax = axd[f"phase_{i}"] + ax.plot( + list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], + list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], ) - axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis - axd[f"phase_{pl_i}"].set_ylabel( - "Firing Rate (Hz)" - ) # Firing rate in Hz, on the Y-axis - axd[f"phase_{pl_i}"].set_xticks([]) - axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") + ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") + ax.set_xticks([]) + fig.suptitle("Phase Preference Histograms of First 6 Units") +plt.show() # %% @@ -297,18 +281,18 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): figsize=(10, 6), subplot_kw={"projection": "polar"}, ) -for pl_i, sc_i in enumerate(list(phase_var.keys())[:6]): - axd[f"phase_{pl_i}"].plot( - list(smoothcurves[sc_i].index) + list([smoothcurves[sc_i].index[0]]), - list(smoothcurves[sc_i].values) + list([smoothcurves[sc_i].values[0]]), + +for i, unit in enumerate(list(phase_var.keys())[:6]): + ax = axd[f"phase_{i}"] + ax.plot( + list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], + list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], ) - axd[f"phase_{pl_i}"].set_xlabel("Phase (rad)") # Angle in radian, on the X-axis - axd[f"phase_{pl_i}"].set_ylabel( - "Firing Rate (Hz)" - ) # Firing rate in Hz, on the Y-axis - axd[f"phase_{pl_i}"].set_xticks([]) - axd[f"phase_{pl_i}"].set_title(f"Unit {sc_i}") -fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference ") + ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") + ax.set_xticks([]) + +fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference") +plt.show() # %% # *** @@ -317,34 +301,40 @@ def smoothAngularTuningCurves(tuning_curves, sigma=2): # There is definitely some strong phase preferences happening here. Let's visualize the firing preferences # of the 6 cells we've isolated to get an impression of just how striking these preferences are. -fig = plt.figure(constrained_layout=True, figsize=(10, 8)) -axd = fig.subplot_mosaic( +fig, axd = plt.subplot_mosaic( [ ["lfp_run"], ["phase_0"], ["phase_1"], ["phase_2"], ], + constrained_layout=True, + figsize=(10, 8), height_ratios=[0.4, 0.2, 0.2, 0.2], ) -[axd[k].margins(0) for k in ["lfp_run"] + [f"phase_{i}" for i in range(3)]] -axd["lfp_run"].plot( - REM_Tsd.index.values, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM" -) + +REM_index = REM_Tsd.index.values +axd["lfp_run"].plot(REM_index, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM") axd["lfp_run"].plot( - REM_Tsd.index.values, + REM_index, theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index],2)}Hz oscillations", + label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", +) +axd["lfp_run"].set( + ylabel="LFP (v)", + xlabel="Time (s)", + title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", ) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"{np.round(freqs[theta_freq_index],2)}Hz oscillation power.") axd["lfp_run"].legend() +axd["lfp_run"].margins(0) + for i in range(3): - axd[f"phase_{i}"].plot(REM_Tsd.index.values, theta_band_phase, alpha=0.2) - axd[f"phase_{i}"].scatter( - spikes[list(phase_var.keys())[i]].index, phase[list(phase_var.keys())[i]] - ) - axd[f"phase_{i}"].set_ylabel("Phase") - axd[f"phase_{i}"].set_title(f"Unit {list(phase_var.keys())[i]}") + unit_key = list(phase_var.keys())[i] + ax = axd[f"phase_{i}"] + ax.plot(REM_index, theta_band_phase, alpha=0.2) + ax.scatter(spikes[unit_key].index, phase[unit_key]) + ax.set(ylabel="Phase", title=f"Unit {unit_key}") + ax.margins(0) + fig.suptitle("Phase Preference Visualizations") +plt.show() diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index 0e874085..0fbc8942 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -18,6 +18,8 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# mkdocs_gallery_thumbnail_number = 7 +# # First, import the necessary libraries: import math From 7fee0d610fa40afabf53b3d55b3a9cea7bec9d7e Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 6 Aug 2024 17:50:45 -0400 Subject: [PATCH 133/195] Update tutorial_signal_processing.py and tests --- docs/examples/tutorial_signal_processing.py | 287 +++++++++++--------- pynapple/process/signal_processing.py | 52 +++- tests/test_power_spectral_density.py | 116 +++++++- 3 files changed, 302 insertions(+), 153 deletions(-) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index ac3dc3c9..71086630 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- """ -Computing Wavelet Transform +Wavelet Transform ============ This tutorial demonstrates how we can use the signal processing tools within Pynapple to aid with data analysis. We will examine the dataset from [Grosmark & Buzsáki (2016)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4919122/). @@ -18,7 +18,10 @@ # # You can install all with `pip install matplotlib requests tqdm seaborn` # +# # First, import the necessary libraries: +# +# mkdocs_gallery_thumbnail_number = 6 import math import os @@ -118,41 +121,83 @@ # Let's take the Fourier transform of our data to get an initial insight into the dominant frequencies during exploration (`wake_ep`). -power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep) -print(power) +power = nap.compute_power_spectral_density(eeg, fs=FS, ep=wake_ep, norm=True) +print(power) # %% # *** # The returned object is a pandas dataframe which uses frequencies as indexes and spectral power as values. # # Let's plot the power between 1 and 100 Hz. -# -# The red area outlines the theta rhythm (6-12 Hz) which is proeminent in hippocampal LFP. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -ax.semilogy( - np.abs(power[(power.index > 1.0) & (power.index < 100)]), +ax.plot( + np.abs(power[(power.index >= 1.0) & (power.index <= 100)]), alpha=0.5, label="LFP Frequency Power", ) -ax.axvspan(6, 12, color="red", alpha=0.1) +ax.axvspan(6, 10, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") ax.set_ylabel("Frequency Power") ax.set_title("LFP Fourier Decomposition") ax.legend() + +# %% +# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. +# Hippocampal theta rhythm appears mostly when the animal is running. +# (See Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666.) +# We can check it here by separating `wake_ep` into `run_ep` and `rest_ep`. +run_ep = data['position'].dropna().find_support(1) +rest_ep = wake_ep.set_diff(run_ep) + +# %% +# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# The function `nap.compute_power_spectral_density` takes signal with a single epoch to avoid artefacts between epochs jumps. +# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. +# +# In this case, `interval_size` is equal to 1.5 seconds. + +power_run = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=run_ep, norm=True) +power_rest = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=rest_ep, norm=True) + +# %% +# `power_run` and `power_rest` are the power spectral density when the animal is respectively running and resting. + +fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) +ax.plot( + np.abs(power_run[(power_run.index >= 3.0) & (power_run.index <= 30)]), + alpha=1, + label="Run", + linewidth=2 +) +ax.plot( + np.abs(power_rest[(power_rest.index >= 3.0) & (power_rest.index <= 30)]), + alpha=1, + label="Rest", + linewidth=2 +) +ax.axvspan(6, 10, color="red", alpha=0.1) +ax.set_xlabel("Freq (Hz)") +ax.set_ylabel("Frequency Power") +ax.set_title("LFP Fourier Decomposition") +ax.legend() + + # %% # *** # Getting the Wavelet Decomposition # ----------------------------------- -# It looks like the prominent frequencies in the data may vary over time. For example, it looks like the -# LFP characteristics may be different while the animal is running along the track, and when it is finished. +# Overall, the prominent frequencies in the data vary over time. The LFP characteristics may be different when the animal is running along the track, and when it is finished. # Let's generate a wavelet decomposition to look more closely at the changing frequency powers over time. # We must define the frequency set that we'd like to use for our decomposition freqs = np.geomspace(3, 250, 100) + +# %% # Compute and print the wavelet transform on our LFP data + mwt_RUN = nap.compute_wavelet_transform(eeg_example, fs=FS, freqs=freqs) @@ -172,8 +217,9 @@ ax0.grid(False) ax0.set_yscale("log") ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") -ax0.set_label("Amplitude") +ax0.set_ylabel("Amplitude") ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) @@ -184,7 +230,6 @@ ax1.set_xlabel("Time (s)") ax1.set_ylabel("Pos.") -plt.show() # %% @@ -192,14 +237,14 @@ # Visualizing Theta Band Power # ----------------------------------- # There seems to be a strong theta frequency present in the data during the maze traversal. -# Let's plot the estimated 8Hz component of the wavelet decomposition on top of our data, and see how well -# they match up +# Let's plot the estimated 6-10Hz component of the wavelet decomposition on top of our data, and see how well they match up. + +theta_freq_index = np.logical_and(freqs>6, freqs<10) + -# Find the index of the frequency closest to theta band -theta_freq_index = np.argmin(np.abs(8 - freqs)) # Extract its real component, as well as its power envelope -theta_band_reconstruction = mwt_RUN[:, theta_freq_index].values.real -theta_band_power_envelope = np.abs(mwt_RUN[:, theta_freq_index].values) +theta_band_reconstruction = np.mean(mwt_RUN[:,theta_freq_index], 1) +theta_band_power_envelope = np.abs(theta_band_reconstruction) # %% @@ -207,139 +252,111 @@ # Now let's visualise the theta band component of the signal over time. fig = plt.figure(constrained_layout=True, figsize=(10, 6)) -axd = fig.subplot_mosaic( - [["ephys"], ["pos"]], - height_ratios=[1, 0.4], -) -axd["ephys"].plot(eeg_example, label="CA1") -axd["ephys"].plot( - eeg_example.index.values, - theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", -) -axd["ephys"].plot( - eeg_example.index.values, - theta_band_power_envelope, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz power envelope", -) -axd["ephys"].set_title("EEG (1250 Hz)") -axd["ephys"].set_ylabel("LFP (a.u.)") -axd["ephys"].set_xlabel("time (s)") -axd["ephys"].margins(0) -axd["ephys"].legend() -axd["pos"].plot(pos_example, color="black") -axd["pos"].margins(0) -axd["pos"].set_xlabel("time (s)") -axd["pos"].set_ylabel("Linearized Position") -axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) - +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.9]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example, label="CA1") +ax0.set_title("EEG (1250 Hz)") +ax0.set_ylabel("LFP (a.u.)") +ax0.set_xlabel("time (s)") +ax0.legend() +ax1 = plt.subplot(gs[1,0]) +ax1.plot(np.real(theta_band_reconstruction), label="6-10 Hz oscillations") +ax1.plot(theta_band_power_envelope, label="6-10 Hz power envelope") +ax1.set_xlabel("time (s)") +ax1.set_ylabel("Wavelet transform") +ax1.legend() # %% # *** -# Visualizing Sharp Wave Ripple Power +# Visualizing high frequency oscillation # ----------------------------------- -# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. -# Let's plot the LFP, along with the 200Hz frequency power, to see if we can isolate these peaks and -# see what's going on. +# There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. Here we use the interval (18356, 18357.5) seconds to zoom in. -# Find the index of the frequency closest to sharp wave ripple oscillations -ripple_freq_idx = np.argmin(np.abs(200 - freqs)) -# Extract its power envelope -ripple_power = np.abs(mwt_RUN[:, ripple_freq_idx].values) +zoom_ep = nap.IntervalSet(18356.0, 18357.5) +mwt_zoom = mwt_RUN.restrict(zoom_ep) + +fig = plt.figure(constrained_layout=True, figsize=(10, 6)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0, 0]) +pcmesh = ax0.pcolormesh(mwt_zoom.t, freqs, np.transpose(np.abs(mwt_zoom))) +ax0.grid(False) +ax0.set_yscale("log") +ax0.set_title("Wavelet Decomposition") +ax0.set_ylabel("Frequency (Hz)") +cbar = plt.colorbar(pcmesh, ax=ax0, orientation="vertical") +ax0.set_label("Amplitude") + +ax1 = plt.subplot(gs[1, 0], sharex=ax0) +ax1.plot(eeg_example.restrict(zoom_ep)) +ax1.set_ylabel("LFP (a.u.)") +ax1.set_xlabel("Time (s)") # %% -# *** -# Now let's visualise the 200Hz component of the signal over time. +# Those events are called Sharp-waves ripples (See : Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188.) +# +# Among other methods, we can use the Wavelet decomposition to isolate them. In this case, we will look at the power of the wavelets for frequencies between 150 to 250 Hz. -fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["rip_pow"], - ], - height_ratios=[1, 0.4], -) -axd["lfp_run"].plot(eeg_example, label="LFP Data") -axd["rip_pow"].plot(eeg_example.index.values, ripple_power) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].margins(0) -axd["lfp_run"].set_title(f"EEG (1250 Hz)") -axd["rip_pow"].margins(0) -axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") +ripple_freq_index = np.logical_and(freqs>150, freqs<250) # %% -# *** -# Isolating Ripple Times -# ----------------------------------- -# We can see one significant peak in the 200Hz frequency power. Let's smooth the power curve and threshold -# to try to isolate this event. - -# Define threshold -threshold = 6000 -# Smooth wavelet power TsdFrame at the SWR frequency -smoother_swr_power = ( - mwt_RUN[:, ripple_freq_idx] - .abs() - .smooth(std=0.025, windowsize=0.2, time_units="s", norm=False) -) -# Threshold our TsdFrame -is_ripple = smoother_swr_power.threshold(threshold) +# We can compute the mean power for this frequency band. + +ripple_power = np.mean(np.abs(mwt_RUN[:, ripple_freq_index]), 1) # %% -# *** -# Now let's plot the threshold ripple power over time. +# Now let's visualise the 150-250 Hz mean amplitude of the wavelet decomposition over time fig = plt.figure(constrained_layout=True, figsize=(10, 5)) -axd = fig.subplot_mosaic( - [ - ["lfp_run"], - ["rip_pow"], - ], - height_ratios=[1, 0.4], -) -axd["lfp_run"].plot(eeg_example, label="LFP Data") -axd["rip_pow"].plot(smoother_swr_power) -axd["rip_pow"].axvspan( - is_ripple.index.min(), is_ripple.index.max(), color="red", alpha=0.3 -) -axd["lfp_run"].set_ylabel("LFP (v)") -axd["lfp_run"].set_xlabel("Time (s)") -axd["lfp_run"].set_title(f"EEG (1250 Hz)") -axd["rip_pow"].axhline(threshold, linestyle="--", color="black", alpha=0.4) -[axd[k].margins(0) for k in ["lfp_run", "rip_pow"]] -axd["rip_pow"].set_xlim(eeg_example.index.min(), eeg_example.index.max()) -axd["rip_pow"].set_ylabel(f"{np.round(freqs[ripple_freq_idx], 2)}Hz Power") +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +ax0.set_ylabel("LFP (v)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1,0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + # %% -# *** -# Plotting a Sharp Wave Ripple -# ----------------------------------- -# Let's zoom in on out detected ripples and have a closer look! +# It is then easy to isolate ripple times by using the pynapple functions `smooth` and `threshold`. In the following lines, `ripples` is smoothed with a gaussian kernel of size 0.005 second and thesholded with a value of 100. +# + +smoothed_ripple_power = ripple_power.smooth(0.005) + +threshold_ripple_power = smoothed_ripple_power.threshold(100) + +# %% +# `threshold_ripple_power` contains all the time points above 100. The ripple epochs are contained in the `time_support` of the threshold time series. Here we call it `rip_ep`. + +rip_ep = threshold_ripple_power.time_support + + +# %% +# Now let's plot the ripples epoch as well as the smoothed ripple power. +# +# We can also plot `rip_ep` as vertical boxes to see if the detection is accurate + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) +ax0 = plt.subplot(gs[0,0]) +ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") +for s,e in rip_ep.intersect(zoom_ep).values: + ax0.axvspan(s, e, color='red', alpha=0.1, ec=None) +ax0.set_ylabel("LFP (v)") +ax0.set_title(f"EEG (1250 Hz)") +ax1 = plt.subplot(gs[1,0]) +ax1.legend() +ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") +ax1.plot(smoothed_ripple_power.restrict(zoom_ep)) +for s,e in rip_ep.intersect(zoom_ep).values: + ax1.axvspan(s, e, color='red', alpha=0.1, ec=None) +ax1.legend() +ax1.set_ylabel("Mean Amplitude") +ax1.set_xlabel("Time (s)") + -fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) -buffer = 0.1 -ax.plot( - eeg_example.restrict( - nap.IntervalSet( - start=is_ripple.index.min() - buffer, end=is_ripple.index.max() + buffer - ) - ), - color="blue", - label="Non-SWR LFP", -) -ax.axvspan( - is_ripple.index.min(), - is_ripple.index.max(), - color="red", - alpha=0.3, - label="SWR LFP", -) -ax.margins(0) -ax.set_xlabel("Time (s)") -ax.set_ylabel("LFP (v)") -ax.legend() -ax.set_title("Sharp Wave Ripple Visualization") diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 7f921402..722abafa 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -8,11 +8,12 @@ import numpy as np import pandas as pd +from scipy import signal from .. import core as nap -def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): +def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False): """ Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. @@ -26,6 +27,8 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude Returns ------- @@ -39,9 +42,9 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): parameter otherwise will be sig.time_support, but it must only be a single epoch. """ if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError( - "Currently compute_spectogram is only implemented for Tsd or TsdFrame" - ) + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: @@ -50,22 +53,40 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False): raise ValueError("Given epoch (or signal time_support) must have length 1") if fs is None: fs = sig.rate + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) + + if norm: + fft_result = fft_result / fft_result.shape[0] + ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) + if not full_range: return ret.loc[ret.index >= 0] return ret def compute_mean_power_spectral_density( - sig, interval_size, fs=None, ep=None, full_range=False, time_unit="s" + sig, + interval_size, + fs=None, + ep=None, + full_range=False, + norm=False, + time_unit="s", ): """Compute mean power spectral density by averaging FFT over epochs of same size. The parameter `interval_size` controls the duration of the epochs. - Note that this function assumes a constant sampling rate for sig. + To imporve frequency resolution, the signal is multiplied by a Hamming window. + + Note that this function assumes a constant sampling rate for `sig`. Parameters ---------- @@ -79,6 +100,8 @@ def compute_mean_power_spectral_density( The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude time_unit : str, optional Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') @@ -94,6 +117,9 @@ def compute_mean_power_spectral_density( TypeError If `ep` or `sig` are not respectively pynapple time series or interval set. """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (ep is None or isinstance(ep, nap.IntervalSet)): raise TypeError("ep param must be a pynapple IntervalSet object, or None") if ep is None: @@ -107,6 +133,9 @@ def compute_mean_power_spectral_density( if not isinstance(full_range, bool): raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + # Split the ep interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ 0 @@ -137,11 +166,20 @@ def compute_mean_power_spectral_density( # Get the freqs fft_freq = np.fft.fftfreq(N, 1 / fs) + # Get the Hamming window + window = signal.windows.hamming(N) + if sig.ndim == 2: + window = window[:, np.newaxis] + # Compute the fft fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) for i in range(len(slices)): - fft_result += np.fft.fft(sig[slices[i, 0] : slices[i, 1]].values[0:N], axis=0) + tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window + fft_result += np.fft.fft(tmp, axis=0) + + if norm: + fft_result = fft_result / (float(N) * float(len(slices))) ret = pd.DataFrame(fft_result, fft_freq) ret.sort_index(inplace=True) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 7d0ec64f..dbda75cf 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -4,13 +4,22 @@ import numpy as np import pandas as pd import pytest +from scipy import signal import pynapple as nap + ############################################################ -# Test for mean_power_spectral_density +# Test for power_spectral_density ############################################################ +def get_sorted_fft(data,fs): + fft = np.fft.fft(data, axis=0) + fft_freq = np.fft.fftfreq(len(data), 1 / fs) + order = np.argsort(fft_freq) + if fft.ndim==1: + fft = fft[:,np.newaxis] + return fft_freq[order], fft[order] def test_compute_power_spectral_density(): @@ -20,16 +29,31 @@ def test_compute_power_spectral_density(): assert isinstance(r, pd.DataFrame) assert r.shape[0] == 500 + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + + r = nap.compute_power_spectral_density(sig, norm=True) + np.testing.assert_array_almost_equal(r.values, b[a>=0]/len(sig)) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_power_spectral_density(sig) assert isinstance(r, pd.DataFrame) assert r.shape == (500, 4) + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a[a>=0]) + np.testing.assert_array_almost_equal(r.values, b[a>=0]) + sig = nap.TsdFrame(d=np.random.random((1000, 4)), t=t) r = nap.compute_power_spectral_density(sig, full_range=True) assert isinstance(r, pd.DataFrame) assert r.shape == (1000, 4) + a, b = get_sorted_fft(sig.values, sig.rate) + np.testing.assert_array_almost_equal(r.index.values, a) + np.testing.assert_array_almost_equal(r.values, b) + t = np.linspace(0, 1, 1000) sig = nap.Tsd(d=np.random.random(1000), t=t) r = nap.compute_power_spectral_density(sig, ep=sig.time_support) @@ -44,7 +68,7 @@ def test_compute_power_spectral_density(): @pytest.mark.parametrize( - "sig, fs, ep, full_range, expectation", + "sig, fs, ep, full_range, norm, expectation", [ ( nap.Tsd( @@ -55,6 +79,7 @@ def test_compute_power_spectral_density(): 1000, None, False, + False, pytest.raises( ValueError, match=re.escape( @@ -67,28 +92,63 @@ def test_compute_power_spectral_density(): 1000, "not_ep", False, + False, pytest.raises( TypeError, match="ep param must be a pynapple IntervalSet object, or None", ), ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + "a", + None, + False, + False, + pytest.raises( + TypeError, + match="fs must be of type float or int", + ), + ), ( "not_a_tsd", 1000, None, False, + False, + pytest.raises( + TypeError, + match="sig must be either a Tsd or a TsdFrame object.", + ), + ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + "a", + False, pytest.raises( TypeError, - match="Currently compute_spectogram is only implemented for Tsd or TsdFrame", + match="full_range must be of type bool or None", ), ), + ( + nap.Tsd(d=np.random.random(1000), t=np.linspace(0, 1, 1000)), + 1000, + None, + False, + "a", + pytest.raises( + TypeError, + match="norm must be of type bool", + ), + ), ], ) def test_compute_power_spectral_density_raise_errors( - sig, fs, ep, full_range, expectation + sig, fs, ep, full_range, norm, expectation ): with expectation: - psd = nap.compute_power_spectral_density(sig, fs, ep, full_range) + psd = nap.compute_power_spectral_density(sig, fs, ep, full_range, norm) ############################################################ @@ -102,6 +162,7 @@ def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T tmp = tmp[0:-1] + tmp = tmp*signal.windows.hamming(tmp.shape[0])[:,np.newaxis] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) order = np.argsort(freq) @@ -125,6 +186,14 @@ def test_compute_mean_power_spectral_density(): np.testing.assert_array_almost_equal(psd.values.flatten(), out) np.testing.assert_array_almost_equal(psd.index.values, freq) + # Norm + psd = nap.compute_mean_power_spectral_density(sig, 10, norm=True) + assert isinstance(psd, pd.DataFrame) + assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(9999.0*10.0)) + np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) + + # TsdFrame sig2 = nap.TsdFrame( t=sig.t, d=np.repeat(sig.values[:, None], 2, 1), time_support=sig.time_support @@ -147,15 +216,26 @@ def test_compute_mean_power_spectral_density(): @pytest.mark.parametrize( - "sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation", + "sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation", [ - (*get_signal_and_output(), 10, None, None, False, "s", does_not_raise()), + (*get_signal_and_output(), 10, None, None, False, False, "s", does_not_raise()), + ( + "a", *get_signal_and_output()[1:], + 10, + None, + None, + False, + False, + "s", + pytest.raises(TypeError, match="sig must be either a Tsd or a TsdFrame object."), + ), ( *get_signal_and_output(), 10, "a", None, False, + False, "s", pytest.raises(TypeError, match="fs must be of type float or int"), ), @@ -165,6 +245,7 @@ def test_compute_mean_power_spectral_density(): None, "a", False, + False, "s", pytest.raises( TypeError, @@ -177,17 +258,29 @@ def test_compute_mean_power_spectral_density(): None, None, "a", + False, "s", pytest.raises(TypeError, match="full_range must be of type bool or None"), ), - (*get_signal_and_output(), 10 * 1e3, None, None, False, "ms", does_not_raise()), - (*get_signal_and_output(), 10 * 1e6, None, None, False, "us", does_not_raise()), + ( + *get_signal_and_output(), + 10, + None, + None, + None, + "a", + "s", + pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + (*get_signal_and_output(), 10 * 1e3, None, None, False, False, "ms", does_not_raise()), + (*get_signal_and_output(), 10 * 1e6, None, None, False, False, "us", does_not_raise()), ( *get_signal_and_output(), 200, None, None, False, + False, "s", pytest.raises( RuntimeError, @@ -200,6 +293,7 @@ def test_compute_mean_power_spectral_density(): None, nap.IntervalSet([0, 200], [100, 300]), False, + False, "s", pytest.raises( RuntimeError, @@ -209,9 +303,9 @@ def test_compute_mean_power_spectral_density(): ], ) def test_compute_mean_power_spectral_density_raise_errors( - sig, out, freq, interval_size, fs, ep, full_range, time_units, expectation + sig, out, freq, interval_size, fs, ep, full_range, norm, time_units, expectation ): with expectation: psd = nap.compute_mean_power_spectral_density( - sig, interval_size, fs, ep, full_range, time_units + sig, interval_size, fs, ep, full_range, norm, time_units ) From 265be7f78a780f523fcbc7a0460ba80f1bd85f13 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 10:09:05 -0400 Subject: [PATCH 134/195] Missing test for sig processing --- tests/test_power_spectral_density.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index dbda75cf..626b0832 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -265,12 +265,22 @@ def test_compute_mean_power_spectral_density(): ( *get_signal_and_output(), 10, - None, - None, - None, - "a", - "s", + None, # FS + None, # Ep + "a", # full_range + False, # Norm + "s", # Time units pytest.raises(TypeError, match="full_range must be of type bool or None"), + ), + ( + *get_signal_and_output(), + 10, + None, # FS + None, # Ep + False, # full_range + "a", # Norm + "s", # Time units + pytest.raises(TypeError, match="norm must be of type bool"), ), (*get_signal_and_output(), 10 * 1e3, None, None, False, False, "ms", does_not_raise()), (*get_signal_and_output(), 10 * 1e6, None, None, False, False, "us", does_not_raise()), From 88857147a1ecc0adc0bc45e7406d6904d59eefe0 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 12:28:23 -0400 Subject: [PATCH 135/195] More update on wavelets --- docs/api_guide/tutorial_pynapple_spectrum.py | 23 ++--- docs/api_guide/tutorial_pynapple_wavelets.py | 98 ++++++++++---------- 2 files changed, 63 insertions(+), 58 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_spectrum.py b/docs/api_guide/tutorial_pynapple_spectrum.py index bcb64c50..39db2432 100644 --- a/docs/api_guide/tutorial_pynapple_spectrum.py +++ b/docs/api_guide/tutorial_pynapple_spectrum.py @@ -35,11 +35,11 @@ F = [2, 10] Fs = 2000 -t = np.arange(0, 100, 1/Fs) +t = np.arange(0, 200, 1/Fs) sig = nap.Tsd( t=t, d=np.cos(t*2*np.pi*F[0])+np.cos(t*2*np.pi*F[1])+np.random.normal(0, 3, len(t)), - time_support = nap.IntervalSet(0, 100) + time_support = nap.IntervalSet(0, 200) ) # %% @@ -55,9 +55,9 @@ # Computing power spectral density (PSD) # -------------------------------------- # -# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density` +# To compute a PSD of a signal, you can use the function `nap.compute_power_spectral_density`. With `norm=True`, the output of the FFT is divided by the length of the signal. -psd = nap.compute_power_spectral_density(sig) +psd = nap.compute_power_spectral_density(sig, norm=True) # %% # Pynapple returns a pandas DataFrame. @@ -68,7 +68,7 @@ # It is then easy to plot it. plt.figure() -plt.plot(psd) +plt.plot(np.abs(psd)) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") @@ -79,7 +79,7 @@ # Let's zoom on the first 20 Hz. plt.figure() -plt.plot(psd) +plt.plot(np.abs(psd)) plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 20) @@ -108,18 +108,19 @@ # # In this case, the FFT will be computed over epochs of 10 seconds. -mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=10.0) +mean_psd = nap.compute_mean_power_spectral_density(sig, interval_size=20.0, norm=True) # %% -# Let's compare `mean_psd` to `psd`. +# Let's compare `mean_psd` to `psd`. In both cases, the ouput is normalized. plt.figure() -plt.plot(psd) -plt.plot(mean_psd) +plt.plot(np.abs(psd), label='PSD') +plt.plot(np.abs(mean_psd), label='Mean PSD (10s)') plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") -plt.xlim(0, 20) +plt.legend() +plt.xlim(0, 15) # %% # As we can see, `nap.compute_mean_power_spectral_density` was able to smooth out the noise. diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 94a74e97..016d7cc0 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -1,9 +1,11 @@ # -*- coding: utf-8 -*- """ -Wavelet API tutorial -============ +Wavelet Transform +================= -Working with Wavelets! +This tutorial covers the use of `nap.compute_wavelet_transform` to do continuous wavelet transform. By default, pynapple uses Morlet wavelets. + +The function `nap.generate_morlet_filterbank` can help parametrize and visualize the Morlet wavelets. See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. @@ -80,15 +82,16 @@ # %% -# Lets plot it. +# Lets plot it some of the wavelets. + def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + 1.5 * f_i) - ax.text(-2.3, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") - ax.margins(0) - ax.yaxis.set_visible(False) - ax.set_xlim(-2, 2) + ax.plot(filter_bank[:, f_i].real() + f_i*1.5) + ax.text(-5.5, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") + + ax.set_yticks([]) + ax.set_xlim(-5, 5) ax.set_xlabel("Time (s)") ax.set_title(title) @@ -98,23 +101,28 @@ def plot_filterbank(filter_bank, freqs, title): # %% # *** -# Decomposing the Dummy Signal -# ------------------ +# Continuous wavelet transform +# ---------------------------- # Here we will use the `compute_wavelet_transform` function to decompose our signal using the filter bank shown # above. Wavelet decomposition breaks down a signal into its constituent wavelets, capturing both time and # frequency information for analysis. We will calculate this decomposition and plot it's corresponding -# scalogram. +# scalogram (which is another name for time frequency decomposition using wavelets). # Compute the wavelet transform using the parameters above mwt = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=1.5, window_length=1.0 ) +# %% +# `mwt` for Morlet wavelet transform is a `TsdFrame`. Each column is the result of the convolution of the signal with one wavelet. + +print(mwt) # %% # Lets plot it. + def plot_timefrequency(freqs, powers, ax=None): - im = ax.imshow(abs(powers), aspect="auto") + im = ax.imshow(np.abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") @@ -134,7 +142,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -149,9 +157,9 @@ def plot_timefrequency(freqs, powers, ax=None): # Get the index of the 2Hz frequency two_hz_freq_idx = np.where(freqs == 2.0)[0] # The 2Hz component is the real component of the wavelet decomposition at this index -slow_oscillation = mwt[:, two_hz_freq_idx].values.real +slow_oscillation = np.real(mwt[:, two_hz_freq_idx]) # The 2Hz wavelet phase is the angle of the wavelet decomposition at this index -slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx].values) +slow_oscillation_phase = np.angle(mwt[:, two_hz_freq_idx]) # %% # Lets plot it. @@ -161,11 +169,11 @@ def plot_timefrequency(freqs, powers, ax=None): height_ratios=[1, 0.4], ) axd["signal"].plot(sig, label="Raw Signal", alpha=0.5) -axd["signal"].plot(t, slow_oscillation, label="2Hz Reconstruction") +axd["signal"].plot(slow_oscillation, label="2Hz Reconstruction") axd["signal"].legend() axd["signal"].set_ylabel("Signal") -axd["phase"].plot(t, slow_oscillation_phase, alpha=0.5) +axd["phase"].plot(slow_oscillation_phase, alpha=0.5) axd["phase"].set_ylabel("Phase (rad)") axd["phase"].set_xlabel("Time (s)") [axd[k].margins(0) for k in ["signal", "phase"]] @@ -184,9 +192,9 @@ def plot_timefrequency(freqs, powers, ax=None): # Get the index of the 15 Hz frequency fifteen_hz_freq_idx = np.where(freqs == 15.0)[0] # The 15 Hz component is the real component of the wavelet decomposition at this index -fifteenHz_oscillation = mwt[:, fifteen_hz_freq_idx].values.real +fifteenHz_oscillation = np.real(mwt[:, fifteen_hz_freq_idx]) # The 15 Hz poser is the absolute value of the wavelet decomposition at this index -fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx].values) +fifteenHz_oscillation_power = np.abs(mwt[:, fifteen_hz_freq_idx]) # %% # Lets plot it. @@ -195,13 +203,13 @@ def plot_timefrequency(freqs, powers, ax=None): gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 1.0]) ax0 = plt.subplot(gs[0, 0]) -ax0.plot(t, fifteenHz_oscillation, label="15Hz Reconstruction") -ax0.plot(t, fifteenHz_oscillation_power, label="15Hz Power") +ax0.plot(fifteenHz_oscillation, label="15Hz Reconstruction") +ax0.plot(fifteenHz_oscillation_power, label="15Hz Power") ax0.set_xticklabels([]) ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig, label="Raw Signal", alpha=0.5) -ax1.plot(t, slow_oscillation + fifteenHz_oscillation, label="2Hz + 15Hz Reconstruction") +ax1.plot(slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction") ax1.set_xlabel("Time (s)") [ @@ -216,13 +224,13 @@ def plot_timefrequency(freqs, powers, ax=None): # ------------------ # Let's now add together the real components of all frequency bands to recreate a version of the original signal. -combined_oscillations = mwt.sum(axis=1).real() +combined_oscillations = np.real(np.sum(mwt, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -277,19 +285,19 @@ def plot_timefrequency(freqs, powers, ax=None): ) # %% -# Increasing window_length increases the number of wavelet cycles present in the oscillations (cycles), and +# Increasing `window_length` increases the number of wavelet cycles present in the oscillations (cycles), and # correspondingly increases the time window that the wavelet covers. # -# The gaussian_width parameter determines the shape of the gaussian window being convolved with the sinusoidal +# The `gaussian_width` parameter determines the shape of the gaussian window being convolved with the sinusoidal # component of the wavelet # # Both of these parameters can be tweaked to control for the trade-off between time resolution and frequency resolution. # %% # *** -# Effect of gaussian_width +# Effect of `gaussian_width` # ------------------ -# Let's increase gaussian_width to 7.5 and see the effect on the resultant filter bank. +# Let's increase `gaussian_width` to 7.5 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -319,7 +327,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -348,9 +356,9 @@ def plot_timefrequency(freqs, powers, ax=None): # %% # *** -# Effect of window_length +# Effect of `window_length` # ------------------ -# Let's increase window_length to 2.0 and see the effect on the resultant filter bank. +# Let's increase `window_length` to 2.0 and see the effect on the resultant filter bank. freqs = np.linspace(1, 25, num=25) filter_bank = nap.generate_morlet_filterbank( @@ -380,7 +388,7 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -388,13 +396,13 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # And let's see if that has an effect on the reconstructed version of the signal -combined_oscillations = mwt.sum(axis=1).real() +combined_oscillations = np.real(np.sum(mwt, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, alpha=0.5, label="Signal") -ax.plot(t, combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) +ax.plot(combined_oscillations, label="Wavelet Reconstruction", alpha=0.5) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") @@ -407,7 +415,7 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # Effect of L1 vs L2 normalization # ------------------ -# compute_wavelet_transform contains two options for normalization; L1, and L2. L1 normalization. +# `compute_wavelet_transform` contains two options for normalization; L1, and L2. # By default, L1 is used as it creates cleaner looking decomposition images. # # L1 normalization often increases the contrast between significant and insignificant coefficients. @@ -417,7 +425,7 @@ def plot_timefrequency(freqs, powers, ax=None): # L2 normalization is directly related to the energy of the signal. By normalizing using the # L2 norm, you ensure that the transformed coefficients preserve the energy distribution of the original signal. # -# Let's compare two wavelet decomposition images, each generated with a different normalization strategy +# Let's compare two wavelet decomposition, each generated with a different normalization strategy mwt_l1 = nap.compute_wavelet_transform( sig, fs=Fs, freqs=freqs, gaussian_width=7.5, window_length=2.0, norm="l1" @@ -437,7 +445,7 @@ def plot_timefrequency(freqs, powers, ax=None): cbar = fig.colorbar(im, ax=ax0, orientation="vertical") ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -449,7 +457,7 @@ def plot_timefrequency(freqs, powers, ax=None): cbar = fig.colorbar(im, ax=ax0, orientation="vertical") ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig) -ax1.set_ylabel("LFP (a.u.)") +ax1.set_ylabel("Signal") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -460,19 +468,15 @@ def plot_timefrequency(freqs, powers, ax=None): # %% -combined_oscillations_l1 = mwt_l1.sum(axis=1).real() -combined_oscillations_l2 = mwt_l2.sum(axis=1).real() +combined_oscillations_l1 = np.real(np.sum(mwt_l1, axis=1)) +combined_oscillations_l2 = np.real(np.sum(mwt_l2, axis=1)) # %% # Lets plot it. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 4)) ax.plot(sig, label="Signal", linewidth=3, alpha=0.6, c="b") -ax.plot( - t, combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6 -) -ax.plot( - t, combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6 -) +ax.plot(combined_oscillations_l1, label="Wavelet Reconstruction (L1)", c="g", alpha=0.6) +ax.plot(combined_oscillations_l2, label="Wavelet Reconstruction (L2)", c="r", alpha=0.6) ax.set_xlabel("Time (s)") ax.set_ylabel("Signal") ax.set_title("Wavelet Reconstruction of Signal") From 2793e1cd837d327cdf57c36c3186974214f43a7b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 12:32:24 -0400 Subject: [PATCH 136/195] change title --- docs/examples/tutorial_phase_preferences.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 15d44dda..0abb753b 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Computing Phase Preferences -============ +Spikes-phase coupling +===================== In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it with spiking data, to find phase preferences of spiking units. From ebe5e76d055f59872a8d029559076c25cd7f8413 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 17:18:21 -0400 Subject: [PATCH 137/195] improved filtering --- pynapple/process/filtering.py | 108 ++++++++++++++++++---------------- 1 file changed, 57 insertions(+), 51 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 4ea1b667..1cb7afd4 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -1,67 +1,73 @@ -""" - Filtering module -""" +"""Filtering module.""" import numpy as np from .. import core as nap -from scipy.signal import butter, lfilter, filtfilt +from scipy.signal import butter, filtfilt +from numbers import Number -def _butter_bandpass(lowcut, highcut, fs, order=5): - nyq = 0.5 * fs - low = lowcut / nyq - high = highcut / nyq - b, a = butter(order, [low, high], btype='band') - return b, a +def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None): + """ + Apply a Butterworth filter to the provided signal data. -def _butter_bandpass_filter(data, lowcut, highcut, fs, order=4): - b, a = _butter_bandpass(lowcut, highcut, fs, order=order) - y = lfilter(b, a, data) - return y + This function performs bandpass filtering on Local Field Potential (LFP) + data using a Butterworth filter. The filter can be configured to be of + type "bandpass", "bandstop", "highpass", or "lowpass". -def compute_bandpass_filter(data, freq_band, sampling_frequency=None, order=4): - """ - Bandpass filtering the LFP. - Parameters ---------- - data : Tsd/TsdFrame - Description - lowcut : TYPE - Description - highcut : TYPE - Description - fs : TYPE - Description + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + freq_band : tuple of (float, float) or float + Cutoff frequency(ies) in Hz. + - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. + - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. + filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional + The type of frequency filter to apply. Default is "bandpass". order : int, optional - Description - + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + sampling_frequency : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + Raises ------ - RuntimeError - Description + ValueError + If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. + If `freq_band` is not a float for "lowpass" and "highpass" filters. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + + Notes + ----- + The cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). """ - time_support = data.time_support - time_index = data.as_units('s').index.values - if type(data) is nap.TsdFrame: - tmp = np.zeros(data.shape) - for i in np.arange(data.shape[1]): - tmp[:,i] = bandpass_filter(data[:,i], lowcut, highcut, fs, order) + if sampling_frequency is None: + sampling_frequency = data.rate + + if filter_type not in ["lowpass", "highpass", "bandpass", "bandstop"]: + raise ValueError(f"Unrecognized filter type {filter_type}. " + "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'.") + if filter_type in ["lowpass", "highpass"] and isinstance(freq_band, Number): + raise ValueError("Must provide a single float for specifying a 'highpass' and 'lowpass' filters. " + f"{freq_band} provided instead!") + elif filter_type in ["bandpass", "bandstop"] and len(tuple(freq_band)) != 2: + raise ValueError("Must provide a two floats for specifying a 'bandpass' and 'bandstop' filters. " + f"{freq_band} provided instead!") - return nap.TsdFrame( - t = time_index, - d = tmp, - time_support = time_support, - time_units = 's', - columns = data.columns) + b, a = butter(order, freq_band, btype=filter_type, fs=sampling_frequency) - elif type(data) is nap.Tsd: - flfp = _butter_bandpass_filter(data.values, lowcut, highcut, fs, order) - return nap.Tsd( - t=time_index, - d=flfp, - time_support=time_support, - time_units='s') + out = np.zeros_like(data.d) + for ep in data.time_support: + slc = data.get_slice(start=ep.start[0], end=ep.end[0]) + out[slc] = filtfilt(b, a, data.d[slc], axis=0) - else: - raise RuntimeError("Unknow format. Should be Tsd/TsdFrame") \ No newline at end of file + kwargs = dict(t=data.t, d=out, time_support=data.time_support) + if isinstance(data, nap.TsdFrame): + kwargs["columns"] = data.columns + return data.__class__(**kwargs) From bc20031a432d51e8699e64c54812e6f1be53803a Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 17:53:11 -0400 Subject: [PATCH 138/195] added regressoin test --- pynapple/process/__init__.py | 2 ++ pynapple/process/filtering.py | 2 +- tests/test_filtering.py | 36 +++++++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 tests/test_filtering.py diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 2e1af412..ddc5bffa 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -24,3 +24,5 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) + +from .filtering import compute_filtered_signal diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 1cb7afd4..aff10347 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -53,7 +53,7 @@ def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sa if filter_type not in ["lowpass", "highpass", "bandpass", "bandstop"]: raise ValueError(f"Unrecognized filter type {filter_type}. " "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'.") - if filter_type in ["lowpass", "highpass"] and isinstance(freq_band, Number): + elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number): raise ValueError("Must provide a single float for specifying a 'highpass' and 'lowpass' filters. " f"{freq_band} provided instead!") elif filter_type in ["bandpass", "bandstop"] and len(tuple(freq_band)) != 2: diff --git a/tests/test_filtering.py b/tests/test_filtering.py new file mode 100644 index 00000000..ebf0a204 --- /dev/null +++ b/tests/test_filtering.py @@ -0,0 +1,36 @@ +import pytest +import pynapple as nap +import numpy as np +from scipy import signal + + +@pytest.mark.parametrize("freq", [10, 100]) +@pytest.mark.parametrize("order", [2, 4, 6]) +@pytest.mark.parametrize("btype", ["lowpass", "highpass"]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize( + "ep", + [ + nap.IntervalSet(start=[0], end=[1]), + nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), + nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) + ] +) +def test_filtering_single_freq(freq, order, btype, shape: tuple, ep): + + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + b, a = signal.butter(order, freq, fs=tsd.rate, btype=btype) + out_sci = [] + for iset in ep: + out_sci.append(signal.filtfilt(b, a, tsd.restrict(iset).d, axis=0)) + out_sci = np.concatenate(out_sci, axis=0) + np.testing.assert_array_equal(out.d, out_sci) From 77dab5fa3838379c73508227b952094ae24f6510 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 18:03:29 -0400 Subject: [PATCH 139/195] added unit testing --- pynapple/process/filtering.py | 10 +++++++--- tests/test_filtering.py | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index aff10347..246ba801 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -56,9 +56,13 @@ def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sa elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number): raise ValueError("Must provide a single float for specifying a 'highpass' and 'lowpass' filters. " f"{freq_band} provided instead!") - elif filter_type in ["bandpass", "bandstop"] and len(tuple(freq_band)) != 2: - raise ValueError("Must provide a two floats for specifying a 'bandpass' and 'bandstop' filters. " - f"{freq_band} provided instead!") + elif filter_type in ["bandpass", "bandstop"]: + try: + if len(freq_band) != 2: + raise ValueError + except Exception: + raise ValueError("Must provide a two floats for specifying a 'bandpass' and 'bandstop' filters. " + f"{freq_band} provided instead!") b, a = butter(order, freq_band, btype=filter_type, fs=sampling_frequency) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index ebf0a204..d9f16ace 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -2,6 +2,16 @@ import pynapple as nap import numpy as np from scipy import signal +from contextlib import nullcontext as does_not_raise + + +@pytest.fixture +def sample_data(): + # Create a sample Tsd data object + t = np.linspace(0, 1, 500) + d = np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 0.5, t.shape) + time_support = nap.IntervalSet(start=[0], end=[1]) + return nap.Tsd(t=t, d=d, time_support=time_support) @pytest.mark.parametrize("freq", [10, 100]) @@ -34,3 +44,20 @@ def test_filtering_single_freq(freq, order, btype, shape: tuple, ep): out_sci.append(signal.filtfilt(b, a, tsd.restrict(iset).d, axis=0)) out_sci = np.concatenate(out_sci, axis=0) np.testing.assert_array_equal(out.d, out_sci) + + +@pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ + ((5, 15), "bandpass", 4, does_not_raise()), + ((5, 15), "bandstop", 4, does_not_raise()), + (10, "highpass", 4, does_not_raise()), + (10, "lowpass", 4, does_not_raise()), + ((5, 15), "invalid_filter", 4, pytest.raises(ValueError)), + (10, "bandpass", 4, pytest.raises(ValueError)), + ((5, 15), "highpass", 4, pytest.raises(ValueError)), +]) +def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception): + with expected_exception: + filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type, order) + if not expected_exception: + assert isinstance(filtered_data, type(sample_data)) + assert filtered_data.d.shape == sample_data.d.shape \ No newline at end of file From 86ca5885d8a341863c5b00d8db6138bff5b92898 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 18:07:22 -0400 Subject: [PATCH 140/195] test dytpe --- tests/test_filtering.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index d9f16ace..ae11a243 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -26,7 +26,7 @@ def sample_data(): nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) ] ) -def test_filtering_single_freq(freq, order, btype, shape: tuple, ep): +def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep): t = np.linspace(0, 1, shape[0]) y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) @@ -46,6 +46,36 @@ def test_filtering_single_freq(freq, order, btype, shape: tuple, ep): np.testing.assert_array_equal(out.d, out_sci) +@pytest.mark.parametrize("freq", [10, 100]) +@pytest.mark.parametrize("order", [2, 4, 6]) +@pytest.mark.parametrize("btype", ["lowpass", "highpass"]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize( + "ep", + [ + nap.IntervalSet(start=[0], end=[1]), + nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), + nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) + ] +) +def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1])) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + @pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ ((5, 15), "bandpass", 4, does_not_raise()), ((5, 15), "bandstop", 4, does_not_raise()), From 098de145d58ed020d6a059ab66ca3c722b7e5165 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 18:09:10 -0400 Subject: [PATCH 141/195] removed simple test --- tests/test_filtering.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index ae11a243..b73c8e27 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -74,20 +74,3 @@ def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): assert np.all(out.time_support == tsd.time_support) if isinstance(tsd, nap.TsdFrame): assert np.all(tsd.columns == out.columns) - - -@pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ - ((5, 15), "bandpass", 4, does_not_raise()), - ((5, 15), "bandstop", 4, does_not_raise()), - (10, "highpass", 4, does_not_raise()), - (10, "lowpass", 4, does_not_raise()), - ((5, 15), "invalid_filter", 4, pytest.raises(ValueError)), - (10, "bandpass", 4, pytest.raises(ValueError)), - ((5, 15), "highpass", 4, pytest.raises(ValueError)), -]) -def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception): - with expected_exception: - filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type, order) - if not expected_exception: - assert isinstance(filtered_data, type(sample_data)) - assert filtered_data.d.shape == sample_data.d.shape \ No newline at end of file From 287523e695f99f6ceec977c9a1851ef0667ff169 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 7 Aug 2024 18:16:39 -0400 Subject: [PATCH 142/195] improved tests --- pynapple/process/filtering.py | 6 +++--- tests/test_filtering.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 246ba801..364ea407 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -54,14 +54,14 @@ def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sa raise ValueError(f"Unrecognized filter type {filter_type}. " "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'.") elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number): - raise ValueError("Must provide a single float for specifying a 'highpass' and 'lowpass' filters. " + raise ValueError("Low/high-pass filter specification requires a single frequency. " f"{freq_band} provided instead!") elif filter_type in ["bandpass", "bandstop"]: try: - if len(freq_band) != 2: + if len(freq_band) != 2 or not all(isinstance(fq, Number) for fq in freq_band): raise ValueError except Exception: - raise ValueError("Must provide a two floats for specifying a 'bandpass' and 'bandstop' filters. " + raise ValueError("Band-pass/stop filter specification requires two frequencies. " f"{freq_band} provided instead!") b, a = butter(order, freq_band, btype=filter_type, fs=sampling_frequency) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index b73c8e27..3d7e2972 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -74,3 +74,22 @@ def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): assert np.all(out.time_support == tsd.time_support) if isinstance(tsd, nap.TsdFrame): assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ + ((5, 15), "bandpass", 4, does_not_raise()), + ((5, 15), "bandstop", 4, does_not_raise()), + (10, "highpass", 4, does_not_raise()), + (10, "lowpass", 4, does_not_raise()), + ((5, 15), "invalid_filter", 4, pytest.raises(ValueError, match="Unrecognized filter type")), + (10, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")), + ((5, 15), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency")), + (None, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")), + ((None, 1), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency")) +]) +def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception): + with expected_exception: + filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type, order) + if not expected_exception: + assert isinstance(filtered_data, type(sample_data)) + assert filtered_data.d.shape == sample_data.d.shape From 5d3a340199addeaf7a448a44e640d7340081cf32 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 7 Aug 2024 22:18:46 -0400 Subject: [PATCH 143/195] Updating tests --- pynapple/process/signal_processing.py | 126 ++++--- tests/test_signal_processing.py | 499 ++++++++++++++------------ 2 files changed, 328 insertions(+), 297 deletions(-) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 722abafa..4d814549 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,7 +1,11 @@ """ -Signal processing tools for Pynapple. +# Signal processing tools + +- `nap.compute_power_spectral_density` +- `nap.compute_mean_power_spectral_density` +- `nap.compute_wavelet_transform` +- `nap.generate_morlet_filterbank` -Contains functionality for signal processing pynapple object; fourier transforms and wavelet decomposition. """ from numbers import Number @@ -23,7 +27,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm Time series. fs : float, optional Sampling rate, in Hz. If None, will be calculated from the given signal - ep : pynapple.IntervalSet or None, optional + ep : None or pynapple.IntervalSet, optional The epoch to calculate the fft on. Must be length 1. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values @@ -90,13 +94,13 @@ def compute_mean_power_spectral_density( Parameters ---------- - sig : Tsd or TsdFrame + sig : pynapple.Tsd or pynapple.TsdFrame Signal with equispaced samples interval_size : Number Epochs size to compute to average the FFT across fs : None, optional Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` - ep : None, optional + ep : None or pynapple.IntervalSet, optional The `IntervalSet` to calculate the fft on. Can be any length. full_range : bool, optional If true, will return full fft frequency range, otherwise will return only positive values @@ -216,32 +220,6 @@ def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): ) -def _create_freqs(freq_start, freq_stop, num_freqs=10, log_scaling=False): - """ - Creates an array of frequencies. - - Parameters - ---------- - freq_start : float - Starting value for the frequency definition. - freq_stop: float - Stopping value for the frequency definition, inclusive. - num_freqs: int, optional - Number of freqs to create. Default 10 - log_scaling: Bool - If True, will use log spacing with base log_base for frequency spacing. Default False. - - Returns - ------- - freqs: 1d array - Frequency indices. - """ - if not log_scaling: - return np.linspace(freq_start, freq_stop, num_freqs) - else: - return np.geomspace(freq_start, freq_stop, num_freqs) - - def compute_wavelet_transform( sig, freqs, fs=None, gaussian_width=1.5, window_length=1.0, precision=16, norm="l1" ): @@ -250,26 +228,24 @@ def compute_wavelet_transform( Parameters ---------- - sig : pynapple.Tsd, pynapple.TsdFrame or pynapple.TsdTensor + sig : pynapple.Tsd or pynapple.TsdFrame or pynapple.TsdTensor Time series. - freqs : 1d array or tuple of float - If array, frequency values to estimate with morlet wavelets. - If tuple, define the frequency range, as [freq_start, freq_stop, freq_step]. - The `freq_step` is optional, and defaults to 1. Range is inclusive of `freq_stop` value. + freqs : 1d array + Frequency values to estimate with Morlet wavelets. fs : float or None - Sampling rate, in Hz. Defaults to sig.rate if None is given. + Sampling rate, in Hz. Defaults to `sig.rate` if None is given. gaussian_width : float - Defines width of Gaussian to be used in wavelet creation. + Defines width of Gaussian to be used in wavelet creation. Default is 1.5. window_length : float - The length of window to be used for wavelet creation. + The length of window to be used for wavelet creation. Default is 1.0. precision: int. - Precision of wavelet to use. . Defines the number of timepoints to evaluate the Morlet wavelet at. - Default is 16 + Precision of wavelet to use. Defines the number of timepoints to evaluate the Morlet wavelet at. + Default is 16. norm : {None, 'l1', 'l2'}, optional Normalization method: - * None - no normalization - * 'l1' - divide by the sum of amplitudes - * 'l2' - divide by the square root of the sum of amplitudes + - None - no normalization + - 'l1' - divide by the sum of amplitudes + - 'l2' - divide by the square root of the sum of amplitudes Returns ------- @@ -280,10 +256,10 @@ def compute_wavelet_transform( -------- >>> import numpy as np >>> import pynapple as nap - >>> t = np.linspace(0, 1, 1000) + >>> t = np.arange(0, 1, 1/1000) >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) >>> freqs = np.linspace(10, 100, 10) - >>> mwt = nap.compute_wavelet_transform(signal, fs=None, freqs=freqs) + >>> mwt = nap.compute_wavelet_transform(signal, fs=1000, freqs=freqs) Notes ----- @@ -292,31 +268,28 @@ def compute_wavelet_transform( if not isinstance(sig, (nap.Tsd, nap.TsdFrame, nap.TsdTensor)): raise TypeError("`sig` must be instance of Tsd, TsdFrame, or TsdTensor") - if isinstance(gaussian_width, (int, float, np.number)): - if gaussian_width <= 0: - raise ValueError("gaussian_width must be a positive number.") - else: - raise TypeError("gaussian_width must be a float or int instance.") - if isinstance(window_length, (int, float, np.number)): - if window_length <= 0: - raise ValueError("window_length must be a positive number.") - else: - raise TypeError("window_length must be a float or int instance.") + + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") + if len(freqs) == 0: + raise ValueError("Given list of freqs cannot be empty.") + if np.min(freqs) <= 0: + raise ValueError("All frequencies in freqs must be strictly positive") + + if fs is not None and not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int or None") + if norm is not None and norm not in ["l1", "l2"]: raise ValueError("norm parameter must be 'l1', 'l2', or None.") - if not isinstance(freqs, (np.ndarray, tuple)): - raise TypeError("`freqs` must be a ndarray or tuple instance.") - if isinstance(freqs, tuple): - freqs = _create_freqs(*freqs) if fs is None: fs = sig.rate - if isinstance(sig, nap.Tsd): + if sig.ndim == 1: output_shape = (sig.shape[0], len(freqs)) else: output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) + sig = np.reshape(sig, (sig.shape[0], np.prod(sig.shape[1:]))) filter_bank = generate_morlet_filterbank( freqs, fs, gaussian_width, window_length, precision @@ -324,6 +297,7 @@ def compute_wavelet_transform( convolved_real = sig.convolve(filter_bank.real().values) convolved_imag = sig.convolve(filter_bank.imag().values) convolved = convolved_real.values + convolved_imag.values * 1j + if norm == "l1": coef = convolved / (fs / freqs) elif norm == "l2": @@ -352,8 +326,8 @@ def generate_morlet_filterbank( Parameters ---------- freqs : 1d array - Frequency values to estimate with morlet wavelets. - fs : float + frequency values to estimate with Morlet wavelets. + fs : float or int Sampling rate, in Hz. gaussian_width : float Defines width of Gaussian to be used in wavelet creation. @@ -367,10 +341,34 @@ def generate_morlet_filterbank( filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given """ + if not isinstance(freqs, np.ndarray): + raise TypeError("`freqs` must be a ndarray") if len(freqs) == 0: raise ValueError("Given list of freqs cannot be empty.") if np.min(freqs) <= 0: raise ValueError("All frequencies in freqs must be strictly positive") + + if not isinstance(fs, (int, float, np.number)): + raise TypeError("`fs` must be of type float or int ndarray") + + if isinstance(gaussian_width, (int, float, np.number)): + if gaussian_width <= 0: + raise ValueError("gaussian_width must be a positive number.") + else: + raise TypeError("gaussian_width must be a float or int instance.") + + if isinstance(window_length, (int, float, np.number)): + if window_length <= 0: + raise ValueError("window_length must be a positive number.") + else: + raise TypeError("window_length must be a float or int instance.") + + if isinstance(precision, int): + if precision <= 0: + raise ValueError("precision must be a positive number.") + else: + raise TypeError("precision must be a float or int instance.") + filter_bank = [] cutoff = 8 morlet_f = _morlet( diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 34b0e79f..6d1fc11a 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -2,6 +2,8 @@ import numpy as np import pytest +import re +from contextlib import nullcontext as does_not_raise import pynapple as nap @@ -76,13 +78,95 @@ def test_generate_morlet_filterbank(): ), ), ( - [], + "a", 1000, 1.5, 1.0, 16, - pytest.raises(ValueError, match="Given list of freqs cannot be empty."), + pytest.raises( + TypeError, match="`freqs` must be a ndarray" + ), + ), + ( + np.array([]), + 1000, + 1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="Given list of freqs cannot be empty." + ), + ), + ( + np.linspace(1, 10, 1), + "a", + 1.5, + 1.0, + 16, + pytest.raises( + TypeError, match="`fs` must be of type float or int ndarray" + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + -1.5, + 1.0, + 16, + pytest.raises( + ValueError, match="gaussian_width must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + "a", + 1.0, + 16, + pytest.raises( + TypeError, match="gaussian_width must be a float or int instance." + ), ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + -1.0, + 16, + pytest.raises( + ValueError, match="window_length must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + "a", + 16, + pytest.raises( + TypeError, match="window_length must be a float or int instance." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + -16, + pytest.raises( + ValueError, match="precision must be a positive number." + ), + ), + ( + np.linspace(1, 10, 1), + 1000, + 1.5, + 1.0, + "a", + pytest.raises( + TypeError, match="precision must be a float or int instance." + ), + ), ], ) def test_generate_morlet_filterbank_raise_errors( @@ -94,300 +178,249 @@ def test_generate_morlet_filterbank_raise_errors( ) -def test_compute_wavelet_transform(): - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 50 - assert ( - np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] - == 500 - ) - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.linspace(10, 100, 10)) - assert np.array_equal(mwt, mwt2) +############################################################ +# Test for compute_wavelet_transform +############################################################ +import pynapple as nap +import numpy as np - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=1001, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert np.array_equal(mwt, mwt2) - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 50 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = (10, 100, 10, True) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mwt2 = nap.compute_wavelet_transform(sig, fs=None, freqs=np.geomspace(10, 100, 10)) - assert np.array_equal(mwt, mwt2) +def get_1d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.Tsd(t, d, time_support=nap.IntervalSet(0, 2)) + +def get_2d_signal(fs=1000, fc=50): + t = np.arange(0, 2, 1 / fs) + d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) + return nap.TsdFrame(t, d[:,np.newaxis], time_support=nap.IntervalSet(0, 2)) + +def get_output_1d(sig, wavelets): + T = sig.shape[0] + M, N = wavelets.shape + out = [] + for n in range(N): + out.append(np.convolve(sig, wavelets[:, n], mode="full")) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] + +def get_output_2d(sig, wavelets): + T, K = sig.shape + M, N = wavelets.shape + out = [] + for k in range(K): + tmp = [] + for n in range(N): + tmp.append(np.convolve(sig[:,k], wavelets[:, n], mode="full")) + out.append(np.array(tmp)) + out = np.array(out).T + cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) + return out[cut[0] : cut[1]] - t = np.linspace(0, 1, 1001) - sig = nap.Tsd( - d=np.sin(t * 10 * np.pi * 2) - * np.interp(np.linspace(0, 1, 1001), [0, 0.5, 1], [0, 1, 0]), - t=t, - ) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 10 - assert ( - np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] - == 500 +@pytest.mark.parametrize( + "func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt", + [ + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), None, 1.5, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 3.0, 1.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 2.0, 16, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 20, None, 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l1", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, "l2", 50, 1000), + (get_1d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + (get_2d_signal, np.linspace(10, 100, 10), 1000, 1.5, 1.0, 16, None, 20, 1000), + ], +) +def test_compute_wavelet_transform( + func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt +): + sig = func(1000, fc) + wavelets = nap.generate_morlet_filterbank( + freqs, 1000, gaussian_width, window_length, precision ) + if sig.ndim == 1: + output = get_output_1d(sig.d, wavelets.values) + if sig.ndim == 2: + output = get_output_2d(sig.d, wavelets.values) - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 50 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l1") - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm="l2") - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 20 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs, norm=None) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 20 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1000) - sig = nap.Tsd(d=np.sin(t * 70 * np.pi * 2), t=t) - freqs = np.linspace(10, 100, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - mpf = freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] - assert mpf == 70 - assert mwt.shape == (1000, 10) - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10) - - t = np.linspace(0, 1, 1024) - sig = nap.Tsd(d=np.random.random(1024), t=t) - freqs = (1, 51, 6) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 6) - - t = np.linspace(0, 1, 1024) - sig = nap.TsdFrame(d=np.random.random((1024, 4)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4) - - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform(sig, fs=None, freqs=freqs) - assert mwt.shape == (1024, 10, 4, 2) - - # Testing against manual convolution for l1 norm - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l1" - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved / (1024 / freqs) - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support - ) - assert np.array_equal(mwt, mwt2) + if norm == "l1": + output = output / (1000 / freqs) + if norm == "l2": + output = output / (1000 / np.sqrt(freqs)) - # Testing against manual convolution for l2 norm - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm="l2" - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved / (1024 / np.sqrt(freqs)) - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + sig, + freqs, + fs=fs, + gaussian_width=gaussian_width, + window_length=window_length, + precision=precision, + norm=norm, ) - assert np.array_equal(mwt, mwt2) - # Testing against manual convolution for no normalization - t = np.linspace(0, 1, 1024) - sig = nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=t) - freqs = np.linspace(1, 600, 10) - mwt = nap.compute_wavelet_transform( - sig, fs=None, freqs=freqs, gaussian_width=1.5, window_length=1.0, norm=None - ) - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = sig.reshape((sig.shape[0], np.prod(sig.shape[1:]))) - filter_bank = nap.generate_morlet_filterbank(freqs, 1024, 1.5, 1.0, precision=16) - convolved_real = sig.convolve(filter_bank.real().values) - convolved_imag = sig.convolve(filter_bank.imag().values) - convolved = convolved_real.values + convolved_imag.values * 1j - coef = convolved - cwt = np.expand_dims(coef, -1) if len(coef.shape) == 2 else coef - mwt2 = nap.TsdTensor( - t=sig.index, d=cwt.reshape(output_shape), time_support=sig.time_support + np.testing.assert_array_almost_equal(output, mwt.values) + assert freqs[np.argmax(np.sum(np.abs(mwt), axis=0))] == fc + assert ( + np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] + == maxt ) - assert np.array_equal(mwt, mwt2) + np.testing.assert_array_almost_equal(mwt.time_support.values, sig.time_support.values) @pytest.mark.parametrize( - "sig, fs, freqs, gaussian_width, window_length, precision, norm, expectation", + "sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation", [ + (get_1d_signal(), np.linspace(1, 10, 2), 1000, 1.5, 1, 16, None, does_not_raise()), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(0, 600, 10), + "a", + np.linspace(1, 10, 2), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - ValueError, match="All frequencies in freqs must be strictly positive" + TypeError, + match=re.escape( + "`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + get_1d_signal(), + np.linspace(1, 10, 2), + "a", + 1.5, + 1, + 16, None, - np.linspace(1, 600, 10), + pytest.raises( + TypeError, + match=re.escape( + "`fs` must be of type float or int or None" + ), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, -1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - ValueError, match="gaussian_width must be a positive number." + ValueError, + match=re.escape("gaussian_width must be a positive number."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + "a", + 1, + 16, None, - np.linspace(1, 600, 10), + pytest.raises( + TypeError, + match=re.escape("gaussian_width must be a float or int instance."), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - -1.0, + -1, 16, - "l1", - pytest.raises(ValueError, match="window_length must be a positive number."), + None, + pytest.raises( + ValueError, + match=re.escape("window_length must be a positive number."), + ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), - "not_number", - 1.0, + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + "a", 16, - "l1", + None, pytest.raises( - TypeError, match="gaussian_width must be a float or int instance." + TypeError, + match=re.escape("window_length must be a float or int instance."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - "not_number", + 1, 16, - "l1", + "a", pytest.raises( - TypeError, match="window_length must be a float or int instance." + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - np.linspace(1, 600, 10), + get_1d_signal(), + "a", + 1000, 1.5, - 1.0, + 1, 16, - "l3", + None, pytest.raises( - ValueError, match="norm parameter must be 'l1', 'l2', or None." + TypeError, + match=re.escape("`freqs` must be a ndarray"), ), ), ( - nap.TsdTensor(d=np.random.random((1024, 4, 2)), t=np.linspace(0, 1, 1024)), - None, - None, + get_1d_signal(), + np.array([]), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + None, pytest.raises( - TypeError, match="`freqs` must be a ndarray or tuple instance." + ValueError, + match=re.escape("Given list of freqs cannot be empty."), ), ), ( - "not_a_signal", + get_1d_signal(), + np.array([-1]), + 1000, + 1.5, + 1, + 16, None, - np.linspace(10, 100, 10), + pytest.raises( + ValueError, + match=re.escape("All frequencies in freqs must be strictly positive"), + ), + ), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, 1.5, - 1.0, + 1, 16, - "l1", + 1, pytest.raises( - TypeError, match="`sig` must be instance of Tsd, TsdFrame, or TsdTensor" + ValueError, + match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), + ], ) def test_compute_wavelet_transform_raise_errors( From 5fab4dc4b006ffff279d41f92d9a1f982156c98f Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:29:31 +0100 Subject: [PATCH 144/195] check that fft of wavelet is correct gaussian --- tests/test_signal_processing.py | 77 ++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 34b0e79f..91c590a7 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,5 +1,5 @@ """Tests of `signal_processing` for pynapple""" - +import matplotlib.pyplot as plt import numpy as np import pytest @@ -61,6 +61,81 @@ def test_generate_morlet_filterbank(): for i, f in enumerate(freqs): assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + # Checking that the power spectra of the wavelets resemble correct Gaussians + fs = 2000 + freqs = np.linspace(100, 1000, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f)/f) ** 2) + assert np.isclose(power.iloc[:,i]/np.max(power.iloc[:,i]), morlet_ft/np.max(morlet_ft), atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=1.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 1.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=1.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 1.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 100 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=4.0, window_length=3.0, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 4.0 + window_length = 3.0 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + + fs = 1000 + freqs = np.linspace(1, 10, 10) + fb = nap.generate_morlet_filterbank( + freqs, fs, gaussian_width=3.5, window_length=1.25, precision=24 + ) + power = np.abs(nap.compute_power_spectral_density(fb)) + for i, f in enumerate(freqs): + gaussian_width = 3.5 + window_length = 1.25 + fz = power.index + factor = np.pi ** 0.25 * gaussian_width ** 0.25 + morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length * (fz - f) / f) ** 2) + assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), + atol=0.1).all() + @pytest.mark.parametrize( "freqs, fs, gaussian_width, window_length, precision, expectation", From a761f6efebab215a5cc982a3a936068747ff923b Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:31:09 +0100 Subject: [PATCH 145/195] linting --- tests/test_signal_processing.py | 126 ++++++++++++++++++++------------ 1 file changed, 78 insertions(+), 48 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 594d8550..5ea3bf4f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -1,9 +1,11 @@ """Tests of `signal_processing` for pynapple""" + +import re +from contextlib import nullcontext as does_not_raise + import matplotlib.pyplot as plt import numpy as np import pytest -import re -from contextlib import nullcontext as does_not_raise import pynapple as nap @@ -74,9 +76,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 1.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f)/f) ** 2) - assert np.isclose(power.iloc[:,i]/np.max(power.iloc[:,i]), morlet_ft/np.max(morlet_ft), atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -88,10 +96,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 1.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -103,10 +116,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 4.0 window_length = 1.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 100 freqs = np.linspace(1, 10, 10) @@ -118,10 +136,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 4.0 window_length = 3.0 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length*(fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() fs = 1000 freqs = np.linspace(1, 10, 10) @@ -133,10 +156,15 @@ def test_generate_morlet_filterbank(): gaussian_width = 3.5 window_length = 1.25 fz = power.index - factor = np.pi ** 0.25 * gaussian_width ** 0.25 - morlet_ft = factor * np.exp(-np.pi ** 2 * gaussian_width * (window_length * (fz - f) / f) ** 2) - assert np.isclose(power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1).all() + factor = np.pi**0.25 * gaussian_width**0.25 + morlet_ft = factor * np.exp( + -np.pi**2 * gaussian_width * (window_length * (fz - f) / f) ** 2 + ) + assert np.isclose( + power.iloc[:, i] / np.max(power.iloc[:, i]), + morlet_ft / np.max(morlet_ft), + atol=0.1, + ).all() @pytest.mark.parametrize( @@ -158,9 +186,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - TypeError, match="`freqs` must be a ndarray" - ), + pytest.raises(TypeError, match="`freqs` must be a ndarray"), ), ( np.array([]), @@ -168,9 +194,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - ValueError, match="Given list of freqs cannot be empty." - ), + pytest.raises(ValueError, match="Given list of freqs cannot be empty."), ), ( np.linspace(1, 10, 1), @@ -178,9 +202,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, 16, - pytest.raises( - TypeError, match="`fs` must be of type float or int ndarray" - ), + pytest.raises(TypeError, match="`fs` must be of type float or int ndarray"), ), ( np.linspace(1, 10, 1), @@ -208,9 +230,7 @@ def test_generate_morlet_filterbank(): 1.5, -1.0, 16, - pytest.raises( - ValueError, match="window_length must be a positive number." - ), + pytest.raises(ValueError, match="window_length must be a positive number."), ), ( np.linspace(1, 10, 1), @@ -228,9 +248,7 @@ def test_generate_morlet_filterbank(): 1.5, 1.0, -16, - pytest.raises( - ValueError, match="precision must be a positive number." - ), + pytest.raises(ValueError, match="precision must be a positive number."), ), ( np.linspace(1, 10, 1), @@ -241,7 +259,7 @@ def test_generate_morlet_filterbank(): pytest.raises( TypeError, match="precision must be a float or int instance." ), - ), + ), ], ) def test_generate_morlet_filterbank_raise_errors( @@ -253,12 +271,12 @@ def test_generate_morlet_filterbank_raise_errors( ) +import numpy as np ############################################################ # Test for compute_wavelet_transform ############################################################ import pynapple as nap -import numpy as np def get_1d_signal(fs=1000, fc=50): @@ -266,21 +284,24 @@ def get_1d_signal(fs=1000, fc=50): d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) return nap.Tsd(t, d, time_support=nap.IntervalSet(0, 2)) + def get_2d_signal(fs=1000, fc=50): t = np.arange(0, 2, 1 / fs) d = np.sin(t * fc * np.pi * 2) * np.interp(t, [0, 1, 2], [0, 1, 0]) - return nap.TsdFrame(t, d[:,np.newaxis], time_support=nap.IntervalSet(0, 2)) + return nap.TsdFrame(t, d[:, np.newaxis], time_support=nap.IntervalSet(0, 2)) + def get_output_1d(sig, wavelets): T = sig.shape[0] M, N = wavelets.shape out = [] for n in range(N): - out.append(np.convolve(sig, wavelets[:, n], mode="full")) + out.append(np.convolve(sig, wavelets[:, n], mode="full")) out = np.array(out).T cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) return out[cut[0] : cut[1]] + def get_output_2d(sig, wavelets): T, K = sig.shape M, N = wavelets.shape @@ -288,12 +309,13 @@ def get_output_2d(sig, wavelets): for k in range(K): tmp = [] for n in range(N): - tmp.append(np.convolve(sig[:,k], wavelets[:, n], mode="full")) + tmp.append(np.convolve(sig[:, k], wavelets[:, n], mode="full")) out.append(np.array(tmp)) out = np.array(out).T cut = ((M - 1) // 2, T + M - 1 - ((M - 1) // 2) - (1 - M % 2)) return out[cut[0] : cut[1]] + @pytest.mark.parametrize( "func, freqs, fs, gaussian_width, window_length, precision, norm, fc, maxt", [ @@ -341,13 +363,24 @@ def test_compute_wavelet_transform( np.unravel_index(np.abs(mwt.values).argmax(), np.abs(mwt.values).shape)[0] == maxt ) - np.testing.assert_array_almost_equal(mwt.time_support.values, sig.time_support.values) + np.testing.assert_array_almost_equal( + mwt.time_support.values, sig.time_support.values + ) @pytest.mark.parametrize( "sig, freqs, fs, gaussian_width, window_length, precision, norm, expectation", [ - (get_1d_signal(), np.linspace(1, 10, 2), 1000, 1.5, 1, 16, None, does_not_raise()), + ( + get_1d_signal(), + np.linspace(1, 10, 2), + 1000, + 1.5, + 1, + 16, + None, + does_not_raise(), + ), ( "a", np.linspace(1, 10, 2), @@ -373,11 +406,9 @@ def test_compute_wavelet_transform( None, pytest.raises( TypeError, - match=re.escape( - "`fs` must be of type float or int or None" - ), + match=re.escape("`fs` must be of type float or int or None"), ), - ), + ), ( get_1d_signal(), np.linspace(1, 10, 2), @@ -495,7 +526,6 @@ def test_compute_wavelet_transform( match=re.escape("norm parameter must be 'l1', 'l2', or None."), ), ), - ], ) def test_compute_wavelet_transform_raise_errors( From bda58921d3a04ebc0ba539b2cf4801ae9bbf7066 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Thu, 8 Aug 2024 16:35:29 +0100 Subject: [PATCH 146/195] removed bad import --- tests/test_signal_processing.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index 5ea3bf4f..b3e2e3d1 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -3,7 +3,6 @@ import re from contextlib import nullcontext as does_not_raise -import matplotlib.pyplot as plt import numpy as np import pytest @@ -271,12 +270,9 @@ def test_generate_morlet_filterbank_raise_errors( ) -import numpy as np - ############################################################ # Test for compute_wavelet_transform ############################################################ -import pynapple as nap def get_1d_signal(fs=1000, fc=50): From 9e5672a44f24f85090264999aa939e514e19badd Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 8 Aug 2024 14:45:00 -0400 Subject: [PATCH 147/195] Changed get method --- pynapple/core/base_class.py | 170 +++++++++++++++++++----------------- tests/test_time_series.py | 31 ++++--- 2 files changed, 106 insertions(+), 95 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 2c8b0e2d..75ec274f 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -403,69 +403,66 @@ def get(self, start, end=None, time_units="s"): end : float or int or None The end """ - assert isinstance(start, Number), "start should be a float or int" - time_array = self.index.values + sl = self.get_slice(start, end, time_units) if end is None: - start = TsIndex.format_timestamps(np.array([start]), time_units)[0] - idx = int(np.searchsorted(time_array, start)) - if idx == 0: - return self[idx] - elif idx >= self.shape[0]: - return self[-1] - else: - if start - time_array[idx - 1] < time_array[idx] - start: - return self[idx - 1] - else: - return self[idx] - else: - assert isinstance(end, Number), "end should be a float or int" - assert start < end, "Start should not precede end" - start, end = TsIndex.format_timestamps(np.array([start, end]), time_units) - idx_start = np.searchsorted(time_array, start) - idx_end = np.searchsorted(time_array, end, side="right") - return self[idx_start:idx_end] + sl = sl.start - def _get_filename(self, filename): - """Check if the filename is valid and return the path + return self[sl] + + def get_slice(self, start, end=None, time_unit="s"): + """ + Get a slice object from the time series data based on the start and end values such that all the timestamps satisfy `start<=t<=end`. + If `end` is None, only the timepoint closest to `start` is returned. + + By default, the time support doesn't change. If you want to change the time support, use the `restrict` function. + + This function is equivalent of calling the `get` method. Parameters ---------- - filename : str or Path - The filename + start : int or float + The starting value for the slice. + end : int or float, optional + The ending value for the slice. Defaults to None. + time_unit : str, optional + The time unit for the start and end values. Defaults to "s" (seconds). Returns ------- - Path - The path to the file + slice : slice + A slice determining the start and end indices, with unit step + Slicing the array will be equivalent to calling get: `ts[s].t == ts.get(start, end).t` with `s` being the slice object. + Raises ------ - RuntimeError - If the filename is a directory or the parent does not exist - """ + ValueError + - If start or end is not a number. + - If start is greater than end. - return check_filename(filename) + Examples + -------- + >>> import pynapple as nap - @classmethod - def _from_npz_reader(cls, file): - """Load a time series object from a npz file interface. + >>> ts = nap.Ts(t = [0, 1, 2, 3]) - Parameters - ---------- - file : NPZFile object - opened npz file interface. + >>> # slice over a range + >>> start, end = 1.2, 2.6 + >>> print(ts.get_slice(start, end)) # returns `slice(2, 3, None)` + >>> start, end = 1., 2. + >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(1, 3, None)` - Returns - ------- - out : Ts or Tsd or TsdFrame or TsdTensor - The time series object + >>> # slice a single value + >>> start = 1.2 + >>> print(ts.get_slice(start)) # returns `slice(1, 2, None)` + >>> start = 2. + >>> print(ts.get_slice(start)) # returns `slice(2, 3, None)` """ - kwargs = { - key: file[key] for key in file.keys() if key not in ["start", "end", "type"] - } - iset = IntervalSet(start=file["start"], end=file["end"]) - return cls(time_support=iset, **kwargs) + mode = "closest_t" if end is None else "restrict" + return self._get_slice( + start, end=end, mode=mode, n_points=None, time_unit=time_unit + ) def _get_slice( self, start, end=None, mode="closest_t", n_points=None, time_unit="s" @@ -473,6 +470,9 @@ def _get_slice( """ Get a slice from the time series data based on the start and end values with the specified mode. + For a given time t, mode `before_t` means you want the timepoint right before t to start the slice. + Mode `after_t` means you want the timepoint right after t to start the slice. + Parameters ---------- start : int or float @@ -502,16 +502,20 @@ def _get_slice( - For mode == "restrict": - slice the indices such that start <= self.t[idx] <= end If end is provided: - - For mode == "backward": + - For mode == "before_t": - An empty slice if end < self.t[0] - slice(idx_start, idx_end) with self.t[idx_start] <= start < self.t[idx_start+1] and self.t[idx_end] <= end < self.t[idx_end+1] - - For mode == "forward": + - For mode == "after_t": - An empty slice if start > self.t[-1] - slice(idx_start, idx_end) with self.t[idx_start-1] <= start < self.t[idx_start] and self.t[idx_end-1] <= end < self.t[idx_end] - For mode == "closest": - slice(idx_start, idx_end) with the closest indices to start and end + - For mode == "restrict": + - An empty slice if start > self.t[-1] or end < self.t[0] + - slice(idx_start, idx_end) with self.t[idx_start] <= start <= self.t[idx_start+1] and + self.t[idx_end] <= end <= self.t[idx_end+1] Raises ------ @@ -525,9 +529,19 @@ def _get_slice( f"'start' must be an int or a float. Type {type(start)} provided instead!" ) + if n_points is not None and not isinstance(n_points, int): + raise TypeError( + f"'n_points' must be of type int or None. Type {type(n_points)} provided instead!" + ) + if end is None and n_points: raise ValueError("'n_points' can be used only when 'end' is specified!") + if mode not in ["before_t", "after_t", "closest_t", "restrict"]: + raise ValueError( + "'mode' only accepts 'before_t', 'after_t', 'closest_t' or 'restrict'." + ) + if mode == "restrict" and n_points: raise ValueError( "Fixing the number of time points is incompatible with 'restrict' mode." @@ -600,51 +614,43 @@ def _get_slice( return slice(idx_start, idx_end, step) - def get_slice(self, start, end=None, time_unit="s"): - """ - Get a slice from the time series data based on the start and end values with the specified mode. + def _get_filename(self, filename): + """Check if the filename is valid and return the path Parameters ---------- - start : int or float - The starting value for the slice. - end : int or float, optional - The ending value for the slice. Defaults to None. - time_unit : str, optional - The time unit for the start and end values. Defaults to "s" (seconds). + filename : str or Path + The filename Returns ------- - slice : slice - A slice determining the start and end indices, with unit step - Slicing the array will be equivalent to calling get: `ts[s].t == ts.get(start, end).t` - + Path + The path to the file Raises ------ - ValueError - - If start or end is not a number. - - If start is greater than end. + RuntimeError + If the filename is a directory or the parent does not exist + """ - Examples - -------- - >>> import pynapple as nap + return check_filename(filename) - >>> ts = nap.Ts(t = [0, 1, 2, 3]) + @classmethod + def _from_npz_reader(cls, file): + """Load a time series object from a npz file interface. - >>> # slice over a range - >>> start, end = 1.2, 2.6 - >>> print(ts.get_slice(start, end)) # returns `slice(2, 3, None)` - >>> start, end = 1., 2. - >>> print(ts.get_slice(start, end, mode="forward")) # returns `slice(1, 3, None)` + Parameters + ---------- + file : NPZFile object + opened npz file interface. - >>> # slice a single value - >>> start = 1.2 - >>> print(ts.get_slice(start)) # returns `slice(1, 2, None)` - >>> start = 2. - >>> print(ts.get_slice(start)) # returns `slice(2, 3, None)` + Returns + ------- + out : Ts or Tsd or TsdFrame or TsdTensor + The time series object """ - mode = "closest_t" if end is None else "restrict" - return self._get_slice( - start, end=end, mode=mode, n_points=None, time_unit=time_unit - ) + kwargs = { + key: file[key] for key in file.keys() if key not in ["start", "end", "type"] + } + iset = IntervalSet(start=file["start"], end=file["end"]) + return cls(time_support=iset, **kwargs) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index ecda7ac2..b7da1603 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1441,24 +1441,30 @@ def test_pickling(obj): assert np.all(obj.time_support == unpickled_obj.time_support) +#################################################### +# Test for slicing +#################################################### + + @pytest.mark.parametrize( - "start, end, expectation", + "start, end, mode, n_points, expectation", [ - (1, 3, does_not_raise()), - (3, 1, pytest.raises(ValueError, match="'start' should not precede 'end'")), - (1., 3., does_not_raise()), - (1., None, does_not_raise()), - (None, 3, pytest.raises(ValueError, match="'start' must be an int or a float")), - ("a", 3, pytest.raises(ValueError, match="'start' must be an int or a float")), - (2, "a", pytest.raises(ValueError, match="'end' must be an int or a float")), - + (1, 3, "closest_t", None, does_not_raise()), + (None, 3, "closest_t", None, pytest.raises(ValueError, match="'start' must be an int or a float")), + (2, "a", "closest_t", None, pytest.raises(ValueError, match="'end' must be an int or a float. Type provided instead!")), + (1, 3, "closest_t", "a", pytest.raises(TypeError, match="'n_points' must be of type int or None. Type provided instead!")), + (1, None, "closest_t", 1, pytest.raises(ValueError, match="'n_points' can be used only when 'end' is specified!")), + (1, 3, "banana", None, pytest.raises(ValueError, match="'mode' only accepts 'before_t', 'after_t', 'closest_t' or 'restrict'.")), + (3, 1, "closest_t", None, pytest.raises(ValueError, match="'start' should not precede 'end'")), + (1, 3, "restrict", 1, pytest.raises(ValueError, match="Fixing the number of time points is incompatible with 'restrict' mode.")), + (1., 3., "closest_t", None, does_not_raise()), + (1., None, "closest_t", None, does_not_raise()), ] ) -@pytest.mark.parametrize("time_unit", ["s", "ms", "us"]) -def test_get_slice_value_types(start, end, time_unit, expectation): +def test_get_slice_raise_errors(start, end, mode, n_points, expectation): ts = nap.Ts(t=np.array([1, 2, 3, 4])) with expectation: - ts._get_slice(start, end, time_unit=time_unit) + ts._get_slice(start, end, mode, n_points) @pytest.mark.parametrize( @@ -1570,7 +1576,6 @@ def test_get_slice_vs_get_random_val_start_value(): - @pytest.mark.parametrize( "end, n_points, expectation", [ From 96510776dd309dd68c457234efae194250c633ca Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 8 Aug 2024 15:22:22 -0400 Subject: [PATCH 148/195] updating --- pynapple/core/base_class.py | 2 +- pynapple/process/signal_processing.py | 2 +- tests/test_power_spectral_density.py | 4 ++-- tests/test_time_series.py | 2 ++ 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index 337f304f..c119b57a 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -568,7 +568,7 @@ def _get_slice( # get index of preceding time value idx_start = np.searchsorted(self.t, start, side="left") - if idx_start == len(self.t): + if idx_start == len(self.t) and mode != "restrict": idx_start -= 1 # make sure the index is not out of bound if mode == "before_t": diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4d814549..cdb5639f 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -155,7 +155,7 @@ def compute_mean_power_spectral_density( slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) slices[i, 0] = sl.start slices[i, 1] = sl.stop diff --git a/tests/test_power_spectral_density.py b/tests/test_power_spectral_density.py index 626b0832..fc76103c 100644 --- a/tests/test_power_spectral_density.py +++ b/tests/test_power_spectral_density.py @@ -161,7 +161,7 @@ def get_signal_and_output(f=2, fs=1000, duration=100, interval_size=10): d = np.cos(2 * np.pi * f * t) sig = nap.Tsd(t=t, d=d, time_support=nap.IntervalSet(0, 100)) tmp = d.reshape((int(duration / interval_size), int(fs * interval_size))).T - tmp = tmp[0:-1] + # tmp = tmp[0:-1] tmp = tmp*signal.windows.hamming(tmp.shape[0])[:,np.newaxis] out = np.sum(np.fft.fft(tmp, axis=0), 1) freq = np.fft.fftfreq(out.shape[0], 1 / fs) @@ -190,7 +190,7 @@ def test_compute_mean_power_spectral_density(): psd = nap.compute_mean_power_spectral_density(sig, 10, norm=True) assert isinstance(psd, pd.DataFrame) assert psd.shape[0] > 0 # Check that the psd DataFrame is not empty - np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(9999.0*10.0)) + np.testing.assert_array_almost_equal(psd.values.flatten(), out[freq >= 0]/(10000.0*10.0)) np.testing.assert_array_almost_equal(psd.index.values, freq[freq >= 0]) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index b7be2f1a..446a4fd0 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1682,6 +1682,8 @@ def test_get_slice_value_step(start, end, n_points, mode, expected_slice, expect (1, None, slice(0, 1), np.array([1])), (4, None, slice(3, 4), np.array([4])), (5, None, slice(3, 4), np.array([4])), + (-1, 0, slice(0, 0), np.array([])), + (5, 6, slice(4, 4), np.array([])), ] ) @pytest.mark.parametrize("ts", From ee558e560193a594dc673871105c78744ee77ad0 Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 9 Aug 2024 15:52:25 +0100 Subject: [PATCH 149/195] addressing comments, slight notebook improvements --- docs/api_guide/tutorial_pynapple_wavelets.py | 20 +++- docs/examples/tutorial_phase_preferences.py | 2 +- docs/examples/tutorial_signal_processing.py | 114 ++++++++++++------- pynapple/io/folder.py | 9 -- pynapple/process/signal_processing.py | 65 +++++++---- tests/test_signal_processing.py | 6 + 6 files changed, 143 insertions(+), 73 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_wavelets.py b/docs/api_guide/tutorial_pynapple_wavelets.py index 016d7cc0..20246d0f 100644 --- a/docs/api_guide/tutorial_pynapple_wavelets.py +++ b/docs/api_guide/tutorial_pynapple_wavelets.py @@ -5,11 +5,14 @@ This tutorial covers the use of `nap.compute_wavelet_transform` to do continuous wavelet transform. By default, pynapple uses Morlet wavelets. +Wavelet are a great tool for capturing changes of spectral characteristics of a signal over time. As neural signals change +and develop over time, wavelet decompositions can aid both visualization and analysis. + The function `nap.generate_morlet_filterbank` can help parametrize and visualize the Morlet wavelets. See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ @@ -84,12 +87,13 @@ # %% # Lets plot it some of the wavelets. + def plot_filterbank(filter_bank, freqs, title): fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 7)) for f_i in range(filter_bank.shape[1]): - ax.plot(filter_bank[:, f_i].real() + f_i*1.5) + ax.plot(filter_bank[:, f_i].real() + f_i * 1.5) ax.text(-5.5, 1.5 * f_i, f"{np.round(freqs[f_i], 2)}Hz", va="center", ha="left") - + ax.set_yticks([]) ax.set_xlim(-5, 5) ax.set_xlabel("Time (s)") @@ -121,6 +125,7 @@ def plot_filterbank(filter_bank, freqs, title): # %% # Lets plot it. + def plot_timefrequency(freqs, powers, ax=None): im = ax.imshow(np.abs(powers), aspect="auto") ax.invert_yaxis() @@ -209,7 +214,9 @@ def plot_timefrequency(freqs, powers, ax=None): ax1 = plt.subplot(gs[1, 0]) ax1.plot(sig, label="Raw Signal", alpha=0.5) -ax1.plot(slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction") +ax1.plot( + slow_oscillation + fifteenHz_oscillation.values, label="2Hz + 15Hz Reconstruction" +) ax1.set_xlabel("Time (s)") [ @@ -222,6 +229,11 @@ def plot_timefrequency(freqs, powers, ax=None): # *** # Adding ALL the Oscillations! # ------------------ +# We will now learn how to interpret the parameters of the wavelet, and in particular how to trade off the +# accuracy in the frequency decomposition with the accuracy in the time domain reconstruction; + +# Up to this point we have used default wavelet and normalization parameters. +# # Let's now add together the real components of all frequency bands to recreate a version of the original signal. combined_oscillations = np.real(np.sum(mwt, axis=1)) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 0abb753b..5af93b3c 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -8,7 +8,7 @@ Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ # %% diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_signal_processing.py index dd7d66df..1daf034a 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_signal_processing.py @@ -7,7 +7,7 @@ Specifically, we will examine Local Field Potential data from a period of active traversal of a linear track. -This tutorial was made by Kipp Freud. +This tutorial was made by [Kipp Freud](https://kippfreud.com/). """ @@ -26,6 +26,7 @@ import math import os +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import requests @@ -54,18 +55,11 @@ total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), ): f.write(data) - - -# %% -# *** -# Loading the data -# ------------------ # Let's load and print the full dataset. - data = nap.load_file(path) - print(data) + # %% # First we can extract the data from the NWB. The local field potential has been downsampled to 1250Hz. We will call it `eeg`. # @@ -113,6 +107,9 @@ axd["pos"].set_ylabel("Linearized Position") axd["pos"].set_xlim(RUN_interval[0, 0], RUN_interval[0, 1]) +# %% +# In the top panel, we can see the lfp trace as a function of time, and on the bottom the mouse position on the linear +# track as a function of time. Position 0 and 1 correspond to the start and end of the trial respectively. # %% # *** @@ -145,22 +142,30 @@ # %% -# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. -# Hippocampal theta rhythm appears mostly when the animal is running. -# (See Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666.) -# We can check it here by separating `wake_ep` into `run_ep` and `rest_ep`. -run_ep = data['position'].dropna().find_support(1) +# The red area outlines the theta rhythm (6-10 Hz) which is proeminent in hippocampal LFP. +# Hippocampal theta rhythm appears mostly when the animal is running [1]. +# We can check it here by separating the wake epochs (`wake_ep`) into run epochs (`run_ep`) and rest epochs (`rest_ep`). + +# The run epoch is the portion of the data for which we have position data +run_ep = data["position"].dropna().find_support(1) +# The rest epoch is the data at all points where we do not have position data rest_ep = wake_ep.set_diff(run_ep) # %% -# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# `run_ep` and `rest_ep` are IntervalSet with discontinuous epoch. +# # The function `nap.compute_power_spectral_density` takes signal with a single epoch to avoid artefacts between epochs jumps. -# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. -# -# In this case, `interval_size` is equal to 1.5 seconds. +# +# To compare `run_ep` with `rest_ep`, we can use the function `nap.compute_mean_power_spectral_density` which avearge the FFT over multiple epochs of same duration. The parameter `interval_size` controls the duration of those epochs. +# +# In this case, `interval_size` is equal to 1.5 seconds. -power_run = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=run_ep, norm=True) -power_rest = nap.compute_mean_power_spectral_density(eeg, 1.5, fs=FS, ep=rest_ep, norm=True) +power_run = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=run_ep, norm=True +) +power_rest = nap.compute_mean_power_spectral_density( + eeg, 1.5, fs=FS, ep=rest_ep, norm=True +) # %% # `power_run` and `power_rest` are the power spectral density when the animal is respectively running and resting. @@ -170,13 +175,13 @@ np.abs(power_run[(power_run.index >= 3.0) & (power_run.index <= 30)]), alpha=1, label="Run", - linewidth=2 + linewidth=2, ) ax.plot( np.abs(power_rest[(power_rest.index >= 3.0) & (power_rest.index <= 30)]), alpha=1, label="Rest", - linewidth=2 + linewidth=2, ) ax.axvspan(6, 10, color="red", alpha=0.1) ax.set_xlabel("Freq (Hz)") @@ -223,7 +228,7 @@ ax1 = plt.subplot(gs[1, 0], sharex=ax0) ax1.plot(eeg_example) -ax1.set_ylabel("LFP (v)") +ax1.set_ylabel("LFP (a.u.)") ax1 = plt.subplot(gs[2, 0], sharex=ax0) ax1.plot(pos_example, color="black") @@ -238,11 +243,11 @@ # There seems to be a strong theta frequency present in the data during the maze traversal. # Let's plot the estimated 6-10Hz component of the wavelet decomposition on top of our data, and see how well they match up. -theta_freq_index = np.logical_and(freqs>6, freqs<10) +theta_freq_index = np.logical_and(freqs > 6, freqs < 10) # Extract its real component, as well as its power envelope -theta_band_reconstruction = np.mean(mwt_RUN[:,theta_freq_index], 1) +theta_band_reconstruction = np.mean(mwt_RUN[:, theta_freq_index], 1) theta_band_power_envelope = np.abs(theta_band_reconstruction) @@ -252,13 +257,13 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 6)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.9]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example, label="CA1") ax0.set_title("EEG (1250 Hz)") ax0.set_ylabel("LFP (a.u.)") ax0.set_xlabel("time (s)") ax0.legend() -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.plot(np.real(theta_band_reconstruction), label="6-10 Hz oscillations") ax1.plot(theta_band_power_envelope, label="6-10 Hz power envelope") ax1.set_xlabel("time (s)") @@ -267,7 +272,12 @@ # %% # *** -# Visualizing high frequency oscillation +# We observe that the theta power is far stronger during the first 4 seconds of the dataset, during which the rat +# is traversing the linear track. + +# %% +# *** +# Visualizing High Frequency Oscillation # ----------------------------------- # There also seem to be peaks in the 200Hz frequency power after traversal of thew maze is complete. Here we use the interval (18356, 18357.5) seconds to zoom in. @@ -292,11 +302,11 @@ ax1.set_xlabel("Time (s)") # %% -# Those events are called Sharp-waves ripples (See : Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188.) +# Those events are called Sharp-waves ripples [2]. # # Among other methods, we can use the Wavelet decomposition to isolate them. In this case, we will look at the power of the wavelets for frequencies between 150 to 250 Hz. -ripple_freq_index = np.logical_and(freqs>150, freqs<250) +ripple_freq_index = np.logical_and(freqs > 150, freqs < 250) # %% # We can compute the mean power for this frequency band. @@ -309,11 +319,11 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") -ax0.set_ylabel("LFP (v)") +ax0.set_ylabel("LFP (a.u.)") ax0.set_title(f"EEG (1250 Hz)") -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.legend() ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") ax1.legend() @@ -342,20 +352,44 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.5]) -ax0 = plt.subplot(gs[0,0]) +ax0 = plt.subplot(gs[0, 0]) ax0.plot(eeg_example.restrict(zoom_ep), label="CA1") -for s,e in rip_ep.intersect(zoom_ep).values: - ax0.axvspan(s, e, color='red', alpha=0.1, ec=None) -ax0.set_ylabel("LFP (v)") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax0.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) +ax0.set_ylabel("LFP (a.u.)") ax0.set_title(f"EEG (1250 Hz)") -ax1 = plt.subplot(gs[1,0]) +ax1 = plt.subplot(gs[1, 0]) ax1.legend() ax1.plot(ripple_power.restrict(zoom_ep), label="150-250 Hz") ax1.plot(smoothed_ripple_power.restrict(zoom_ep)) -for s,e in rip_ep.intersect(zoom_ep).values: - ax1.axvspan(s, e, color='red', alpha=0.1, ec=None) +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax1.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) ax1.legend() ax1.set_ylabel("Mean Amplitude") ax1.set_xlabel("Time (s)") +# %% +# Finally, let's zoom in on each of our isolated ripples + +fig = plt.figure(constrained_layout=True, figsize=(10, 5)) +gs = plt.GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0]) +buffer = 0.02 +plt.suptitle("Isolated Sharp Wave Ripples") +for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): + ax = plt.subplot(gs[int(i / 2), i % 2]) + ax.plot(eeg_example.restrict(nap.IntervalSet(s - buffer, e + buffer))) + ax.axvspan(s, e, color=list(mcolors.TABLEAU_COLORS.keys())[i], alpha=0.2, ec=None) + ax.set_xlim(s - buffer, e + buffer) + ax.set_xlabel("Time (s)") + ax.set_ylabel("LFP (a.u.)") + + +# %% +# *** +# References +# ----------------------------------- +# +# [1] Hasselmo, M. E., & Stern, C. E. (2014). Theta rhythm and the encoding and retrieval of space and time. Neuroimage, 85, 656-666. +# +# [2] Buzsáki, G. (2015). Hippocampal sharp wave‐ripple: A cognitive biomarker for episodic memory and planning. Hippocampus, 25(10), 1073-1188. diff --git a/pynapple/io/folder.py b/pynapple/io/folder.py index 8f7d2f1a..a35af18d 100644 --- a/pynapple/io/folder.py +++ b/pynapple/io/folder.py @@ -1,16 +1,7 @@ -#!/usr/bin/env python - -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-05-15 15:32:24 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2024-08-02 11:35:10 - """ The Folder class helps to navigate a hierarchical data tree. """ - import json import string from collections import UserDict diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 4d814549..99f3d9ba 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -1,5 +1,5 @@ """ -# Signal processing tools +# Signal processing tools. - `nap.compute_power_spectral_density` - `nap.compute_mean_power_spectral_density` @@ -19,7 +19,7 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False): """ - Performs numpy fft on sig, returns output. Pynapple assumes a constant sampling rate for sig. + Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. Parameters ---------- @@ -85,7 +85,9 @@ def compute_mean_power_spectral_density( norm=False, time_unit="s", ): - """Compute mean power spectral density by averaging FFT over epochs of same size. + """ + Compute mean power spectral density by averaging FFT over epochs of same size. + The parameter `interval_size` controls the duration of the epochs. To imporve frequency resolution, the signal is multiplied by a Hamming window. @@ -114,6 +116,14 @@ def compute_mean_power_spectral_density( pandas.DataFrame Power spectral density. + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.arange(0, 1, 1/1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1) + Raises ------ RuntimeError @@ -155,7 +165,7 @@ def compute_mean_power_spectral_density( slices = np.zeros((len(split_ep), 2), dtype=int) for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1], mode="backward") + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) slices[i, 0] = sl.start slices[i, 1] = sl.stop @@ -244,7 +254,7 @@ def compute_wavelet_transform( norm : {None, 'l1', 'l2'}, optional Normalization method: - None - no normalization - - 'l1' - divide by the sum of amplitudes + - 'l1' - (default) divide by the sum of amplitudes - 'l2' - divide by the square root of the sum of amplitudes Returns @@ -320,8 +330,10 @@ def generate_morlet_filterbank( freqs, fs, gaussian_width=1.5, window_length=1.0, precision=16 ): """ - Generates a Morlet filterbank using the given frequencies and parameters. Can be used purely for visualization, - or to convolve with a pynapple Tsd, TsdFrame, or TsdTensor as part of a wavelet decomposition process. + Generates a Morlet filterbank using the given frequencies and parameters. + + This function can be used purely for visualization, or to convolve with a pynapple Tsd, + TsdFrame, or TsdTensor as part of a wavelet decomposition process. Parameters ---------- @@ -340,6 +352,12 @@ def generate_morlet_filterbank( ------- filter_bank : pynapple.TsdFrame list of Morlet wavelet filters of the frequencies given + + Notes + ----- + This algorithm first computes a single, finely sampled wavelet using the provided hyperparameters. + Wavelets of different frequencies are generated by resampling this mother wavelet with an appropriate step size. + The step size is determined based on the desired frequency and the sampling rate. """ if not isinstance(freqs, np.ndarray): raise TypeError("`freqs` must be a ndarray") @@ -369,27 +387,35 @@ def generate_morlet_filterbank( else: raise TypeError("precision must be a float or int instance.") + # Initialize filter bank and parameters filter_bank = [] - cutoff = 8 - morlet_f = _morlet( - int(2**precision), gaussian_width=gaussian_width, window_length=window_length + cutoff = 8 # Define cutoff for wavelet + # Compute a single, finely sampled Morlet wavelet + morlet_f = np.conj( + _morlet( + int(2**precision), + gaussian_width=gaussian_width, + window_length=window_length, + ) ) x = np.linspace(-cutoff, cutoff, int(2**precision)) - int_psi = np.conj(morlet_f) - max_len = -1 + max_len = -1 # Track maximum length of wavelet for freq in freqs: scale = window_length / (freq / fs) + # Calculate the indices for subsampling the wavelet and achieve the right frequency + # After the slicing the size will be reduced, therefore we will pad with 0s. j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) - j = j.astype(int) # floor - if j[-1] >= int_psi.size: - j = np.extract(j < int_psi.size, j) - int_psi_scale = int_psi[j][::-1] - if len(int_psi_scale) > max_len: - max_len = len(int_psi_scale) + j = j.astype(int) # Floor the values to get integer indices + if j[-1] >= morlet_f.size: + j = np.extract(j < morlet_f.size, j) + scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet + if len(scaled_morlet) > max_len: + max_len = len(scaled_morlet) time = np.linspace( -cutoff * window_length / freq, cutoff * window_length / freq, max_len ) - filter_bank.append(int_psi_scale) + filter_bank.append(scaled_morlet) + # Pad wavelets to ensure all are of the same length filter_bank = [ np.pad( arr, @@ -398,4 +424,5 @@ def generate_morlet_filterbank( ) for arr in filter_bank ] + # Return filter bank as a TsdFrame return nap.TsdFrame(d=np.array(filter_bank).transpose(), t=time) diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index b3e2e3d1..a708134f 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -17,6 +17,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -26,6 +27,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 1000 @@ -35,6 +37,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -44,6 +47,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 1000 @@ -53,6 +57,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() fs = 10000 @@ -62,6 +67,7 @@ def test_generate_morlet_filterbank(): ) power = np.abs(nap.compute_power_spectral_density(fb)) for i, f in enumerate(freqs): + # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() # Checking that the power spectra of the wavelets resemble correct Gaussians From 861042a38785315e4284ef7751813bae739e796e Mon Sep 17 00:00:00 2001 From: Kipp Freud Date: Fri, 9 Aug 2024 17:04:08 +0100 Subject: [PATCH 150/195] suggested changes and wavelet arange rounding fix --- ...rocessing.py => tutorial_wavelet_decomposition.py} | 2 +- pynapple/process/signal_processing.py | 9 +++------ tests/test_signal_processing.py | 11 ++++++----- 3 files changed, 10 insertions(+), 12 deletions(-) rename docs/examples/{tutorial_signal_processing.py => tutorial_wavelet_decomposition.py} (99%) diff --git a/docs/examples/tutorial_signal_processing.py b/docs/examples/tutorial_wavelet_decomposition.py similarity index 99% rename from docs/examples/tutorial_signal_processing.py rename to docs/examples/tutorial_wavelet_decomposition.py index 1daf034a..530540ec 100644 --- a/docs/examples/tutorial_signal_processing.py +++ b/docs/examples/tutorial_wavelet_decomposition.py @@ -374,7 +374,7 @@ fig = plt.figure(constrained_layout=True, figsize=(10, 5)) gs = plt.GridSpec(2, 2, figure=fig, height_ratios=[1.0, 1.0]) -buffer = 0.02 +buffer = 0.075 plt.suptitle("Isolated Sharp Wave Ripples") for i, (s, e) in enumerate(rip_ep.intersect(zoom_ep).values): ax = plt.subplot(gs[int(i / 2), i % 2]) diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 99f3d9ba..6071e98e 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -295,11 +295,8 @@ def compute_wavelet_transform( if fs is None: fs = sig.rate - if sig.ndim == 1: - output_shape = (sig.shape[0], len(freqs)) - else: - output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) - sig = np.reshape(sig, (sig.shape[0], np.prod(sig.shape[1:]))) + output_shape = (sig.shape[0], len(freqs), *sig.shape[1:]) + sig = np.reshape(sig, (sig.shape[0], -1)) filter_bank = generate_morlet_filterbank( freqs, fs, gaussian_width, window_length, precision @@ -405,7 +402,7 @@ def generate_morlet_filterbank( # Calculate the indices for subsampling the wavelet and achieve the right frequency # After the slicing the size will be reduced, therefore we will pad with 0s. j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * (x[1] - x[0])) - j = j.astype(int) # Floor the values to get integer indices + j = np.ceil(j).astype(int) # Ceil the values to get integer indices if j[-1] >= morlet_f.size: j = np.extract(j < morlet_f.size, j) scaled_morlet = morlet_f[j][::-1] # Scale and reverse wavelet diff --git a/tests/test_signal_processing.py b/tests/test_signal_processing.py index a708134f..c9c1495b 100644 --- a/tests/test_signal_processing.py +++ b/tests/test_signal_processing.py @@ -70,6 +70,7 @@ def test_generate_morlet_filterbank(): # Check that peak freq matched expectation assert power.iloc[:, i].argmax() == np.abs(power.index - f).argmin() + gaussian_atol = 1e-4 # Checking that the power spectra of the wavelets resemble correct Gaussians fs = 2000 freqs = np.linspace(100, 1000, 10) @@ -88,7 +89,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -108,7 +109,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -128,7 +129,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 100 @@ -148,7 +149,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() fs = 1000 @@ -168,7 +169,7 @@ def test_generate_morlet_filterbank(): assert np.isclose( power.iloc[:, i] / np.max(power.iloc[:, i]), morlet_ft / np.max(morlet_ft), - atol=0.1, + atol=gaussian_atol, ).all() From 08960269b946f3bd231b00c131fced6b1751f86f Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Mon, 24 Jun 2024 16:34:20 +0200 Subject: [PATCH 151/195] allow IntervalSet creation with iterable of tuples --- pynapple/core/interval_set.py | 17 +++++++++++++---- tests/test_interval_set.py | 17 +++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index c19ba075..596fd6c3 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -94,7 +94,7 @@ def __init__(self, start, end=None, time_units="s"): Parameters ---------- - start : numpy.ndarray or number or pandas.DataFrame or pandas.Series + start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs Beginning of intervals end : numpy.ndarray or number or pandas.Series, optional Ends of intervals @@ -108,8 +108,8 @@ def __init__(self, start, end=None, time_units="s"): """ if isinstance(start, IntervalSet): - end = start.values[:, 1].astype(np.float64) - start = start.values[:, 0].astype(np.float64) + end = start.end.astype(np.float64) + start = start.start.astype(np.float64) elif isinstance(start, pd.DataFrame): assert ( @@ -125,7 +125,16 @@ def __init__(self, start, end=None, time_units="s"): start = start["start"].values.astype(np.float64) else: - assert end is not None, "Missing end argument when initializing IntervalSet" + if end is None: + # Require iterable of (start, end) tuples + try: + start_end_array = np.array(list(start)) + if start_end_array.ndim == 1: + start, end = start_end_array + else: + start, end = zip(*start_end_array) + except (TypeError, ValueError): + raise ValueError("Unable to Interpret the input. Please provide a list of start-end intervals.") args = {"start": start, "end": end} diff --git a/tests/test_interval_set.py b/tests/test_interval_set.py index 2c8fbf47..ec7105d0 100644 --- a/tests/test_interval_set.py +++ b/tests/test_interval_set.py @@ -79,6 +79,23 @@ def test_create_iset_from_mock_array(): np.testing.assert_array_almost_equal(ep.start, start) np.testing.assert_array_almost_equal(ep.end, end) +def test_create_iset_from_tuple(): + start = 0 + end = 5 + ep = nap.IntervalSet((start, end)) + assert isinstance(ep, nap.core.interval_set.IntervalSet) + np.testing.assert_array_almost_equal(start, ep.start[0]) + np.testing.assert_array_almost_equal(end, ep.end[0]) + +def test_create_iset_from_tuple_iter(): + start = [0, 10, 16, 25] + end = [5, 15, 20, 40] + pairs = zip(start, end) + ep = nap.IntervalSet(pairs) + assert isinstance(ep, nap.core.interval_set.IntervalSet) + np.testing.assert_array_almost_equal(start, ep.start) + np.testing.assert_array_almost_equal(end, ep.end) + def test_create_iset_from_unknown_format(): with pytest.raises(RuntimeError) as e: nap.IntervalSet(start="abc", end=[1, 2]) From fc6f731de8c7a44888bbe187cd7b35394e7b8585 Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Fri, 9 Aug 2024 14:24:24 -0400 Subject: [PATCH 152/195] make IntervalSet(start_end_pairs) initialization clearer and more robust --- pynapple/core/interval_set.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 596fd6c3..147c1316 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -95,7 +95,12 @@ def __init__(self, start, end=None, time_units="s"): Parameters ---------- start : numpy.ndarray or number or pandas.DataFrame or pandas.Series or iterable of (start, end) pairs - Beginning of intervals + Beginning of intervals. Alternatively, the `end` argument can be left out and `start` can be one of the + following: + - IntervalSet + - pandas.DataFrame with columns ["start", "end"] + - iterable of (start, end) pairs + - a single (start, end) pair end : numpy.ndarray or number or pandas.Series, optional Ends of intervals time_units : str, optional @@ -128,13 +133,12 @@ def __init__(self, start, end=None, time_units="s"): if end is None: # Require iterable of (start, end) tuples try: - start_end_array = np.array(list(start)) - if start_end_array.ndim == 1: - start, end = start_end_array - else: - start, end = zip(*start_end_array) + start_end_array = np.array(list(start)).reshape(-1, 2) + start, end = zip(*start_end_array) except (TypeError, ValueError): - raise ValueError("Unable to Interpret the input. Please provide a list of start-end intervals.") + raise ValueError( + "Unable to Interpret the input. Please provide a list of start-end pairs." + ) args = {"start": start, "end": end} From 459224122f1a6490160665991ce7b73221ce72c0 Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Tue, 25 Jun 2024 23:26:52 +0200 Subject: [PATCH 153/195] accept iterable of Ts-compatible objects for TsGroup --- pynapple/core/ts_group.py | 11 +++++++---- tests/test_ts_group.py | 12 ++++++++++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index be6c955c..cab8541e 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -75,9 +75,9 @@ def __init__( Parameters ---------- - data : dict - Dictionary containing Ts/Tsd objects, keys should contain integer values or should be convertible - to integer. + data : dict or iterable + Dictionary or iterable of Ts/Tsd objects. The keys should be integer-convertible; if a non-dict iterator is + passed, its values will be used to create a dict with integer keys. time_support : IntervalSet, optional The time support of the TsGroup. Ts/Tsd objects will be restricted to the time support if passed. If no time support is specified, TsGroup will merge time supports from all the Ts/Tsd objects in data. @@ -117,13 +117,16 @@ def __init__( self._initialized = False + if not isinstance(data, dict): + data = dict(enumerate(data)) + # convert all keys to integer try: keys = [int(k) for k in data.keys()] except Exception: raise ValueError("All keys must be convertible to integer.") - # check that there were no floats with decimal points in keys.i + # check that there were no floats with decimal points in keys. # i.e. 0.5 is not a valid key if not all(np.allclose(keys[j], float(k)) for j, k in enumerate(data.keys())): raise ValueError("All keys must have integer value!}") diff --git a/tests/test_ts_group.py b/tests/test_ts_group.py index b9387fcf..31f54107 100644 --- a/tests/test_ts_group.py +++ b/tests/test_ts_group.py @@ -51,10 +51,19 @@ def test_create_ts_group(self, group): assert isinstance(tsgroup, UserDict) assert len(tsgroup) == 3 + def test_create_ts_group_from_iter(self, group): + tsgroup = nap.TsGroup(group.values()) + assert isinstance(tsgroup, UserDict) + assert len(tsgroup) == 3 + + def test_create_ts_group_from_invalid(self): + with pytest.raises(AttributeError): + tsgroup = nap.TsGroup(np.arange(0, 200)) + @pytest.mark.parametrize( "test_dict, expectation", [ - ({"1": nap.Ts(np.arange(10)), "2":nap.Ts(np.arange(10))}, does_not_raise()), + ({"1": nap.Ts(np.arange(10)), "2": nap.Ts(np.arange(10))}, does_not_raise()), ({"1": nap.Ts(np.arange(10)), 2: nap.Ts(np.arange(10))}, does_not_raise()), ({"1": nap.Ts(np.arange(10)), 1: nap.Ts(np.arange(10))}, pytest.raises(ValueError, match="Two dictionary keys contain the same integer")), @@ -82,7 +91,6 @@ def test_initialize_from_dict(self, test_dict, expectation): def test_metadata_len_match(self, tsgroup): assert len(tsgroup._metadata) == len(tsgroup) - def test_create_ts_group_from_array(self): with warnings.catch_warnings(record=True) as w: nap.TsGroup({ From ce6f373ec39a82adc7a5d7f8f9c549321d3037cc Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Tue, 25 Jun 2024 23:27:25 +0200 Subject: [PATCH 154/195] ensure expected data dict ordering in TsGroup --- pynapple/core/ts_group.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index cab8541e..abc0e8a2 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -138,6 +138,7 @@ def __init__( data = {keys[j]: data[k] for j, k in enumerate(data.keys())} self.index = np.sort(keys) + data = {k: data[k] for k in self.index} # Make sure data dict and index are ordered the same self._metadata = pd.DataFrame(index=self.index, columns=["rate"], dtype="float") From 64ffdc0658b8507a814044cf643b4f4b8cce9f81 Mon Sep 17 00:00:00 2001 From: Erik Schomburg Date: Fri, 9 Aug 2024 14:37:05 -0400 Subject: [PATCH 155/195] black formatting --- pynapple/core/ts_group.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index abc0e8a2..2052f696 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -138,7 +138,8 @@ def __init__( data = {keys[j]: data[k] for j, k in enumerate(data.keys())} self.index = np.sort(keys) - data = {k: data[k] for k in self.index} # Make sure data dict and index are ordered the same + # Make sure data dict and index are ordered the same + data = {k: data[k] for k in self.index} self._metadata = pd.DataFrame(index=self.index, columns=["rate"], dtype="float") From d1851a9b708e1a086938dc1e3c84c58430ff6c22 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:20:33 +0200 Subject: [PATCH 156/195] do not run actions on draft PR --- .github/workflows/main.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0c779db4..8c991883 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -6,6 +6,11 @@ on: branches: [ main ] pull_request: branches: [ main, dev ] + types: + - opened + - reopened + - synchronize + - ready_for_review jobs: lint: From b6f8724283d699ec355d3cc49c67823c0cdeaed4 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:22:19 +0200 Subject: [PATCH 157/195] linted --- pynapple/process/__init__.py | 3 +-- pynapple/process/filtering.py | 32 ++++++++++++++++++++++---------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index ddc5bffa..ac41c4fb 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -4,6 +4,7 @@ compute_eventcorrelogram, ) from .decoding import decode_1d, decode_2d +from .filtering import compute_filtered_signal from .perievent import ( compute_event_trigger_average, compute_perievent, @@ -24,5 +25,3 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) - -from .filtering import compute_filtered_signal diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 364ea407..329df9f5 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -1,12 +1,16 @@ """Filtering module.""" +from numbers import Number + import numpy as np -from .. import core as nap from scipy.signal import butter, filtfilt -from numbers import Number + +from .. import core as nap -def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None): +def compute_filtered_signal( + data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None +): """ Apply a Butterworth filter to the provided signal data. @@ -51,18 +55,26 @@ def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sa sampling_frequency = data.rate if filter_type not in ["lowpass", "highpass", "bandpass", "bandstop"]: - raise ValueError(f"Unrecognized filter type {filter_type}. " - "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'.") + raise ValueError( + f"Unrecognized filter type {filter_type}. " + "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'." + ) elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number): - raise ValueError("Low/high-pass filter specification requires a single frequency. " - f"{freq_band} provided instead!") + raise ValueError( + "Low/high-pass filter specification requires a single frequency. " + f"{freq_band} provided instead!" + ) elif filter_type in ["bandpass", "bandstop"]: try: - if len(freq_band) != 2 or not all(isinstance(fq, Number) for fq in freq_band): + if len(freq_band) != 2 or not all( + isinstance(fq, Number) for fq in freq_band + ): raise ValueError except Exception: - raise ValueError("Band-pass/stop filter specification requires two frequencies. " - f"{freq_band} provided instead!") + raise ValueError( + "Band-pass/stop filter specification requires two frequencies. " + f"{freq_band} provided instead!" + ) b, a = butter(order, freq_band, btype=filter_type, fs=sampling_frequency) From 3467a85ea00d8dd9c09905b686fdfb08fef9763c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:35:14 +0200 Subject: [PATCH 158/195] added tests --- tests/test_filtering.py | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 3d7e2972..784bbcdd 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -46,6 +46,38 @@ def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep): np.testing.assert_array_equal(out.d, out_sci) +@pytest.mark.parametrize("freq", [[10, 30], [100,150]]) +@pytest.mark.parametrize("order", [2, 4, 6]) +@pytest.mark.parametrize("btype", ["bandpass", "bandstop"]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize( + "ep", + [ + nap.IntervalSet(start=[0], end=[1]), + nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), + nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) + ] +) +def test_filtering_freq_band_match_sci(freq, order, btype, shape: tuple, ep): + + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + b, a = signal.butter(order, freq, fs=tsd.rate, btype=btype) + out_sci = [] + for iset in ep: + out_sci.append(signal.filtfilt(b, a, tsd.restrict(iset).d, axis=0)) + out_sci = np.concatenate(out_sci, axis=0) + np.testing.assert_array_equal(out.d, out_sci) + + @pytest.mark.parametrize("freq", [10, 100]) @pytest.mark.parametrize("order", [2, 4, 6]) @pytest.mark.parametrize("btype", ["lowpass", "highpass"]) @@ -76,6 +108,36 @@ def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): assert np.all(tsd.columns == out.columns) +@pytest.mark.parametrize("freq", [[10, 30], [100, 150]]) +@pytest.mark.parametrize("order", [2, 4, 6]) +@pytest.mark.parametrize("btype", ["bandpass", "bandstop"]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize( + "ep", + [ + nap.IntervalSet(start=[0], end=[1]), + nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), + nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) + ] +) +def test_filtering_freq_band_dtype(freq, order, btype, shape: tuple, ep): + t = np.linspace(0, 1, shape[0]) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + + if len(shape) == 1: + tsd = nap.Tsd(t, y, time_support=ep) + elif len(shape) == 2: + tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1])) + else: + tsd = nap.TsdTensor(t, y, time_support=ep) + out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + @pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ ((5, 15), "bandpass", 4, does_not_raise()), ((5, 15), "bandstop", 4, does_not_raise()), @@ -93,3 +155,18 @@ def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, exp if not expected_exception: assert isinstance(filtered_data, type(sample_data)) assert filtered_data.d.shape == sample_data.d.shape + + +# Test with edge-case frequencies close to Nyquist frequency +@pytest.mark.parametrize("nyquist_fraction", [0.99, 0.999]) +@pytest.mark.parametrize("order", [2, 4]) +def test_filtering_nyquist_edge_case(nyquist_fraction, order, sample_data): + nyquist_freq = 0.5 * sample_data.rate + freq = nyquist_freq * nyquist_fraction + + out = nap.filtering.compute_filtered_signal( + sample_data, freq_band=freq, filter_type="lowpass", order=order + ) + assert isinstance(out, type(sample_data)) + np.testing.assert_allclose(out.t, sample_data.t) + np.testing.assert_allclose(out.time_support, sample_data.time_support) From 9a1ed8ad59325858a355f51503acebb15cca80dc Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:41:07 +0200 Subject: [PATCH 159/195] switch to sos filter --- pynapple/process/filtering.py | 6 +++--- tests/test_filtering.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 329df9f5..53326777 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -3,7 +3,7 @@ from numbers import Number import numpy as np -from scipy.signal import butter, filtfilt +from scipy.signal import butter, filtfilt, sosfiltfilt from .. import core as nap @@ -76,12 +76,12 @@ def compute_filtered_signal( f"{freq_band} provided instead!" ) - b, a = butter(order, freq_band, btype=filter_type, fs=sampling_frequency) + sos = butter(order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos") out = np.zeros_like(data.d) for ep in data.time_support: slc = data.get_slice(start=ep.start[0], end=ep.end[0]) - out[slc] = filtfilt(b, a, data.d[slc], axis=0) + out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) kwargs = dict(t=data.t, d=out, time_support=data.time_support) if isinstance(data, nap.TsdFrame): diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 784bbcdd..0e842019 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -38,10 +38,10 @@ def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep): else: tsd = nap.TsdTensor(t, y, time_support=ep) out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) - b, a = signal.butter(order, freq, fs=tsd.rate, btype=btype) + sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos") out_sci = [] for iset in ep: - out_sci.append(signal.filtfilt(b, a, tsd.restrict(iset).d, axis=0)) + out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) out_sci = np.concatenate(out_sci, axis=0) np.testing.assert_array_equal(out.d, out_sci) @@ -70,10 +70,10 @@ def test_filtering_freq_band_match_sci(freq, order, btype, shape: tuple, ep): else: tsd = nap.TsdTensor(t, y, time_support=ep) out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) - b, a = signal.butter(order, freq, fs=tsd.rate, btype=btype) + sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos") out_sci = [] for iset in ep: - out_sci.append(signal.filtfilt(b, a, tsd.restrict(iset).d, axis=0)) + out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) out_sci = np.concatenate(out_sci, axis=0) np.testing.assert_array_equal(out.d, out_sci) From b1916bf9ce220fcb9f978ec93ee3f1754ca854b2 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:41:31 +0200 Subject: [PATCH 160/195] removed unused import --- pynapple/process/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 53326777..bd2e3973 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -3,7 +3,7 @@ from numbers import Number import numpy as np -from scipy.signal import butter, filtfilt, sosfiltfilt +from scipy.signal import butter, sosfiltfilt from .. import core as nap From f669f951524d6b6e228bdd5a1544688c6aa23db9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Thu, 15 Aug 2024 15:45:16 +0200 Subject: [PATCH 161/195] linted --- pynapple/process/filtering.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index bd2e3973..f6263176 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -76,7 +76,9 @@ def compute_filtered_signal( f"{freq_band} provided instead!" ) - sos = butter(order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos") + sos = butter( + order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos" + ) out = np.zeros_like(data.d) for ep in data.time_support: From 6c5d88aa6e54010edd5d7037fce779d37198ca4f Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 15 Aug 2024 11:30:54 -0400 Subject: [PATCH 162/195] few changes --- pynapple/process/filtering.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 364ea407..809ddab7 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -8,10 +8,10 @@ def compute_filtered_signal(data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None): """ - Apply a Butterworth filter to the provided signal data. + Apply a Butterworth filter to the provided signal. - This function performs bandpass filtering on Local Field Potential (LFP) - data using a Butterworth filter. The filter can be configured to be of + This function performs bandpass filtering on time series data + using a Butterworth filter. The filter can be configured to be of type "bandpass", "bandstop", "highpass", or "lowpass". Parameters From 009477a4934146a5cc1dbf8aca12cd777fcbd2c1 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Fri, 16 Aug 2024 08:53:49 +0200 Subject: [PATCH 163/195] Update pynapple/process/filtering.py Co-authored-by: Guillaume Viejo --- pynapple/process/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 921ca532..c45e6066 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -9,7 +9,7 @@ def compute_filtered_signal( - data, freq_band, filter_type="bandpass", order=4, sampling_frequency=None + data, freq_band, fs=None, filter_type="bandpass", order=4 ): """ Apply a Butterworth filter to the provided signal. From 0d2187926a634719d64fa20ad0bf59f7e0852690 Mon Sep 17 00:00:00 2001 From: Edoardo Balzani Date: Fri, 16 Aug 2024 08:53:58 +0200 Subject: [PATCH 164/195] Update pynapple/process/filtering.py Co-authored-by: Guillaume Viejo --- pynapple/process/filtering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index c45e6066..c164c42e 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -31,7 +31,7 @@ def compute_filtered_signal( order : int, optional The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. Default is 4. - sampling_frequency : float, optional +fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. Returns From 292ee298eb27165ffafeb77b557b2905a4cea883 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 16 Aug 2024 09:03:49 +0200 Subject: [PATCH 165/195] changed to sampling_frequency --- pynapple/process/filtering.py | 70 +++++++++++++++++------------------ tests/test_filtering.py | 2 +- 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index c164c42e..76423608 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -8,48 +8,46 @@ from .. import core as nap -def compute_filtered_signal( - data, freq_band, fs=None, filter_type="bandpass", order=4 -): +def compute_filtered_signal(data, freq_band, sampling_frequency=None, filter_type="bandpass", order=4): """ - Apply a Butterworth filter to the provided signal. + Apply a Butterworth filter to the provided signal. - This function performs bandpass filtering on time series data - using a Butterworth filter. The filter can be configured to be of - type "bandpass", "bandstop", "highpass", or "lowpass". + This function performs bandpass filtering on time series data + using a Butterworth filter. The filter can be configured to be of + type "bandpass", "bandstop", "highpass", or "lowpass". - Parameters - ---------- - data : Tsd, TsdFrame, or TsdTensor - The signal to be filtered. - freq_band : tuple of (float, float) or float - Cutoff frequency(ies) in Hz. - - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. - - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. - filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional - The type of frequency filter to apply. Default is "bandpass". - order : int, optional - The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. - Default is 4. -fs : float, optional - The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + freq_band : tuple of (float, float) or float + Cutoff frequency(ies) in Hz. + - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. + - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. + filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional + The type of frequency filter to apply. Default is "bandpass". + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + sampling_frequency : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. - Returns - ------- - filtered_data : Tsd, TsdFrame, or TsdTensor - The filtered signal, with the same data type as the input. + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. - Raises - ------ - ValueError - If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. - If `freq_band` is not a float for "lowpass" and "highpass" filters. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + Raises + ------ + ValueError + If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. + If `freq_band` is not a float for "lowpass" and "highpass" filters. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. - Notes - ----- - The cutoff frequency is defined as the frequency at which the amplitude of the signal - is reduced by -3 dB (decibels). + Notes + ----- + The cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). """ if sampling_frequency is None: sampling_frequency = data.rate diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 0e842019..490548bc 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -151,7 +151,7 @@ def test_filtering_freq_band_dtype(freq, order, btype, shape: tuple, ep): ]) def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception): with expected_exception: - filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type, order) + filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type=filter_type, order=order) if not expected_exception: assert isinstance(filtered_data, type(sample_data)) assert filtered_data.d.shape == sample_data.d.shape From 58e3ca1b2fbc50af929002ab5d364b80cdd39218 Mon Sep 17 00:00:00 2001 From: gviejo Date: Fri, 16 Aug 2024 03:12:57 -0400 Subject: [PATCH 166/195] linting --- pynapple/process/filtering.py | 70 ++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 76423608..f99bb5eb 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -8,46 +8,48 @@ from .. import core as nap -def compute_filtered_signal(data, freq_band, sampling_frequency=None, filter_type="bandpass", order=4): +def compute_filtered_signal( + data, freq_band, sampling_frequency=None, filter_type="bandpass", order=4 +): """ - Apply a Butterworth filter to the provided signal. + Apply a Butterworth filter to the provided signal. - This function performs bandpass filtering on time series data - using a Butterworth filter. The filter can be configured to be of - type "bandpass", "bandstop", "highpass", or "lowpass". + This function performs bandpass filtering on time series data + using a Butterworth filter. The filter can be configured to be of + type "bandpass", "bandstop", "highpass", or "lowpass". - Parameters - ---------- - data : Tsd, TsdFrame, or TsdTensor - The signal to be filtered. - freq_band : tuple of (float, float) or float - Cutoff frequency(ies) in Hz. - - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. - - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. - filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional - The type of frequency filter to apply. Default is "bandpass". - order : int, optional - The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. - Default is 4. - sampling_frequency : float, optional - The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + freq_band : tuple of (float, float) or float + Cutoff frequency(ies) in Hz. + - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. + - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. + filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional + The type of frequency filter to apply. Default is "bandpass". + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + sampling_frequency : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. - Returns - ------- - filtered_data : Tsd, TsdFrame, or TsdTensor - The filtered signal, with the same data type as the input. + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. - Raises - ------ - ValueError - If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. - If `freq_band` is not a float for "lowpass" and "highpass" filters. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + Raises + ------ + ValueError + If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. + If `freq_band` is not a float for "lowpass" and "highpass" filters. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. - Notes - ----- - The cutoff frequency is defined as the frequency at which the amplitude of the signal - is reduced by -3 dB (decibels). + Notes + ----- + The cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). """ if sampling_frequency is None: sampling_frequency = data.rate From 58a95540101fbcbbba43a0f6812641625bed59db Mon Sep 17 00:00:00 2001 From: gviejo Date: Sun, 1 Sep 2024 08:40:32 -0400 Subject: [PATCH 167/195] Adding tests and notebook for filtering module --- .gitignore | 1 + docs/api_guide/tutorial_pynapple_filtering.py | 271 ++++++++++++ pynapple/core/interval_set.py | 3 + pynapple/process/__init__.py | 7 +- pynapple/process/filtering.py | 415 ++++++++++++++++-- tests/test_filtering.py | 277 ++++++++---- 6 files changed, 831 insertions(+), 143 deletions(-) create mode 100644 docs/api_guide/tutorial_pynapple_filtering.py diff --git a/.gitignore b/.gitignore index 5e7adc67..d88327c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +*.npz *.nwb *.pickle *.py.md5 diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py new file mode 100644 index 00000000..21fc5ccf --- /dev/null +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +""" +Filtering +========= + +The filtering module holds the functions for frequency manipulation : + +- `nap.compute_bandstop_filter` +- `nap.compute_lowpass_filter` +- `nap.compute_highpass_filter` +- `nap.compute_bandpass_filter` + +The functions have similar calling signatures. For example, to filter a 1000 Hz signal between +10 and 20 Hz using a Butterworth filter: + +```{python} +>>> new_tsd = nap.compute_bandpass_filter(tsd, (10, 20), fs=1000, mode='butter') +``` + +Currently, the filtering module provides two methods for frequency manipulation: `butter` +for a recursive Butterworth filter and `sinc` for a Windowed-sinc convolution. This notebook provides +a comparison of the two methods. +""" + +# %% +# !!! warning +# This tutorial uses matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib requests tqdm seaborn` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +import matplotlib.pyplot as plt +import numpy as np +import pynapple as nap +import seaborn +seaborn.set_theme() +seaborn.set(font_scale=1.25) + +# %% +# *** +# Introduction +# ------------ +# +# We start by generating a signal with multiple frequencies (2, 10 and 50 Hz). +fs = 1000 # sampling frequency +t = np.linspace(0, 2, fs * 2) +f2 = np.cos(t*2*np.pi*2) +f10 = np.cos(t*2*np.pi*10) +f50 = np.cos(t*2*np.pi*50) + +sig = nap.Tsd(t=t,d=f2+f10+f50 + np.random.normal(0, 0.5, len(t))) + +# %% +# Let's plot it +# fig = plt.figure(figsize = (15, 5)) +# plt.plot(sig) +# plt.xlabel("Time (s)") +# plt.show() + +# %% +# We can compute the Fourier transform of `sig` to verify that all the frequencies are there. +psd = nap.compute_power_spectral_density(sig, fs, norm=True) + +# fig = plt.figure(figsize = (15, 5)) +# plt.plot(np.abs(psd)) +# plt.xlabel("Frequency (Hz)") +# plt.ylabel("Amplitude") +# plt.xlim(0, 100) +# plt.show() + +# %% +# Let's say we would like to see only the 10 Hz component. +# We can use the function `compute_bandpass_filter` with mode `butter` for Butterworth. + +sig_butter = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='butter') +# sig_butter = nap.compute_lowpass_filter(sig, 5, fs, mode='butter') + +# %% +# Let's compare it to the `sinc` mode for Windowed-sinc. +sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc') +# sig_sinc = nap.compute_lowpass_filter(sig, 5, fs, mode='sinc') + +# %% +# Let's plot it +fig = plt.figure(figsize = (15, 5)) +plt.subplot(211) +plt.plot(t, f10, '-', color = 'gray', label = "10 Hz component") +plt.xlim(0, 1) +plt.legend() +plt.subplot(212) +# plt.plot(sig, alpha=0.5) +plt.plot(sig_butter, label = "Butterworth") +plt.plot(sig_sinc, '--', label = "Windowed-sinc") +plt.legend() +plt.xlabel("Time (Hz)") +plt.xlim(0, 1) +plt.show() + +import sys +sys.exit() + +# %% +# This gives similar results except at the edges. +# +# The remaining notebook compares the two modes. +# +# *** +# Frequency responses +# ------------------- +# +# Let's get filter coefficients for a 250Hz recursive low pass filter with different order: +butter_sos = { + order:nap.process.filtering._get_butter_coefficients(250, "lowpass", fs, order=order) + for order in [2, 4, 6]} + +# %% +# ... and the kernel for the Windowed-sinc equivalent with different transition bandwitdh +sinc_kernel = { + tb:nap.process.filtering._get_windowed_sinc_kernel(250/fs, "lowpass", transition_bandwidth=tb) + for tb in [0.02, 0.1, 0.2]} + +# %% +# Let's plot the frequency response of both. + +from scipy import signal + +fig = plt.figure(figsize = (20, 10)) +gs = plt.GridSpec(2, 2) +for order, sos in butter_sos.items(): + plt.subplot(gs[0, 0]) + w, h = signal.sosfreqz(sos, worN=1500, fs=fs) + plt.plot(w, np.abs(h), label = f"order={order}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Butterworth recursive") + plt.subplot(gs[1, 0]) + plt.plot(w, 20 * np.log10(np.abs(h)), label = f"order={order}") + plt.xlabel('Frequency [Hz]') + plt.ylabel('Amplitude [dB]') + plt.ylim(-100,40) + plt.legend() + +for trans_bandwidth, kernel in sinc_kernel.items(): + plt.subplot(gs[0, 1]) + fft_sinc = nap.compute_power_spectral_density( + nap.Tsd(t=np.arange(len(kernel)) / fs, d=kernel), fs) + plt.plot(np.abs(fft_sinc), label= f"width={trans_bandwidth}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Windowed-sinc conv.") + plt.subplot(gs[1, 1]) + plt.plot(20*np.log10(np.abs(fft_sinc)), label= f"width={trans_bandwidth}") + plt.xlabel('Frequency [Hz]') + plt.ylabel('Amplitude [dB]') + plt.ylim(-100,40) + plt.legend() + +plt.show() +# %% +# *** +# Step responses +# -------------- + +step = nap.Tsd(t=sig.t, d=np.zeros(len(sig))) +step[len(step)//2:] = 1.0 + +# %% +# +step_butter = { + order:nap.compute_lowpass_filter(step, 0.2*fs, fs, mode='butter', order=order) + for order in [4, 12]} + +# %% +# Let's compare it to the `sinc` mode for Windowed-sinc. + +step_sinc = { + tb:nap.compute_lowpass_filter(step, 0.2*fs, fs, mode='sinc', transition_bandwidth=tb) + for tb in [0.005, 0.2]} + +# %% +plt.figure() +plt.subplot(211) +plt.plot(step) +for order, filt_step in step_butter.items(): + plt.plot(filt_step, label=f"order={order}") +plt.legend() +plt.xlim(0.95, 1.05) +plt.title("Butterworth filter") +plt.subplot(212) +plt.plot(step) +for tb, filt_step in step_sinc.items(): + plt.plot(filt_step, label=f"width={tb}") +plt.legend() +plt.xlim(0.95, 1.05) +plt.title("Windowed-sinc") +plt.show() + +# %% +# *** +# Performances +# ------------ +# Let's compare the performance of each when varying the number of time points and the number of dimensions. +from time import perf_counter + +def get_mean_perf(tsd, mode, n=10): + tmp = np.zeros(n) + for i in range(n): + t1 = perf_counter() + _ = nap.compute_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode) + t2 = perf_counter() + tmp[i] = t2 - t1 + return [np.mean(tmp), np.std(tmp)] + +def benchmark_time_points(mode): + times = [] + for T in np.arange(1000, 100000, 40000): + time_array = np.arange(T)/1000 + data_array = np.random.randn(len(time_array)) + startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2) + ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1]) + tsd = nap.Tsd(t=time_array, d=data_array, time_support=ep) + times.append([T]+get_mean_perf(tsd, mode)) + return np.array(times) + +def benchmark_dimensions(mode): + times = [] + for n in np.arange(1, 100, 30): + time_array = np.arange(10000)/1000 + data_array = np.random.randn(len(time_array), n) + startend = np.linspace(0, time_array[-1], 10000//100).reshape(10000//200, 2) + ep = nap.IntervalSet(start=startend[::2,0], end=startend[::2,1]) + tsd = nap.TsdFrame(t=time_array, d=data_array, time_support=ep) + times.append([n]+get_mean_perf(tsd, mode)) + return np.array(times) + + +times_sinc = benchmark_time_points(mode="sinc") +times_butter = benchmark_time_points(mode="butter") + +dims_sinc = benchmark_dimensions(mode="sinc") +dims_butter = benchmark_dimensions(mode="butter") + + +plt.figure(figsize = (16, 5)) +plt.subplot(121) +for arr, label in zip( + [times_sinc, times_butter], + ["Windowed-sinc", "Butter"], + ): + plt.plot(arr[:, 0], arr[:, 1], "o-", label=label) + plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2) +plt.legend() +plt.xlabel("Number of time points") +plt.ylabel("Time (s)") +plt.title("Low pass filtering benchmark") +plt.subplot(122) +for arr, label in zip( + [dims_sinc, dims_butter], + ["Windowed-sinc", "Butter"], + ): + plt.plot(arr[:, 0], arr[:, 1], "o-", label=label) + plt.fill_between(arr[:, 0], arr[:, 1] - arr[:, 2], arr[:, 1] + arr[:, 2], alpha=0.2) +plt.legend() +plt.xlabel("Number of dimensions") +plt.ylabel("Time (s)") +plt.title("Low pass filtering benchmark") + +plt.show() diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 147c1316..ef74b18b 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -241,6 +241,9 @@ def __str__(self): def __len__(self): return len(self.values) + # def __iter__(self): + # pass + def __setitem__(self, key, value): raise RuntimeError( "IntervalSet is immutable. Starts and ends have been already sorted." diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 811df65a..6a0ccdb4 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -4,7 +4,12 @@ compute_eventcorrelogram, ) from .decoding import decode_1d, decode_2d -from .filtering import compute_filtered_signal +from .filtering import ( + compute_bandpass_filter, + compute_bandstop_filter, + compute_highpass_filter, + compute_lowpass_filter, +) from .perievent import ( compute_event_trigger_average, compute_perievent, diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index f99bb5eb..fcd1ec38 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -1,5 +1,6 @@ """Filtering module.""" +from functools import wraps from numbers import Number import numpy as np @@ -8,31 +9,238 @@ from .. import core as nap -def compute_filtered_signal( +def _get_butter_coefficients(freq_band, filter_type, sampling_frequency, order=4): + return butter( + order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos" + ) + + +def _compute_butterworth_filter( data, freq_band, sampling_frequency=None, filter_type="bandpass", order=4 ): """ Apply a Butterworth filter to the provided signal. + """ + sos = _get_butter_coefficients(freq_band, filter_type, sampling_frequency, order) + out = np.zeros_like(data.d) + for ep in data.time_support: + slc = data.get_slice(start=ep.start[0], end=ep.end[0]) + out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) + + kwargs = dict(t=data.t, d=out, time_support=data.time_support) + if isinstance(data, nap.TsdFrame): + kwargs["columns"] = data.columns + return data.__class__(**kwargs) + + +def _compute_spectral_inversion(kernel): + """ + Compute the spectral inversion. + Parameters + ---------- + kernel: ndarray + + Returns + ------- + ndarray + """ + kernel *= -1.0 + kernel[len(kernel) // 2] = 1.0 + kernel[len(kernel) // 2] + return kernel + + +def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0.02): + """ + Get the windowed-sinc kernel. + Smith, S. (2003). Digital signal processing: a practical guide for engineers and scientists. + Chapter 16, equation 16-4 + + Parameters + ---------- + fc: float or tuple of float + Cutting frequency between 0 and 0.5. Single float for 'lowpass' and 'highpass'. Tuple of float for + 'bandpass' and 'bandstop'. + filter_type: str + Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'. + transition_bandwidth: float + Percentage between 0 and 0.5 + Returns + ------- + np.ndarray + """ + M = int(np.rint(4.0 / transition_bandwidth)) + x = np.arange(-(M // 2), 1 + (M // 2)) + fc = np.transpose(np.atleast_2d(fc)) + kernel = np.sinc(2 * fc * x) + kernel = np.transpose(kernel * np.blackman(kernel.shape[1])) + kernel = kernel / np.sum(kernel, axis=0) + + if filter_type == "lowpass": + kernel = kernel.flatten() + return kernel + elif filter_type == "highpass": + kernel = _compute_spectral_inversion(kernel.flatten()) + return kernel + elif filter_type == "bandstop": + kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel = np.sum(kernel, axis=1) + return kernel + elif filter_type == "bandpass": + kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel = _compute_spectral_inversion(np.sum(kernel, axis=1)) + return kernel + else: + raise ValueError + + +def _compute_windowed_sinc_filter( + data, freq, sampling_frequency, filter_type="lowpass", transition_bandwidth=0.02 +): + """ + Apply a windowed-sinc filter to the provided signal. + + Parameters + ---------- + filter_type + """ + kernel = _get_windowed_sinc_kernel( + freq / sampling_frequency, filter_type, transition_bandwidth + ) + return data.convolve(kernel) + + +def _validate_filtering_inputs(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Validate each positional argument + if not isinstance(args[0], nap.time_series.BaseTsd): + raise ValueError( + f"Invalid value: {args[0]}. First argument should be of type Tsd, TsdFrame or TsdTensor" + ) + if not isinstance(args[1], Number): + if len(args[1]) != 2 or not all(isinstance(fq, Number) for fq in args[1]): + raise ValueError + + # Validate each keyword argument + if "fs" in kwargs: + if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): + raise ValueError( + "Invalid value for 'fs'. Parameter 'fs' should be of type float or int" + ) + + if "order" in kwargs: + if not isinstance(kwargs["order"], int): + raise ValueError( + "Invalid value for 'order': Parameter 'order' should be of type int" + ) - This function performs bandpass filtering on time series data - using a Butterworth filter. The filter can be configured to be of - type "bandpass", "bandstop", "highpass", or "lowpass". + if "transition_bandwidth" in kwargs: + if not isinstance(kwargs["transition_bandwidth"], float): + raise ValueError( + "Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float" + ) + + # Call the original function with validated inputs + return func(*args, **kwargs) + + return wrapper + + +@_validate_filtering_inputs +def compute_bandpass_filter( + data, freq_band, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a band-pass filter to the provided signal. + Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. + - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. Parameters ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - freq_band : tuple of (float, float) or float - Cutoff frequency(ies) in Hz. - - For "bandpass" and "bandstop" filters, provide a tuple specifying the two cutoff frequencies. - - For "lowpass" and "highpass" filters, provide a single float specifying the cutoff frequency. - filter_type : {'bandpass', 'bandstop', 'highpass', 'lowpass'}, optional - The type of frequency filter to apply. Default is "bandpass". + freq_band : tuple of (float, float) + Cutoff frequencies in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. order : int, optional The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. Default is 4. - sampling_frequency : float, optional + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + if fs is None: + fs = data.rate + + freq_band = np.array(freq_band) + + if mode == "butter": + return _compute_butterworth_filter( + data, freq_band, fs, filter_type="bandpass", order=order + ) + if mode == "sinc": + return _compute_windowed_sinc_filter( + data, + freq_band, + fs, + filter_type="bandpass", + transition_bandwidth=transition_bandwidth, + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + + +@_validate_filtering_inputs +def compute_bandstop_filter( + data, freq_band, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a band-stop filter to the provided signal. + Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. + - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + freq_band : tuple of (float, float) + Cutoff frequencies in Hz. + fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. Returns ------- @@ -42,50 +250,165 @@ def compute_filtered_signal( Raises ------ ValueError - If `filter_type` is not one of {"bandpass", "bandstop", "highpass", "lowpass"}. - If `freq_band` is not a float for "lowpass" and "highpass" filters. + If `data` is not a Tsd, TsdFrame, or TsdTensor. If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + if fs is None: + fs = data.rate + + freq_band = np.array(freq_band) + + if mode == "butter": + return _compute_butterworth_filter( + data, freq_band, fs, filter_type="bandstop", order=order + ) + elif mode == "sinc": + return _compute_windowed_sinc_filter( + data, + freq_band, + fs, + filter_type="bandstop", + transition_bandwidth=transition_bandwidth, + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + + +@_validate_filtering_inputs +def compute_highpass_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a high-pass filter to the provided signal. + Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. + - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : float + Cutoff frequency in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. Notes ----- - The cutoff frequency is defined as the frequency at which the amplitude of the signal + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal is reduced by -3 dB (decibels). """ - if sampling_frequency is None: - sampling_frequency = data.rate + if fs is None: + fs = data.rate - if filter_type not in ["lowpass", "highpass", "bandpass", "bandstop"]: - raise ValueError( - f"Unrecognized filter type {filter_type}. " - "filter_type must be either 'lowpass', 'highpass', 'bandpass',or 'bandstop'." + if mode == "butter": + return _compute_butterworth_filter( + data, cutoff, fs, filter_type="highpass", order=order ) - elif filter_type in ["lowpass", "highpass"] and not isinstance(freq_band, Number): - raise ValueError( - "Low/high-pass filter specification requires a single frequency. " - f"{freq_band} provided instead!" + elif mode == "sinc": + return _compute_windowed_sinc_filter( + data, + cutoff, + fs, + filter_type="highpass", + transition_bandwidth=transition_bandwidth, ) - elif filter_type in ["bandpass", "bandstop"]: - try: - if len(freq_band) != 2 or not all( - isinstance(fq, Number) for fq in freq_band - ): - raise ValueError - except Exception: - raise ValueError( - "Band-pass/stop filter specification requires two frequencies. " - f"{freq_band} provided instead!" - ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") - sos = butter( - order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos" - ) - out = np.zeros_like(data.d) - for ep in data.time_support: - slc = data.get_slice(start=ep.start[0], end=ep.end[0]) - out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) +@_validate_filtering_inputs +def compute_lowpass_filter( + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 +): + """ + Apply a low-pass filter to the provided signal. + Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. + - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. - kwargs = dict(t=data.t, d=out, time_support=data.time_support) - if isinstance(data, nap.TsdFrame): - kwargs["columns"] = data.columns - return data.__class__(**kwargs) + Parameters + ---------- + data : Tsd, TsdFrame, or TsdTensor + The signal to be filtered. + cutoff : float + Cutoff frequency in Hz. + fs : float, optional + The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. + mode : {'butter', 'sinc'}, optional + Filtering mode. Default is 'butter'. + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + filtered_data : Tsd, TsdFrame, or TsdTensor + The filtered signal, with the same data type as the input. + + Raises + ------ + ValueError + If `data` is not a Tsd, TsdFrame, or TsdTensor. + If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `fs` is not float or None. + If `mode` is not "butter" or "sinc". + If `order` is not an int. + If "transition_bandwidth" is not a float. + Notes + ----- + For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal + is reduced by -3 dB (decibels). + """ + if fs is None: + fs = data.rate + + if mode == "butter": + return _compute_butterworth_filter( + data, cutoff, fs, filter_type="lowpass", order=order + ) + elif mode == "sinc": + return _compute_windowed_sinc_filter( + data, + cutoff, + fs, + filter_type="lowpass", + transition_bandwidth=transition_bandwidth, + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 490548bc..04969083 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -2,10 +2,11 @@ import pynapple as nap import numpy as np from scipy import signal +import warnings from contextlib import nullcontext as does_not_raise -@pytest.fixture +# @pytest.fixture def sample_data(): # Create a sample Tsd data object t = np.linspace(0, 1, 500) @@ -14,22 +15,30 @@ def sample_data(): return nap.Tsd(t=t, d=d, time_support=time_support) -@pytest.mark.parametrize("freq", [10, 100]) -@pytest.mark.parametrize("order", [2, 4, 6]) -@pytest.mark.parametrize("btype", ["lowpass", "highpass"]) -@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) -@pytest.mark.parametrize( - "ep", - [ - nap.IntervalSet(start=[0], end=[1]), - nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), - nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) - ] -) -def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep): +def compare_scipy(tsd, ep, order, freq, fs, btype): + sos = signal.butter(order, freq, btype=btype, fs=fs, output="sos") + out_sci = [] + for iset in ep: + out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) + out_sci = np.concatenate(out_sci, axis=0) + return out_sci + +def compare_sinc(tsd, ep, transition_bandwidth, freq, fs, ftype): + + kernel = nap.process.filtering._get_windowed_sinc_kernel(freq/fs, ftype, transition_bandwidth) + return tsd.convolve(kernel, ep).d + +@pytest.mark.parametrize("freq", [10]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_low_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): t = np.linspace(0, 1, shape[0]) - y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) if len(shape) == 1: tsd = nap.Tsd(t, y, time_support=ep) @@ -37,31 +46,36 @@ def test_filtering_single_freq_match_sci(freq, order, btype, shape: tuple, ep): tsd = nap.TsdFrame(t, y, time_support=ep) else: tsd = nap.TsdTensor(t, y, time_support=ep) - out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) - sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos") - out_sci = [] - for iset in ep: - out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) - out_sci = np.concatenate(out_sci, axis=0) - np.testing.assert_array_equal(out.d, out_sci) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + out = nap.compute_lowpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "lowpass") + np.testing.assert_array_equal(out.d, out_sci) -@pytest.mark.parametrize("freq", [[10, 30], [100,150]]) -@pytest.mark.parametrize("order", [2, 4, 6]) -@pytest.mark.parametrize("btype", ["bandpass", "bandstop"]) -@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) -@pytest.mark.parametrize( - "ep", - [ - nap.IntervalSet(start=[0], end=[1]), - nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), - nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) - ] -) -def test_filtering_freq_band_match_sci(freq, order, btype, shape: tuple, ep): + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "lowpass") + np.testing.assert_array_equal(out.d, out_sinc) + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq", [10]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) +@pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_high_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): t = np.linspace(0, 1, shape[0]) - y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) if len(shape) == 1: tsd = nap.Tsd(t, y, time_support=ep) @@ -69,38 +83,58 @@ def test_filtering_freq_band_match_sci(freq, order, btype, shape: tuple, ep): tsd = nap.TsdFrame(t, y, time_support=ep) else: tsd = nap.TsdTensor(t, y, time_support=ep) - out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) - sos = signal.butter(order, freq, fs=tsd.rate, btype=btype, output="sos") - out_sci = [] - for iset in ep: - out_sci.append(signal.sosfiltfilt(sos, tsd.restrict(iset).d, axis=0)) - out_sci = np.concatenate(out_sci, axis=0) - np.testing.assert_array_equal(out.d, out_sci) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + out = nap.compute_highpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) -@pytest.mark.parametrize("freq", [10, 100]) -@pytest.mark.parametrize("order", [2, 4, 6]) -@pytest.mark.parametrize("btype", ["lowpass", "highpass"]) + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "highpass") + np.testing.assert_array_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "highpass") + np.testing.assert_array_equal(out.d, out_sci) + + assert isinstance(out, type(tsd)) + assert np.all(out.t == tsd.t) + assert np.all(out.time_support == tsd.time_support) + if isinstance(tsd, nap.TsdFrame): + assert np.all(tsd.columns == out.columns) + + +@pytest.mark.parametrize("freq", [[10, 30]]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) @pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) -@pytest.mark.parametrize( - "ep", - [ - nap.IntervalSet(start=[0], end=[1]), - nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), - nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) - ] -) -def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_bandpass(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): t = np.linspace(0, 1, shape[0]) - y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) if len(shape) == 1: tsd = nap.Tsd(t, y, time_support=ep) elif len(shape) == 2: - tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1])) + tsd = nap.TsdFrame(t, y, time_support=ep) else: tsd = nap.TsdTensor(t, y, time_support=ep) - out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.compute_bandpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandpass") + np.testing.assert_array_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandpass") + np.testing.assert_array_equal(out.d, out_sci) + assert isinstance(out, type(tsd)) assert np.all(out.t == tsd.t) assert np.all(out.time_support == tsd.time_support) @@ -108,65 +142,116 @@ def test_filtering_single_freq_dtype(freq, order, btype, shape: tuple, ep): assert np.all(tsd.columns == out.columns) -@pytest.mark.parametrize("freq", [[10, 30], [100, 150]]) -@pytest.mark.parametrize("order", [2, 4, 6]) -@pytest.mark.parametrize("btype", ["bandpass", "bandstop"]) +@pytest.mark.parametrize("freq", [[10, 30]]) +@pytest.mark.parametrize("mode", ["butter", "sinc"]) +@pytest.mark.parametrize("order", [2, 4]) +@pytest.mark.parametrize("transition_bandwidth", [0.02]) @pytest.mark.parametrize("shape", [(5000,), (5000, 2), (5000, 2, 3)]) -@pytest.mark.parametrize( - "ep", - [ - nap.IntervalSet(start=[0], end=[1]), - nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), - nap.IntervalSet(start=[0, 0.5, 0.95], end=[0.4, 0.9, 1]) - ] -) -def test_filtering_freq_band_dtype(freq, order, btype, shape: tuple, ep): +@pytest.mark.parametrize("sampling_frequency", [None, 5000.0]) +@pytest.mark.parametrize("ep", [nap.IntervalSet(start=[0], end=[1]), nap.IntervalSet(start=[0, 0.5], end=[0.4, 1]), ]) +def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequency, ep): t = np.linspace(0, 1, shape[0]) - y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1]*(len(shape) - 1)) + np.random.normal(size=shape)) + y = np.squeeze(np.cos(np.pi * 2 * 80 * t).reshape(-1, *[1] * (len(shape) - 1)) + np.random.normal(size=shape)) if len(shape) == 1: tsd = nap.Tsd(t, y, time_support=ep) elif len(shape) == 2: - tsd = nap.TsdFrame(t, y, time_support=ep, columns=np.arange(10, 10 + y.shape[1])) + tsd = nap.TsdFrame(t, y, time_support=ep) else: tsd = nap.TsdTensor(t, y, time_support=ep) - out = nap.compute_filtered_signal(tsd, freq_band=freq, filter_type=btype, order=order) + if sampling_frequency is not None and sampling_frequency != tsd.rate: + sampling_frequency = tsd.rate + + out = nap.compute_bandstop_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) + + if mode == "sinc": + out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandstop") + np.testing.assert_array_equal(out.d, out_sinc) + + if mode == "butter": + out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandstop") + np.testing.assert_array_equal(out.d, out_sci) + + assert isinstance(out, type(tsd)) assert np.all(out.t == tsd.t) assert np.all(out.time_support == tsd.time_support) if isinstance(tsd, nap.TsdFrame): assert np.all(tsd.columns == out.columns) +######################################################################## +# Errors +######################################################################## +@pytest.mark.parametrize("func, freq", [ + (nap.compute_lowpass_filter, 10), + (nap.compute_highpass_filter, 10), + (nap.compute_bandpass_filter, [10, 20]), + (nap.compute_bandstop_filter, [10, 20]), +]) +@pytest.mark.parametrize("data, fs, mode, order, transition_bandwidth, expected_exception", [ + (sample_data(), None, "butter", "a", 0.02, pytest.raises(ValueError,match="Invalid value for 'order': Parameter 'order' should be of type int")), + ("invalid_data", None, "butter", 4, 0.02, pytest.raises(ValueError,match="Invalid value: invalid_data. First argument should be of type Tsd, TsdFrame or TsdTensor")), + (sample_data(), None, "invalid_mode", 4, 0.02, pytest.raises(ValueError,match="Unrecognized filter mode. Choose either 'butter' or 'sinc'")), + (sample_data(), "invalid_fs", "butter", 4, 0.02, pytest.raises(ValueError,match="Invalid value for 'fs'. Parameter 'fs' should be of type float or int")), + (sample_data(), None, "sinc", 4, "a", pytest.raises(ValueError,match="Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float")), +]) +def test_compute_filtered_signal_raise_errors(func, freq, data, fs, mode, order, transition_bandwidth, expected_exception): + with expected_exception: + func(data, freq, fs=fs, mode=mode, order=order, transition_bandwidth=transition_bandwidth) -@pytest.mark.parametrize("freq_band, filter_type, order, expected_exception", [ - ((5, 15), "bandpass", 4, does_not_raise()), - ((5, 15), "bandstop", 4, does_not_raise()), - (10, "highpass", 4, does_not_raise()), - (10, "lowpass", 4, does_not_raise()), - ((5, 15), "invalid_filter", 4, pytest.raises(ValueError, match="Unrecognized filter type")), - (10, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")), - ((5, 15), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency")), - (None, "bandpass", 4, pytest.raises(ValueError, match="Band-pass/stop filter specification requires two frequencies")), - ((None, 1), "highpass", 4, pytest.raises(ValueError, match="Low/high-pass filter specification requires a single frequency")) +@pytest.mark.parametrize("func, freq, expected_exception", [ + (nap.compute_lowpass_filter, "a", pytest.raises(ValueError)), + (nap.compute_highpass_filter, "b", pytest.raises(ValueError)), + (nap.compute_bandpass_filter, [10, "b"], pytest.raises(ValueError)), + (nap.compute_bandstop_filter, [10, 20, 30], pytest.raises(ValueError)), ]) -def test_compute_filtered_signal(sample_data, freq_band, filter_type, order, expected_exception): +def test_compute_filtered_signal_bad_freq(func, freq, expected_exception): with expected_exception: - filtered_data = nap.filtering.compute_filtered_signal(sample_data, freq_band, filter_type=filter_type, order=order) - if not expected_exception: - assert isinstance(filtered_data, type(sample_data)) - assert filtered_data.d.shape == sample_data.d.shape + func(sample_data(), freq) +################################################################# # Test with edge-case frequencies close to Nyquist frequency @pytest.mark.parametrize("nyquist_fraction", [0.99, 0.999]) @pytest.mark.parametrize("order", [2, 4]) -def test_filtering_nyquist_edge_case(nyquist_fraction, order, sample_data): - nyquist_freq = 0.5 * sample_data.rate +def test_filtering_nyquist_edge_case(nyquist_fraction, order): + data = sample_data() + nyquist_freq = 0.5 * data.rate freq = nyquist_freq * nyquist_fraction - out = nap.filtering.compute_filtered_signal( - sample_data, freq_band=freq, filter_type="lowpass", order=order - ) - assert isinstance(out, type(sample_data)) - np.testing.assert_allclose(out.t, sample_data.t) - np.testing.assert_allclose(out.time_support, sample_data.time_support) + out = nap.filtering.compute_lowpass_filter(data, freq, order=order) + assert isinstance(out, type(data)) + np.testing.assert_allclose(out.t, data.t) + np.testing.assert_allclose(out.time_support, data.time_support) + +################################################################# +# Test windowedsinc kernel + +@pytest.mark.parametrize("tb", [0.2, 0.3]) +def test_get_odd_kernel(tb): + kernel = nap.process.filtering._get_windowed_sinc_kernel(0.25, transition_bandwidth=tb) + assert len(kernel)%2 != 0 + +@pytest.mark.parametrize("filter_type, expected_exception", [ + ("a", pytest.raises(ValueError)), +]) +def test_get_kernel_error(filter_type, expected_exception): + with expected_exception: + nap.process.filtering._get_windowed_sinc_kernel(0.25, filter_type=filter_type) + +def test_compare_sinc_kernel(): + kernel = nap.process.filtering._get_windowed_sinc_kernel(0.25) + x = np.arange(-(len(kernel)//2), 1+len(kernel)//2) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + kernel2 = np.sin(2*np.pi*x*0.25)/x#(2*np.pi*x*0.25) + kernel2[len(kernel)//2] = 0.25*2*np.pi + kernel2 = kernel2 * np.blackman(len(kernel2)) + kernel2 = kernel2/kernel2.sum() + np.testing.assert_allclose(kernel, kernel2) + + ikernel = nap.process.filtering._compute_spectral_inversion(kernel) + ikernel2 = kernel2 * -1.0 + ikernel2[len(ikernel2) // 2] = 1.0 + ikernel2[len(kernel2) // 2] + np.testing.assert_allclose(ikernel, ikernel2) From 4b12da622aa09fc8f11c2bd57f149daf6852f670 Mon Sep 17 00:00:00 2001 From: gviejo Date: Mon, 2 Sep 2024 05:46:40 -0400 Subject: [PATCH 168/195] Fixing notebooks --- docs/api_guide/tutorial_pynapple_filtering.py | 29 +- docs/examples/tutorial_phase_preferences.py | 272 ++++++------------ pynapple/process/filtering.py | 59 ++-- tests/test_filtering.py | 5 + 4 files changed, 140 insertions(+), 225 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index 21fc5ccf..d917990b 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -34,8 +34,10 @@ import matplotlib.pyplot as plt import numpy as np -import pynapple as nap import seaborn + +import pynapple as nap + seaborn.set_theme() seaborn.set(font_scale=1.25) @@ -55,33 +57,31 @@ # %% # Let's plot it -# fig = plt.figure(figsize = (15, 5)) -# plt.plot(sig) -# plt.xlabel("Time (s)") -# plt.show() +fig = plt.figure(figsize = (15, 5)) +plt.plot(sig) +plt.xlabel("Time (s)") +plt.show() # %% # We can compute the Fourier transform of `sig` to verify that all the frequencies are there. psd = nap.compute_power_spectral_density(sig, fs, norm=True) -# fig = plt.figure(figsize = (15, 5)) -# plt.plot(np.abs(psd)) -# plt.xlabel("Frequency (Hz)") -# plt.ylabel("Amplitude") -# plt.xlim(0, 100) -# plt.show() +fig = plt.figure(figsize = (15, 5)) +plt.plot(np.abs(psd)) +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 100) +plt.show() # %% # Let's say we would like to see only the 10 Hz component. # We can use the function `compute_bandpass_filter` with mode `butter` for Butterworth. sig_butter = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='butter') -# sig_butter = nap.compute_lowpass_filter(sig, 5, fs, mode='butter') # %% # Let's compare it to the `sinc` mode for Windowed-sinc. sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc') -# sig_sinc = nap.compute_lowpass_filter(sig, 5, fs, mode='sinc') # %% # Let's plot it @@ -99,9 +99,6 @@ plt.xlim(0, 1) plt.show() -import sys -sys.exit() - # %% # This gives similar results except at the edges. # diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 5af93b3c..d826d2c3 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -3,12 +3,12 @@ Spikes-phase coupling ===================== -In this tutorial we will learn how to isolate phase information from our wavelet decomposition and combine it +In this tutorial we will learn how to isolate phase information using band-pass filtering and combine it with spiking data, to find phase preferences of spiking units. Specifically, we will examine LFP and spiking data from a period of REM sleep, after traversal of a linear track. -This tutorial was made by [Kipp Freud](https://kippfreud.com/). +This tutorial was made by [Kipp Freud](https://kippfreud.com/) & Guillaume Viejo """ # %% @@ -32,7 +32,8 @@ import seaborn import tqdm -seaborn.set_theme() +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +seaborn.set_theme(context='notebook', style="ticks", rc=custom_params) import pynapple as nap @@ -71,19 +72,19 @@ # *** # Selecting slices # ----------------------------------- -# Let's consider a 10-second slice of data taken during REM sleep +# For later visualization, we define an interval of 3 seconds of data during REM sleep. -# Define the IntervalSet for this run and instantiate both LFP and -# Position TsdFrame objects -REM_minute_interval = nap.IntervalSet( - data["rem"]["start"][0] + 95.0, +ep_ex_rem = nap.IntervalSet( + data["rem"]["start"][0] + 97.0, data["rem"]["start"][0] + 100.0, ) -REM_Tsd = data["eeg"].restrict(REM_minute_interval) +# %% +# Here we restrict the lfp to the REM epochs. +tsd_rem = data["eeg"][:,0].restrict(data["rem"]) # We will also extract spike times from all units in our dataset -# which occur during our specified interval -spikes = data["units"].restrict(REM_minute_interval) +# which occur during REM sleep +spikes = data["units"].restrict(data["rem"]) # %% # *** @@ -92,15 +93,12 @@ # We should first plot our REM Local Field Potential data. fig, ax = plt.subplots(1, constrained_layout=True, figsize=(10, 3)) -ax.plot( - REM_Tsd, - label="REM LFP Data", -) +ax.plot(tsd_rem.restrict(ep_ex_rem)) ax.set_title("REM Local Field Potential") ax.set_ylabel("LFP (a.u.)") ax.set_xlabel("time (s)") -ax.margins(0) -ax.legend() + + # %% # *** @@ -110,12 +108,15 @@ # - this is a common feature of REM sleep. Let's perform a wavelet decomposition, # as we did in the last tutorial, to see get a more informative breakdown of the # frequencies present in the data. +# +# We must define the frequency set that we'd like to use for our decomposition. - -# We must define the frequency set that we'd like to use for our decomposition freqs = np.geomspace(5, 200, 25) -# Compute the wavelet transform on our LFP data -mwt_REM = nap.compute_wavelet_transform(REM_Tsd[:, 0], fs=FS, freqs=freqs) + +# %% +# We compute the wavelet transform on our LFP data (only during the example interval). + +cwt_rem = nap.compute_wavelet_transform(tsd_rem.restrict(ep_ex_rem), fs=FS, freqs=freqs) # %% # *** @@ -124,29 +125,25 @@ # Define wavelet decomposition plotting function def plot_timefrequency(freqs, powers, ax=None): - im = ax.imshow(abs(powers), aspect="auto") + im = ax.imshow(np.abs(powers), aspect="auto") ax.invert_yaxis() ax.set_xlabel("Time (s)") ax.set_ylabel("Frequency (Hz)") ax.get_xaxis().set_visible(False) - ax.set( - yticks=[np.argmin(np.abs(freqs - val)) for val in freqs], - yticklabels=np.round(freqs, 2), - ) + ax.set(yticks=np.arange(len(freqs))[::2], yticklabels=np.rint(freqs[::2])) ax.grid(False) return im - fig = plt.figure(constrained_layout=True, figsize=(10, 6)) fig.suptitle("Wavelet Decomposition") gs = plt.GridSpec(2, 1, figure=fig, height_ratios=[1.0, 0.3]) ax0 = plt.subplot(gs[0, 0]) -im = plot_timefrequency(freqs[:], np.transpose(mwt_REM[:, :].values), ax=ax0) +im = plot_timefrequency(freqs, np.transpose(cwt_rem[:, :].values), ax=ax0) cbar = fig.colorbar(im, ax=ax0, orientation="vertical") ax1 = plt.subplot(gs[1, 0]) -ax1.plot(REM_Tsd) +ax1.plot(tsd_rem.restrict(ep_ex_rem)) ax1.set_ylabel("LFP (a.u.)") ax1.set_xlabel("Time (s)") ax1.margins(0) @@ -154,187 +151,94 @@ def plot_timefrequency(freqs, powers, ax=None): # %% # *** -# Visualizing Theta Band Power and Phase -# ----------------------------------- -# There seems to be a strong theta frequency present in the data during the maze traversal. -# Let's plot the estimated 7Hz component of the wavelet decomposition on top of our data, and see how well -# they match up. We will also extract and plot the phase of the 7Hz wavelet from the decomposition. -theta_freq_index = np.argmin(np.abs(8 - freqs)) -theta_band_reconstruction = mwt_REM[:, theta_freq_index].values.real -# calculating phase here -theta_band_phase = nap.Tsd( - t=mwt_REM.index, d=np.angle(mwt_REM[:, theta_freq_index].values) -) +# Filtering Theta +# --------------- +# +# As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.compute_bandpass_filter`. + +theta_band = nap.compute_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) # %% -# *** -# Now let's plot the theta power and phase, along with the LFP. +# We can plot the original signal and the filtered signal. -fig, (ax1, ax2) = plt.subplots( - 2, 1, constrained_layout=True, figsize=(10, 5), height_ratios=[0.4, 0.2] -) +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.xlabel("Time (s)") +plt.show() -ax1.plot(REM_Tsd, alpha=0.5, label="LFP Data - REM") -ax1.plot( - REM_Tsd.index.values, - theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", -) -ax1.set( - ylabel="LFP (v)", - title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", -) -ax1.get_xaxis().set_visible(False) -ax1.legend() -ax1.margins(0) -ax2.plot(REM_Tsd.index.values, theta_band_phase, alpha=0.5) -ax2.set(ylabel="Phase", xlabel="Time (s)") -ax2.margins(0) +# %% +# *** +# Computing phase +# --------------- +# +# From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. +from scipy import signal + +theta_phase = nap.Tsd(t=theta_band.t, d=np.angle(signal.hilbert(theta_band))) +# %% +# Let's plot the phase. + +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.subplot(211) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.subplot(212) +plt.plot(theta_phase.restrict(ep_ex_rem), color='r') +plt.ylabel("Phase (rad)") +plt.xlabel("Time (s)") plt.show() + # %% # *** # Finding Phase of Spikes -# ----------------------------------- +# ----------------------- # Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences -# of each of the units using the compute_1d_tuning_curves function. +# of each of the units using the `compute_1d_tuning_curves` function. # # We will start by throwing away cells which do not have a high enough firing rate during our interval. - -# Filter units based on firing rate spikes = spikes[spikes.rate > 5.0] -# Calculate theta phase firing preferences -tuning_curves = nap.compute_1d_tuning_curves( - group=spikes, feature=theta_band_phase, nb_bins=61, minmax=(-np.pi, np.pi) -) # %% -# *** -# Now we will plot these preferences as smoothed angular histograms. We will select the first 6 units -# to plot. - - -def smoothAngularTuningCurves(tuning_curves, sigma=2): - tmp = np.concatenate([tuning_curves.values] * 3) - tmp = scipy.ndimage.gaussian_filter1d(tmp, sigma=sigma, axis=0) - return pd.DataFrame( - tmp[tuning_curves.shape[0] : 2 * tuning_curves.shape[0]], - index=tuning_curves.index, - columns=tuning_curves.columns, - ) - - -smoothcurves = smoothAngularTuningCurves(tuning_curves, sigma=2) -fig, axd = plt.subplot_mosaic( - [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], - constrained_layout=True, - figsize=(10, 6), - subplot_kw={"projection": "polar"}, +# The feature is the theta phase during REM sleep. + +phase_modulation = nap.compute_1d_tuning_curves( + group=spikes, feature=theta_phase, nb_bins=61, minmax=(-np.pi, np.pi) ) -for i, unit in enumerate(list(smoothcurves)[:6]): - ax = axd[f"phase_{i}"] - ax.plot( - list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], - list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], - ) - ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") - ax.set_xticks([]) +# %% +# Let's plot the first 3 neurons. -fig.suptitle("Phase Preference Histograms of First 6 Units") +plt.figure(constrained_layout=True, figsize = (12, 3)) +for i in range(3): + plt.subplot(1,3,i+1) + plt.plot(phase_modulation.iloc[:,i]) + plt.xlabel("Phase (rad)") + plt.ylabel("Firing rate (Hz)") plt.show() - # %% -# *** -# Isolating Strong Phase Preferences -# ----------------------------------- -# It looks like there could be some phase preferences happening here, but there's a lot of cells to go through. -# Now that we have our phases of firing for each unit, we can sort the units by the circular variance of the phase -# of their spikes, to isolate the cells with the strongest phase preferences without manual inspection. - -# Get phase of each spike -phase = {} -for i in spikes: - phase_i = [ - theta_band_phase[np.argmin(np.abs(REM_Tsd.index.values - s.index))] - for s in spikes[i] - ] - phase[i] = np.array(phase_i) -phase_var = { - key: scipy.stats.circvar(value, low=-np.pi, high=np.pi) - for key, value in phase.items() -} -phase_var = dict(sorted(phase_var.items(), key=lambda item: item[1])) +# There is clearly a strong modulation for the third neuron. +# Finally, we can use the function `value_from` to align each spikes to the corresponding phase position and overlay +# it with the LFP. -# %% -# *** -# And now we plot the phase preference histograms of the 6 units with the least variance in the phase of their -# spiking behaviour. - -fig, axd = plt.subplot_mosaic( - [["phase_0", "phase_1", "phase_2"], ["phase_3", "phase_4", "phase_5"]], - constrained_layout=True, - figsize=(10, 6), - subplot_kw={"projection": "polar"}, -) - -for i, unit in enumerate(list(phase_var.keys())[:6]): - ax = axd[f"phase_{i}"] - ax.plot( - list(smoothcurves[unit].index) + [smoothcurves[unit].index[0]], - list(smoothcurves[unit].values) + [smoothcurves[unit].values[0]], - ) - ax.set(xlabel="Phase (rad)", ylabel="Firing Rate (Hz)", title=f"Unit {unit}") - ax.set_xticks([]) - -fig.suptitle("Phase Preference Histograms of 6 Units with Highest Phase Preference") -plt.show() +spike_phase = spikes[spikes.index[3]].value_from(theta_phase) # %% -# *** -# Visualizing Phase Preferences -# ----------------------------------- -# There is definitely some strong phase preferences happening here. Let's visualize the firing preferences -# of the 6 cells we've isolated to get an impression of just how striking these preferences are. - -fig, axd = plt.subplot_mosaic( - [ - ["lfp_run"], - ["phase_0"], - ["phase_1"], - ["phase_2"], - ], - constrained_layout=True, - figsize=(10, 8), - height_ratios=[0.4, 0.2, 0.2, 0.2], -) +# Let's plot it. +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.subplot(211) +plt.plot(tsd_rem.restrict(ep_ex_rem), alpha=0.5) +plt.plot(theta_band.restrict(ep_ex_rem)) +plt.subplot(212) +plt.plot(theta_phase.restrict(ep_ex_rem), alpha=0.5) +plt.plot(spike_phase.restrict(ep_ex_rem), 'o') +plt.ylabel("Phase (rad)") +plt.xlabel("Time (s)") +plt.show() -REM_index = REM_Tsd.index.values -axd["lfp_run"].plot(REM_index, REM_Tsd[:, 0], alpha=0.5, label="LFP Data - REM") -axd["lfp_run"].plot( - REM_index, - theta_band_reconstruction, - label=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillations", -) -axd["lfp_run"].set( - ylabel="LFP (v)", - xlabel="Time (s)", - title=f"{np.round(freqs[theta_freq_index], 2)}Hz oscillation power.", -) -axd["lfp_run"].legend() -axd["lfp_run"].margins(0) -for i in range(3): - unit_key = list(phase_var.keys())[i] - ax = axd[f"phase_{i}"] - ax.plot(REM_index, theta_band_phase, alpha=0.2) - ax.scatter(spikes[unit_key].index, phase[unit_key]) - ax.set(ylabel="Phase", title=f"Unit {unit_key}") - ax.margins(0) - -fig.suptitle("Phase Preference Visualizations") -plt.show() diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index fcd1ec38..f7cdb826 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -1,5 +1,6 @@ """Filtering module.""" +import inspect from functools import wraps from numbers import Number @@ -9,19 +10,17 @@ from .. import core as nap -def _get_butter_coefficients(freq_band, filter_type, sampling_frequency, order=4): - return butter( - order, freq_band, btype=filter_type, fs=sampling_frequency, output="sos" - ) +def _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order=4): + return butter(order, cutoff, btype=filter_type, fs=sampling_frequency, output="sos") def _compute_butterworth_filter( - data, freq_band, sampling_frequency=None, filter_type="bandpass", order=4 + data, cutoff, sampling_frequency=None, filter_type="bandpass", order=4 ): """ Apply a Butterworth filter to the provided signal. """ - sos = _get_butter_coefficients(freq_band, filter_type, sampling_frequency, order) + sos = _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order) out = np.zeros_like(data.d) for ep in data.time_support: slc = data.get_slice(start=ep.start[0], end=ep.end[0]) @@ -113,15 +112,25 @@ def _validate_filtering_inputs(func): @wraps(func) def wrapper(*args, **kwargs): # Validate each positional argument - if not isinstance(args[0], nap.time_series.BaseTsd): + sig = inspect.signature(func) + kwargs = sig.bind_partial(*args, **kwargs).arguments + + if "data" not in kwargs or "cutoff" not in kwargs: + raise TypeError( + "Function needs time series and cutoff frequency to be specified." + ) + + if not isinstance(kwargs["data"], nap.time_series.BaseTsd): raise ValueError( f"Invalid value: {args[0]}. First argument should be of type Tsd, TsdFrame or TsdTensor" ) - if not isinstance(args[1], Number): - if len(args[1]) != 2 or not all(isinstance(fq, Number) for fq in args[1]): + + if not isinstance(kwargs["cutoff"], Number): + if len(kwargs["cutoff"]) != 2 or not all( + isinstance(fq, Number) for fq in kwargs["cutoff"] + ): raise ValueError - # Validate each keyword argument if "fs" in kwargs: if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): raise ValueError( @@ -141,14 +150,14 @@ def wrapper(*args, **kwargs): ) # Call the original function with validated inputs - return func(*args, **kwargs) + return func(**kwargs) return wrapper @_validate_filtering_inputs def compute_bandpass_filter( - data, freq_band, fs=None, mode="butter", order=4, transition_bandwidth=0.02 + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a band-pass filter to the provided signal. @@ -160,7 +169,7 @@ def compute_bandpass_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - freq_band : tuple of (float, float) + cutoff : tuple of (float, float) Cutoff frequencies in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -183,7 +192,7 @@ def compute_bandpass_filter( ------ ValueError If `data` is not a Tsd, TsdFrame, or TsdTensor. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters. If `fs` is not float or None. If `mode` is not "butter" or "sinc". If `order` is not an int. @@ -196,16 +205,16 @@ def compute_bandpass_filter( if fs is None: fs = data.rate - freq_band = np.array(freq_band) + cutoff = np.array(cutoff) if mode == "butter": return _compute_butterworth_filter( - data, freq_band, fs, filter_type="bandpass", order=order + data, cutoff, fs, filter_type="bandpass", order=order ) if mode == "sinc": return _compute_windowed_sinc_filter( data, - freq_band, + cutoff, fs, filter_type="bandpass", transition_bandwidth=transition_bandwidth, @@ -216,7 +225,7 @@ def compute_bandpass_filter( @_validate_filtering_inputs def compute_bandstop_filter( - data, freq_band, fs=None, mode="butter", order=4, transition_bandwidth=0.02 + data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a band-stop filter to the provided signal. @@ -228,7 +237,7 @@ def compute_bandstop_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - freq_band : tuple of (float, float) + cutoff : tuple of (float, float) Cutoff frequencies in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -251,7 +260,7 @@ def compute_bandstop_filter( ------ ValueError If `data` is not a Tsd, TsdFrame, or TsdTensor. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `cutoff` is not a tuple of two floats for "bandpass" and "bandstop" filters. If `fs` is not float or None. If `mode` is not "butter" or "sinc". If `order` is not an int. @@ -264,16 +273,16 @@ def compute_bandstop_filter( if fs is None: fs = data.rate - freq_band = np.array(freq_band) + cutoff = np.array(cutoff) if mode == "butter": return _compute_butterworth_filter( - data, freq_band, fs, filter_type="bandstop", order=order + data, cutoff, fs, filter_type="bandstop", order=order ) elif mode == "sinc": return _compute_windowed_sinc_filter( data, - freq_band, + cutoff, fs, filter_type="bandstop", transition_bandwidth=transition_bandwidth, @@ -319,7 +328,7 @@ def compute_highpass_filter( ------ ValueError If `data` is not a Tsd, TsdFrame, or TsdTensor. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `cutoff` is not a number. If `fs` is not float or None. If `mode` is not "butter" or "sinc". If `order` is not an int. @@ -385,7 +394,7 @@ def compute_lowpass_filter( ------ ValueError If `data` is not a Tsd, TsdFrame, or TsdTensor. - If `freq_band` is not a tuple of two floats for "bandpass" and "bandstop" filters. + If `cutoff` is not a number. If `fs` is not float or None. If `mode` is not "butter" or "sinc". If `order` is not an int. diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 04969083..80f0ca2a 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -240,6 +240,11 @@ def test_get_kernel_error(filter_type, expected_exception): with expected_exception: nap.process.filtering._get_windowed_sinc_kernel(0.25, filter_type=filter_type) +def test_get__error(): + with pytest.raises(TypeError, match="Function needs time series and cutoff frequency to be specified."): + nap.compute_lowpass_filter(cutoff=0.25) + + def test_compare_sinc_kernel(): kernel = nap.process.filtering._get_windowed_sinc_kernel(0.25) x = np.arange(-(len(kernel)//2), 1+len(kernel)//2) From eb939766982fee1ae2e9349bf15b6e9ef740a044 Mon Sep 17 00:00:00 2001 From: gviejo Date: Mon, 2 Sep 2024 13:03:23 -0400 Subject: [PATCH 169/195] Fixing sinc --- docs/api_guide/tutorial_pynapple_filtering.py | 43 ++++++++++++++++++- pynapple/process/filtering.py | 19 +++++--- tests/test_filtering.py | 2 +- 3 files changed, 55 insertions(+), 9 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index d917990b..b21e0abc 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -38,8 +38,8 @@ import pynapple as nap -seaborn.set_theme() -seaborn.set(font_scale=1.25) +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +seaborn.set_theme(context='notebook', style="ticks", rc=custom_params) # %% # *** @@ -102,6 +102,45 @@ # %% # This gives similar results except at the edges. # +# Another use of filtering is to remove some frequencies (notch filter). Here we can try to remove +# the 50 Hz component in the signal. + +sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') +sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc') + + +# %% +# Let's plot it +fig = plt.figure(figsize = (15, 5)) +plt.subplot(211) +plt.plot(t, sig, '-', color = 'gray', label = "Original signal") +plt.xlim(0, 1) +plt.legend() +plt.subplot(212) +# plt.plot(sig, alpha=0.5) +plt.plot(sig_butter, label = "Butterworth") +plt.plot(sig_sinc, '--', label = "Windowed-sinc") +plt.legend() +plt.xlabel("Time (Hz)") +plt.xlim(0, 1) +plt.show() + +# %% +# Let's see what frequencies remain; + +psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True) +psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True) + +fig = plt.figure(figsize = (10, 5)) +plt.plot(np.abs(psd_butter), label = "Butterworth filter") +plt.plot(np.abs(psd_sinc), label = "Windowed-sinc convolution") +plt.legend() +plt.xlabel("Frequency (Hz)") +plt.ylabel("Amplitude") +plt.xlim(0, 70) +plt.show() + +# %% # The remaining notebook compares the two modes. # # *** diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index f7cdb826..379074ca 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -5,7 +5,7 @@ from numbers import Number import numpy as np -from scipy.signal import butter, sosfiltfilt +from scipy.signal import butter, freqz, sosfiltfilt from .. import core as nap @@ -71,22 +71,29 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. x = np.arange(-(M // 2), 1 + (M // 2)) fc = np.transpose(np.atleast_2d(fc)) kernel = np.sinc(2 * fc * x) - kernel = np.transpose(kernel * np.blackman(kernel.shape[1])) + kernel = np.transpose(kernel) kernel = kernel / np.sum(kernel, axis=0) if filter_type == "lowpass": - kernel = kernel.flatten() + kernel = kernel.flatten() * np.blackman(len(kernel)) return kernel elif filter_type == "highpass": - kernel = _compute_spectral_inversion(kernel.flatten()) + kernel = _compute_spectral_inversion(kernel.flatten()) * np.blackman( + len(kernel) + ) return kernel elif filter_type == "bandstop": kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) kernel = np.sum(kernel, axis=1) + kernel = kernel * np.blackman(len(kernel)) + w, h = freqz(kernel, worN=4000) + kernel /= np.max(np.abs(h)) return kernel elif filter_type == "bandpass": - kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) - kernel = _compute_spectral_inversion(np.sum(kernel, axis=1)) + kernel = kernel[:, 1] - kernel[:, 0] + kernel = kernel * np.blackman(len(kernel)) + w, h = freqz(kernel, worN=4000) + kernel /= np.max(np.abs(h)) return kernel else: raise ValueError diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 80f0ca2a..b9070fa5 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -252,8 +252,8 @@ def test_compare_sinc_kernel(): warnings.simplefilter("ignore") kernel2 = np.sin(2*np.pi*x*0.25)/x#(2*np.pi*x*0.25) kernel2[len(kernel)//2] = 0.25*2*np.pi - kernel2 = kernel2 * np.blackman(len(kernel2)) kernel2 = kernel2/kernel2.sum() + kernel2 = kernel2 * np.blackman(len(kernel2)) np.testing.assert_allclose(kernel, kernel2) ikernel = nap.process.filtering._compute_spectral_inversion(kernel) From 4cb76ce12d8f93fea3f9e019e7f7ec8e1e2a8072 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 15:32:50 -0400 Subject: [PATCH 170/195] added exception --- pynapple/process/filtering.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 379074ca..ba00bea4 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -156,6 +156,13 @@ def wrapper(*args, **kwargs): "Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float" ) + if np.any(np.isnan(kwargs["data"])): + raise ValueError( + "The input signal contains NaN values, which are not supported for filtering. " + "Please remove or handle NaNs before applying the filter. " + "You can use the `dropna()` method to drop all NaN values." + ) + # Call the original function with validated inputs return func(**kwargs) From e9468b00674f6658a4ee50f9d5b75de330a5cbfc Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 15:36:03 -0400 Subject: [PATCH 171/195] added test exception --- tests/test_filtering.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index b9070fa5..b53d5abb 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -15,6 +15,15 @@ def sample_data(): return nap.Tsd(t=t, d=d, time_support=time_support) +def sample_data_with_nan(): + # Create a sample Tsd data object + t = np.linspace(0, 1, 500) + d = np.sin(2 * np.pi * 10 * t) + np.random.normal(0, 0.5, t.shape) + d[10] = np.nan + time_support = nap.IntervalSet(start=[0], end=[1]) + return nap.Tsd(t=t, d=d, time_support=time_support) + + def compare_scipy(tsd, ep, order, freq, fs, btype): sos = signal.butter(order, freq, btype=btype, fs=fs, output="sos") out_sci = [] @@ -195,6 +204,9 @@ def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequ (sample_data(), None, "invalid_mode", 4, 0.02, pytest.raises(ValueError,match="Unrecognized filter mode. Choose either 'butter' or 'sinc'")), (sample_data(), "invalid_fs", "butter", 4, 0.02, pytest.raises(ValueError,match="Invalid value for 'fs'. Parameter 'fs' should be of type float or int")), (sample_data(), None, "sinc", 4, "a", pytest.raises(ValueError,match="Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float")), + (sample_data_with_nan(), None, "sinc", 4, 0.02, pytest.raises(ValueError,match="The input signal contains NaN values, which are not supported for filtering")), + (sample_data_with_nan(), None, "butter", 4, 0.02, pytest.raises(ValueError, match="The input signal contains NaN values, which are not supported for filtering")) + ]) def test_compute_filtered_signal_raise_errors(func, freq, data, fs, mode, order, transition_bandwidth, expected_exception): with expected_exception: From 9774346a57fb71f739aa75deb3b91e22246a4651 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 18:20:57 -0400 Subject: [PATCH 172/195] fixes lowpass and highpass by normalizing --- pynapple/process/filtering.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index ba00bea4..00068b30 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -72,29 +72,28 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. fc = np.transpose(np.atleast_2d(fc)) kernel = np.sinc(2 * fc * x) kernel = np.transpose(kernel) - kernel = kernel / np.sum(kernel, axis=0) if filter_type == "lowpass": kernel = kernel.flatten() * np.blackman(len(kernel)) - return kernel + return kernel / np.sum(kernel) elif filter_type == "highpass": kernel = _compute_spectral_inversion(kernel.flatten()) * np.blackman( len(kernel) ) - return kernel + return kernel / np.sum(kernel) elif filter_type == "bandstop": kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) kernel = np.sum(kernel, axis=1) kernel = kernel * np.blackman(len(kernel)) w, h = freqz(kernel, worN=4000) kernel /= np.max(np.abs(h)) - return kernel + return kernel / np.sum(kernel) elif filter_type == "bandpass": kernel = kernel[:, 1] - kernel[:, 0] kernel = kernel * np.blackman(len(kernel)) w, h = freqz(kernel, worN=4000) kernel /= np.max(np.abs(h)) - return kernel + return kernel / np.sum(kernel) else: raise ValueError From 55105f2a2b646a88566c890175ea08219e8022f0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 18:26:52 -0400 Subject: [PATCH 173/195] fix tests --- tests/test_filtering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index b53d5abb..26b9eb48 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -264,8 +264,9 @@ def test_compare_sinc_kernel(): warnings.simplefilter("ignore") kernel2 = np.sin(2*np.pi*x*0.25)/x#(2*np.pi*x*0.25) kernel2[len(kernel)//2] = 0.25*2*np.pi - kernel2 = kernel2/kernel2.sum() + kernel2 = kernel2 kernel2 = kernel2 * np.blackman(len(kernel2)) + kernel2 /= kernel2.sum() np.testing.assert_allclose(kernel, kernel2) ikernel = nap.process.filtering._compute_spectral_inversion(kernel) From 2c9a194e77ac7cd6c8ee361bab1b3051fa0922d9 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 19:28:36 -0400 Subject: [PATCH 174/195] fix highpass --- pynapple/process/filtering.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 00068b30..268b61a9 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -77,23 +77,26 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. kernel = kernel.flatten() * np.blackman(len(kernel)) return kernel / np.sum(kernel) elif filter_type == "highpass": - kernel = _compute_spectral_inversion(kernel.flatten()) * np.blackman( - len(kernel) - ) - return kernel / np.sum(kernel) + # kernel = _compute_spectral_inversion( + # kernel.flatten() * np.blackman(len(kernel)) + # ) + kernel = kernel.flatten() * np.blackman(len(kernel)) + kernel /= np.sum(kernel) + kernel = _compute_spectral_inversion(kernel) + return kernel #/ np.sum(kernel) elif filter_type == "bandstop": kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) kernel = np.sum(kernel, axis=1) kernel = kernel * np.blackman(len(kernel)) w, h = freqz(kernel, worN=4000) kernel /= np.max(np.abs(h)) - return kernel / np.sum(kernel) + return kernel #/ np.sum(kernel) elif filter_type == "bandpass": kernel = kernel[:, 1] - kernel[:, 0] kernel = kernel * np.blackman(len(kernel)) w, h = freqz(kernel, worN=4000) kernel /= np.max(np.abs(h)) - return kernel / np.sum(kernel) + return kernel #/ np.sum(kernel) else: raise ValueError From c3a681a926061a6b4e7ebe9eb19583ae5b8632c1 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 19:28:57 -0400 Subject: [PATCH 175/195] fix highpass --- pynapple/process/filtering.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 268b61a9..31a9394e 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -77,13 +77,11 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. kernel = kernel.flatten() * np.blackman(len(kernel)) return kernel / np.sum(kernel) elif filter_type == "highpass": - # kernel = _compute_spectral_inversion( - # kernel.flatten() * np.blackman(len(kernel)) - # ) + kernel = kernel.flatten() * np.blackman(len(kernel)) kernel /= np.sum(kernel) kernel = _compute_spectral_inversion(kernel) - return kernel #/ np.sum(kernel) + return kernel elif filter_type == "bandstop": kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) kernel = np.sum(kernel, axis=1) From affd113eb15876209480757522d264b84957c47f Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 19:41:24 -0400 Subject: [PATCH 176/195] fix bandstop --- pynapple/process/filtering.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 31a9394e..7fc6f109 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -77,18 +77,19 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. kernel = kernel.flatten() * np.blackman(len(kernel)) return kernel / np.sum(kernel) elif filter_type == "highpass": - kernel = kernel.flatten() * np.blackman(len(kernel)) kernel /= np.sum(kernel) kernel = _compute_spectral_inversion(kernel) return kernel elif filter_type == "bandstop": + bw = np.blackman(len(kernel)) + kernel[:, 1] *= bw + kernel[:, 1] /= kernel[:, 1].sum() kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel[:, 0] *= bw + kernel[:, 0] /= kernel[:, 0].sum() kernel = np.sum(kernel, axis=1) - kernel = kernel * np.blackman(len(kernel)) - w, h = freqz(kernel, worN=4000) - kernel /= np.max(np.abs(h)) - return kernel #/ np.sum(kernel) + return kernel elif filter_type == "bandpass": kernel = kernel[:, 1] - kernel[:, 0] kernel = kernel * np.blackman(len(kernel)) From 031cae449fe04fabd4843ce513ccc2f4915d0bc0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Fri, 6 Sep 2024 20:13:45 -0400 Subject: [PATCH 177/195] fix bandpass --- pynapple/process/filtering.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 7fc6f109..ac447c32 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -91,11 +91,14 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. kernel = np.sum(kernel, axis=1) return kernel elif filter_type == "bandpass": - kernel = kernel[:, 1] - kernel[:, 0] - kernel = kernel * np.blackman(len(kernel)) - w, h = freqz(kernel, worN=4000) - kernel /= np.max(np.abs(h)) - return kernel #/ np.sum(kernel) + bw = np.blackman(len(kernel)) + kernel[:, 1] *= bw + kernel[:, 1] /= kernel[:, 1].sum() + kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) + kernel[:, 0] *= bw + kernel[:, 0] /= kernel[:, 0].sum() + kernel = _compute_spectral_inversion(kernel[:, 1] + kernel[:, 0]) + return kernel else: raise ValueError From bb9ed1839a5968620dd1456fdebf3b087185eae6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 7 Sep 2024 11:57:40 -0400 Subject: [PATCH 178/195] added a parameter for spectral density --- docs/api_guide/tutorial_pynapple_filtering.py | 10 +++++----- pynapple/process/signal_processing.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index b21e0abc..673bd439 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -81,7 +81,7 @@ # %% # Let's compare it to the `sinc` mode for Windowed-sinc. -sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc') +sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.001) # %% # Let's plot it @@ -106,7 +106,7 @@ # the 50 Hz component in the signal. sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') -sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc') +sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.001) # %% @@ -128,8 +128,8 @@ # %% # Let's see what frequencies remain; -psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True) -psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True) +psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True, n=1024) +psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True, n=1024) fig = plt.figure(figsize = (10, 5)) plt.plot(np.abs(psd_butter), label = "Butterworth filter") @@ -182,7 +182,7 @@ for trans_bandwidth, kernel in sinc_kernel.items(): plt.subplot(gs[0, 1]) fft_sinc = nap.compute_power_spectral_density( - nap.Tsd(t=np.arange(len(kernel)) / fs, d=kernel), fs) + nap.Tsd(t=np.arange(len(kernel)) / fs, d=kernel), fs, n=1024) plt.plot(np.abs(fft_sinc), label= f"width={trans_bandwidth}") plt.ylabel('Amplitude') plt.legend() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 6071e98e..3b72cfd6 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -17,7 +17,7 @@ from .. import core as nap -def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False): +def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False, n=None): """ Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. @@ -33,6 +33,10 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm If true, will return full fft frequency range, otherwise will return only positive values norm: bool, optional Whether the FFT result is divided by the length of the signal to normalize the amplitude + n: int, optional + Length of the transformed axis of the output. If n is smaller than the length of the input, + the input is cropped. If it is larger, the input is padded with zeros. If n is not given, + the length of the input along the axis specified by axis is used. Returns ------- @@ -62,8 +66,10 @@ def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm if not isinstance(norm, bool): raise TypeError("norm must be of type bool") - fft_result = np.fft.fft(sig.restrict(ep).values, axis=0) - fft_freq = np.fft.fftfreq(len(sig.restrict(ep).values), 1 / fs) + fft_result = np.fft.fft(sig.restrict(ep).values, n=n, axis=0) + if n is None: + n = len(sig.restrict(ep)) + fft_freq = np.fft.fftfreq(n, 1 / fs) if norm: fft_result = fft_result / fft_result.shape[0] From d96087e371affc82ee78723bc8826bb6d0f29df6 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Sat, 7 Sep 2024 11:59:15 -0400 Subject: [PATCH 179/195] edited tutorial --- docs/api_guide/tutorial_pynapple_filtering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index 673bd439..62e4e681 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -128,8 +128,8 @@ # %% # Let's see what frequencies remain; -psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True, n=1024) -psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True, n=1024) +psd_butter = nap.compute_power_spectral_density(sig_butter, fs, norm=True) +psd_sinc = nap.compute_power_spectral_density(sig_sinc, fs, norm=True) fig = plt.figure(figsize = (10, 5)) plt.plot(np.abs(psd_butter), label = "Butterworth filter") From 830062663eced3d51b8627ca5fc20116d299dce0 Mon Sep 17 00:00:00 2001 From: gviejo Date: Mon, 9 Sep 2024 09:15:58 -0400 Subject: [PATCH 180/195] Update --- docs/api_guide/tutorial_pynapple_filtering.py | 4 ++-- pynapple/process/filtering.py | 24 +++++-------------- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index 62e4e681..b2f55f20 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -81,7 +81,7 @@ # %% # Let's compare it to the `sinc` mode for Windowed-sinc. -sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.001) +sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc') # %% # Let's plot it @@ -106,7 +106,7 @@ # the 50 Hz component in the signal. sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') -sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.001) +sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc') # %% diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index ac447c32..55748f1f 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -67,37 +67,25 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. ------- np.ndarray """ - M = int(np.rint(4.0 / transition_bandwidth)) + M = int(np.rint(20.0 / transition_bandwidth)) x = np.arange(-(M // 2), 1 + (M // 2)) fc = np.transpose(np.atleast_2d(fc)) kernel = np.sinc(2 * fc * x) + kernel = kernel*np.blackman(len(x)) kernel = np.transpose(kernel) + kernel = kernel/kernel.sum(0) if filter_type == "lowpass": - kernel = kernel.flatten() * np.blackman(len(kernel)) - return kernel / np.sum(kernel) + return kernel.flatten() elif filter_type == "highpass": - kernel = kernel.flatten() * np.blackman(len(kernel)) - kernel /= np.sum(kernel) - kernel = _compute_spectral_inversion(kernel) - return kernel + return _compute_spectral_inversion(kernel.flatten()) elif filter_type == "bandstop": - bw = np.blackman(len(kernel)) - kernel[:, 1] *= bw - kernel[:, 1] /= kernel[:, 1].sum() kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) - kernel[:, 0] *= bw - kernel[:, 0] /= kernel[:, 0].sum() kernel = np.sum(kernel, axis=1) return kernel elif filter_type == "bandpass": - bw = np.blackman(len(kernel)) - kernel[:, 1] *= bw - kernel[:, 1] /= kernel[:, 1].sum() kernel[:, 1] = _compute_spectral_inversion(kernel[:, 1]) - kernel[:, 0] *= bw - kernel[:, 0] /= kernel[:, 0].sum() - kernel = _compute_spectral_inversion(kernel[:, 1] + kernel[:, 0]) + kernel = _compute_spectral_inversion(np.sum(kernel, axis=1)) return kernel else: raise ValueError From c19d78532bd4c0e42e4e1ad850d0326f0329911e Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Sep 2024 09:44:57 -0400 Subject: [PATCH 181/195] linted --- pynapple/process/filtering.py | 6 +++--- pynapple/process/signal_processing.py | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 55748f1f..65d2447c 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -5,7 +5,7 @@ from numbers import Number import numpy as np -from scipy.signal import butter, freqz, sosfiltfilt +from scipy.signal import butter, sosfiltfilt from .. import core as nap @@ -71,9 +71,9 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. x = np.arange(-(M // 2), 1 + (M // 2)) fc = np.transpose(np.atleast_2d(fc)) kernel = np.sinc(2 * fc * x) - kernel = kernel*np.blackman(len(x)) + kernel = kernel * np.blackman(len(x)) kernel = np.transpose(kernel) - kernel = kernel/kernel.sum(0) + kernel = kernel / kernel.sum(0) if filter_type == "lowpass": return kernel.flatten() diff --git a/pynapple/process/signal_processing.py b/pynapple/process/signal_processing.py index 3b72cfd6..fb575db8 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/signal_processing.py @@ -17,7 +17,9 @@ from .. import core as nap -def compute_power_spectral_density(sig, fs=None, ep=None, full_range=False, norm=False, n=None): +def compute_power_spectral_density( + sig, fs=None, ep=None, full_range=False, norm=False, n=None +): """ Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. From ef948f0665d3bd0fb68aa0eead3665c0e6f19b3d Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Sep 2024 10:38:22 -0400 Subject: [PATCH 182/195] removed repeated code. Simplified validate --- pynapple/process/filtering.py | 147 ++++++++++++++++------------------ tests/test_filtering.py | 2 +- 2 files changed, 68 insertions(+), 81 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 65d2447c..88130d41 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -114,11 +114,6 @@ def wrapper(*args, **kwargs): sig = inspect.signature(func) kwargs = sig.bind_partial(*args, **kwargs).arguments - if "data" not in kwargs or "cutoff" not in kwargs: - raise TypeError( - "Function needs time series and cutoff frequency to be specified." - ) - if not isinstance(kwargs["data"], nap.time_series.BaseTsd): raise ValueError( f"Invalid value: {args[0]}. First argument should be of type Tsd, TsdFrame or TsdTensor" @@ -162,6 +157,37 @@ def wrapper(*args, **kwargs): @_validate_filtering_inputs +def _compute_filter( + data, + cutoff, + fs=None, + mode="butter", + order=4, + transition_bandwidth=0.02, + filter_type="bandpass", +): + """Filter the signal.""" + if fs is None: + fs = data.rate + + cutoff = np.array(cutoff) + + if mode == "butter": + return _compute_butterworth_filter( + data, cutoff, fs, filter_type=filter_type, order=order + ) + if mode == "sinc": + return _compute_windowed_sinc_filter( + data, + cutoff, + fs, + filter_type=filter_type, + transition_bandwidth=transition_bandwidth, + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + + def compute_bandpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): @@ -208,28 +234,17 @@ def compute_bandpass_filter( For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal is reduced by -3 dB (decibels). """ - if fs is None: - fs = data.rate - - cutoff = np.array(cutoff) - - if mode == "butter": - return _compute_butterworth_filter( - data, cutoff, fs, filter_type="bandpass", order=order - ) - if mode == "sinc": - return _compute_windowed_sinc_filter( - data, - cutoff, - fs, - filter_type="bandpass", - transition_bandwidth=transition_bandwidth, - ) - else: - raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="bandpass", + ) -@_validate_filtering_inputs def compute_bandstop_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): @@ -276,28 +291,17 @@ def compute_bandstop_filter( For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal is reduced by -3 dB (decibels). """ - if fs is None: - fs = data.rate - - cutoff = np.array(cutoff) - - if mode == "butter": - return _compute_butterworth_filter( - data, cutoff, fs, filter_type="bandstop", order=order - ) - elif mode == "sinc": - return _compute_windowed_sinc_filter( - data, - cutoff, - fs, - filter_type="bandstop", - transition_bandwidth=transition_bandwidth, - ) - else: - raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="bandstop", + ) -@_validate_filtering_inputs def compute_highpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): @@ -344,26 +348,17 @@ def compute_highpass_filter( For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal is reduced by -3 dB (decibels). """ - if fs is None: - fs = data.rate - - if mode == "butter": - return _compute_butterworth_filter( - data, cutoff, fs, filter_type="highpass", order=order - ) - elif mode == "sinc": - return _compute_windowed_sinc_filter( - data, - cutoff, - fs, - filter_type="highpass", - transition_bandwidth=transition_bandwidth, - ) - else: - raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="highpass", + ) -@_validate_filtering_inputs def compute_lowpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): @@ -410,20 +405,12 @@ def compute_lowpass_filter( For the Butterworth filter, the cutoff frequency is defined as the frequency at which the amplitude of the signal is reduced by -3 dB (decibels). """ - if fs is None: - fs = data.rate - - if mode == "butter": - return _compute_butterworth_filter( - data, cutoff, fs, filter_type="lowpass", order=order - ) - elif mode == "sinc": - return _compute_windowed_sinc_filter( - data, - cutoff, - fs, - filter_type="lowpass", - transition_bandwidth=transition_bandwidth, - ) - else: - raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") + return _compute_filter( + data, + cutoff, + fs=fs, + mode=mode, + order=order, + transition_bandwidth=transition_bandwidth, + filter_type="lowpass", + ) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 26b9eb48..d158b1b1 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -253,7 +253,7 @@ def test_get_kernel_error(filter_type, expected_exception): nap.process.filtering._get_windowed_sinc_kernel(0.25, filter_type=filter_type) def test_get__error(): - with pytest.raises(TypeError, match="Function needs time series and cutoff frequency to be specified."): + with pytest.raises(TypeError, match=r"compute_lowpass_filter\(\) missing 1 required positional argument: 'data'"): nap.compute_lowpass_filter(cutoff=0.25) From d76aa3dd8a6e4d27bceca1a26ee39325683e3b35 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 9 Sep 2024 15:29:39 -0400 Subject: [PATCH 183/195] Update --- docs/api_guide/tutorial_pynapple_filtering.py | 98 ++++----- pynapple/core/time_series.py | 3 + pynapple/process/__init__.py | 1 + pynapple/process/filtering.py | 195 ++++++++++++------ tests/test_filtering.py | 34 ++- tests/test_time_series.py | 14 +- 6 files changed, 215 insertions(+), 130 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index b2f55f20..9bbf23d3 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -60,7 +60,7 @@ fig = plt.figure(figsize = (15, 5)) plt.plot(sig) plt.xlabel("Time (s)") -plt.show() + # %% # We can compute the Fourier transform of `sig` to verify that all the frequencies are there. @@ -71,7 +71,7 @@ plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 100) -plt.show() + # %% # Let's say we would like to see only the 10 Hz component. @@ -97,12 +97,12 @@ plt.legend() plt.xlabel("Time (Hz)") plt.xlim(0, 1) -plt.show() + # %% # This gives similar results except at the edges. # -# Another use of filtering is to remove some frequencies (notch filter). Here we can try to remove +# Another use of filtering is to remove some frequencies. Here we can try to remove # the 50 Hz component in the signal. sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') @@ -123,7 +123,7 @@ plt.legend() plt.xlabel("Time (Hz)") plt.xlim(0, 1) -plt.show() + # %% # Let's see what frequencies remain; @@ -138,7 +138,7 @@ plt.xlabel("Frequency (Hz)") plt.ylabel("Amplitude") plt.xlim(0, 70) -plt.show() + # %% # The remaining notebook compares the two modes. @@ -147,16 +147,20 @@ # Frequency responses # ------------------- # -# Let's get filter coefficients for a 250Hz recursive low pass filter with different order: -butter_sos = { - order:nap.process.filtering._get_butter_coefficients(250, "lowpass", fs, order=order) +# In order to check the validity of the filter, the function `get_filter_frequency_response` provides the frequency +# response of the filters. The calling signature is similar to the previous functions. +# The function returns a pandas Series with the frequencies as index. +# +# Let's get the frequency response for a Butterworth low pass filter with different order: +butter_freq = { + order: nap.get_filter_frequency_response(250, fs, "lowpass", "butter", order=order) for order in [2, 4, 6]} # %% -# ... and the kernel for the Windowed-sinc equivalent with different transition bandwitdh -sinc_kernel = { - tb:nap.process.filtering._get_windowed_sinc_kernel(250/fs, "lowpass", transition_bandwidth=tb) - for tb in [0.02, 0.1, 0.2]} +# ... and the frequency response for the Windowed-sinc equivalent with different transition bandwidth. +sinc_freq = { + tb:nap.get_filter_frequency_response(250, fs,"lowpass", "sinc", transition_bandwidth=tb) + for tb in [0.002, 0.02, 0.2]} # %% # Let's plot the frequency response of both. @@ -165,74 +169,54 @@ fig = plt.figure(figsize = (20, 10)) gs = plt.GridSpec(2, 2) -for order, sos in butter_sos.items(): +for order in butter_freq.keys(): plt.subplot(gs[0, 0]) - w, h = signal.sosfreqz(sos, worN=1500, fs=fs) - plt.plot(w, np.abs(h), label = f"order={order}") + plt.plot(butter_freq[order], label = f"order={order}") plt.ylabel('Amplitude') plt.legend() plt.title("Butterworth recursive") plt.subplot(gs[1, 0]) - plt.plot(w, 20 * np.log10(np.abs(h)), label = f"order={order}") + plt.plot(20*np.log10(butter_freq[order]), label = f"order={order}") plt.xlabel('Frequency [Hz]') plt.ylabel('Amplitude [dB]') - plt.ylim(-100,40) + plt.ylim(-200,20) plt.legend() -for trans_bandwidth, kernel in sinc_kernel.items(): +for tb in sinc_freq.keys(): plt.subplot(gs[0, 1]) - fft_sinc = nap.compute_power_spectral_density( - nap.Tsd(t=np.arange(len(kernel)) / fs, d=kernel), fs, n=1024) - plt.plot(np.abs(fft_sinc), label= f"width={trans_bandwidth}") + plt.plot(sinc_freq[tb], label= f"width={tb}") plt.ylabel('Amplitude') plt.legend() plt.title("Windowed-sinc conv.") plt.subplot(gs[1, 1]) - plt.plot(20*np.log10(np.abs(fft_sinc)), label= f"width={trans_bandwidth}") + plt.plot(20*np.log10(sinc_freq[tb]), label= f"width={tb}") plt.xlabel('Frequency [Hz]') plt.ylabel('Amplitude [dB]') - plt.ylim(-100,40) + plt.ylim(-200,20) plt.legend() -plt.show() # %% -# *** -# Step responses -# -------------- +# In some cases, the transition bandwidth that is too high generates a kernel that is too short. The amplitude of the +# original signal will then be lower than expected. +# In this case, the solution is to decrease the transition bandwidth when using the windowed-sinc mode. +# Note that this increases the length of the kernel significantly. +# Let see it with the band pass filter. -step = nap.Tsd(t=sig.t, d=np.zeros(len(sig))) -step[len(step)//2:] = 1.0 -# %% -# -step_butter = { - order:nap.compute_lowpass_filter(step, 0.2*fs, fs, mode='butter', order=order) - for order in [4, 12]} +sinc_freq = { + tb:nap.get_filter_frequency_response((100, 200), fs, "bandpass", "sinc", transition_bandwidth=tb) + for tb in [0.004, 0.5]} -# %% -# Let's compare it to the `sinc` mode for Windowed-sinc. -step_sinc = { - tb:nap.compute_lowpass_filter(step, 0.2*fs, fs, mode='sinc', transition_bandwidth=tb) - for tb in [0.005, 0.2]} +fig = plt.figure(figsize = (20, 10)) +for tb in sinc_freq.keys(): + plt.plot(sinc_freq[tb], label= f"width={tb}") + plt.ylabel('Amplitude') + plt.legend() + plt.title("Windowed-sinc conv.") + plt.legend() + -# %% -plt.figure() -plt.subplot(211) -plt.plot(step) -for order, filt_step in step_butter.items(): - plt.plot(filt_step, label=f"order={order}") -plt.legend() -plt.xlim(0.95, 1.05) -plt.title("Butterworth filter") -plt.subplot(212) -plt.plot(step) -for tb, filt_step in step_sinc.items(): - plt.plot(filt_step, label=f"width={tb}") -plt.legend() -plt.xlim(0.95, 1.05) -plt.title("Windowed-sinc") -plt.show() # %% # *** diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 74692735..2af7f269 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -604,6 +604,9 @@ def smooth(self, std, windowsize=None, time_units="s", size_factor=100, norm=Tru else: M = std_size * size_factor + if M % 2 == 0: + M += 1 + window = signal.windows.gaussian(M=M, std=std_size) if norm: diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 6a0ccdb4..8a913e28 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -9,6 +9,7 @@ compute_bandstop_filter, compute_highpass_filter, compute_lowpass_filter, + get_filter_frequency_response, ) from .perievent import ( compute_event_trigger_average, diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 88130d41..fc48284e 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -5,12 +5,51 @@ from numbers import Number import numpy as np -from scipy.signal import butter, sosfiltfilt +import pandas as pd +from scipy.signal import butter, sosfiltfilt, sosfreqz from .. import core as nap +def _validate_filtering_inputs(func): + @wraps(func) + def wrapper(*args, **kwargs): + # Validate each positional argument + sig = inspect.signature(func) + kwargs = sig.bind_partial(*args, **kwargs).arguments + + if not isinstance(kwargs["cutoff"], Number): + if len(kwargs["cutoff"]) != 2 or not all( + isinstance(fq, Number) for fq in kwargs["cutoff"] + ): + raise ValueError + + if "fs" in kwargs: + if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): + raise ValueError( + "Invalid value for 'fs'. Parameter 'fs' should be of type float or int" + ) + + if "order" in kwargs: + if not isinstance(kwargs["order"], int): + raise ValueError( + "Invalid value for 'order': Parameter 'order' should be of type int" + ) + + if "transition_bandwidth" in kwargs: + if not isinstance(kwargs["transition_bandwidth"], float): + raise ValueError( + "Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float" + ) + + # Call the original function with validated inputs + return func(**kwargs) + + return wrapper + + def _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order=4): + """Calls scipy butter""" return butter(order, cutoff, btype=filter_type, fs=sampling_frequency, output="sos") @@ -48,7 +87,9 @@ def _compute_spectral_inversion(kernel): return kernel -def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0.02): +def _get_windowed_sinc_kernel( + fc, filter_type, sampling_frequency, transition_bandwidth=0.02 +): """ Get the windowed-sinc kernel. Smith, S. (2003). Digital signal processing: a practical guide for engineers and scientists. @@ -57,10 +98,12 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. Parameters ---------- fc: float or tuple of float - Cutting frequency between 0 and 0.5. Single float for 'lowpass' and 'highpass'. Tuple of float for + Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for 'bandpass' and 'bandstop'. filter_type: str Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'. + sampling_frequency: float + Sampling frequency in Hz. transition_bandwidth: float Percentage between 0 and 0.5 Returns @@ -69,7 +112,7 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. """ M = int(np.rint(20.0 / transition_bandwidth)) x = np.arange(-(M // 2), 1 + (M // 2)) - fc = np.transpose(np.atleast_2d(fc)) + fc = np.transpose(np.atleast_2d(fc / sampling_frequency)) kernel = np.sinc(2 * fc * x) kernel = kernel * np.blackman(len(x)) kernel = np.transpose(kernel) @@ -92,70 +135,34 @@ def _get_windowed_sinc_kernel(fc, filter_type="lowpass", transition_bandwidth=0. def _compute_windowed_sinc_filter( - data, freq, sampling_frequency, filter_type="lowpass", transition_bandwidth=0.02 + data, freq, filter_type, sampling_frequency, transition_bandwidth=0.02 ): """ Apply a windowed-sinc filter to the provided signal. Parameters ---------- - filter_type + data: Tsd, TsdFrame or TsdTensor + + freq: float or tuple of float + Cutting frequency in Hz. Single float for 'lowpass' and 'highpass'. Tuple of float for + 'bandpass' and 'bandstop'. + sampling_frequency: float + Sampling frequency in Hz. + filter_type: str + Either 'lowpass', 'highpass', 'bandstop' or 'bandpass'. + transition_bandwidth: float + Percentage between 0 and 0.5 + Returns + ------- + Tsd, TsdFrame or TsdTensor """ kernel = _get_windowed_sinc_kernel( - freq / sampling_frequency, filter_type, transition_bandwidth + freq, filter_type, sampling_frequency, transition_bandwidth ) return data.convolve(kernel) -def _validate_filtering_inputs(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Validate each positional argument - sig = inspect.signature(func) - kwargs = sig.bind_partial(*args, **kwargs).arguments - - if not isinstance(kwargs["data"], nap.time_series.BaseTsd): - raise ValueError( - f"Invalid value: {args[0]}. First argument should be of type Tsd, TsdFrame or TsdTensor" - ) - - if not isinstance(kwargs["cutoff"], Number): - if len(kwargs["cutoff"]) != 2 or not all( - isinstance(fq, Number) for fq in kwargs["cutoff"] - ): - raise ValueError - - if "fs" in kwargs: - if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): - raise ValueError( - "Invalid value for 'fs'. Parameter 'fs' should be of type float or int" - ) - - if "order" in kwargs: - if not isinstance(kwargs["order"], int): - raise ValueError( - "Invalid value for 'order': Parameter 'order' should be of type int" - ) - - if "transition_bandwidth" in kwargs: - if not isinstance(kwargs["transition_bandwidth"], float): - raise ValueError( - "Invalid value for 'transition_bandwidth'. 'transition_bandwidth' should be of type float" - ) - - if np.any(np.isnan(kwargs["data"])): - raise ValueError( - "The input signal contains NaN values, which are not supported for filtering. " - "Please remove or handle NaNs before applying the filter. " - "You can use the `dropna()` method to drop all NaN values." - ) - - # Call the original function with validated inputs - return func(**kwargs) - - return wrapper - - @_validate_filtering_inputs def _compute_filter( data, @@ -166,7 +173,21 @@ def _compute_filter( transition_bandwidth=0.02, filter_type="bandpass", ): - """Filter the signal.""" + """ + Filter the signal. + """ + if not isinstance(data, nap.time_series.BaseTsd): + raise ValueError( + f"Invalid value: {data}. First argument should be of type Tsd, TsdFrame or TsdTensor" + ) + + if np.any(np.isnan(data)): + raise ValueError( + "The input signal contains NaN values, which are not supported for filtering. " + "Please remove or handle NaNs before applying the filter. " + "You can use the `dropna()` method to drop all NaN values." + ) + if fs is None: fs = data.rate @@ -178,11 +199,7 @@ def _compute_filter( ) if mode == "sinc": return _compute_windowed_sinc_filter( - data, - cutoff, - fs, - filter_type=filter_type, - transition_bandwidth=transition_bandwidth, + data, cutoff, filter_type, fs, transition_bandwidth=transition_bandwidth ) else: raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") @@ -414,3 +431,57 @@ def compute_lowpass_filter( transition_bandwidth=transition_bandwidth, filter_type="lowpass", ) + + +@_validate_filtering_inputs +def get_filter_frequency_response( + cutoff, fs, filter_type, mode, order=4, transition_bandwidth=0.02 +): + """ + Utility function to evaluate the frequency response of a particular type of filter. The arguments are the same + as the function `compute_lowpass_filter`, `compute_highpass_filter`, `compute_bandpass_filter` and + `compute_bandstop_filter`. + + This function returns a pandas Series object with the index as frequencies. + + Parameters + ---------- + cutoff : float or tuple of float + Cutoff frequency in Hz. + fs : float + The sampling frequency of the signal in Hz. + filter_type: str + Can be "lowpass", "highpass", "bandpass" or "bandstop" + mode: str + Can be "butter" or "sinc". + order : int, optional + The order of the Butterworth filter. Higher values result in sharper frequency cutoffs. + Default is 4. + transition_bandwidth : float, optional + The transition bandwidth. 0.2 corresponds to 20% of the frequency band between 0 and the sampling frequency. + The smaller the transition bandwidth, the larger the windowed-sinc kernel. + Default is 0.02. + + Returns + ------- + pandas.Series + """ + cutoff = np.array(cutoff) + + if mode == "butter": + sos = _get_butter_coefficients(cutoff, filter_type, fs, order) + w, h = sosfreqz(sos, worN=1024, fs=fs) + return pd.Series(index=w, data=np.abs(h)) + if mode == "sinc": + kernel = _get_windowed_sinc_kernel( + cutoff, filter_type, fs, transition_bandwidth + ) + fft_result = np.fft.fft(kernel) + fft_result = np.fft.fftshift(fft_result) + fft_freq = np.fft.fftfreq(n=len(kernel), d=1 / fs) + fft_freq = np.fft.fftshift(fft_freq) + return pd.Series( + index=fft_freq[fft_freq >= 0], data=np.abs(fft_result[fft_freq >= 0]) + ) + else: + raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") diff --git a/tests/test_filtering.py b/tests/test_filtering.py index d158b1b1..d6b1e748 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -2,6 +2,7 @@ import pynapple as nap import numpy as np from scipy import signal +import pandas as pd import warnings from contextlib import nullcontext as does_not_raise @@ -34,7 +35,7 @@ def compare_scipy(tsd, ep, order, freq, fs, btype): def compare_sinc(tsd, ep, transition_bandwidth, freq, fs, ftype): - kernel = nap.process.filtering._get_windowed_sinc_kernel(freq/fs, ftype, transition_bandwidth) + kernel = nap.process.filtering._get_windowed_sinc_kernel(freq, ftype, fs, transition_bandwidth) return tsd.convolve(kernel, ep).d @@ -242,7 +243,7 @@ def test_filtering_nyquist_edge_case(nyquist_fraction, order): @pytest.mark.parametrize("tb", [0.2, 0.3]) def test_get_odd_kernel(tb): - kernel = nap.process.filtering._get_windowed_sinc_kernel(0.25, transition_bandwidth=tb) + kernel = nap.process.filtering._get_windowed_sinc_kernel(1, "lowpass", 4, transition_bandwidth=tb) assert len(kernel)%2 != 0 @pytest.mark.parametrize("filter_type, expected_exception", [ @@ -250,7 +251,7 @@ def test_get_odd_kernel(tb): ]) def test_get_kernel_error(filter_type, expected_exception): with expected_exception: - nap.process.filtering._get_windowed_sinc_kernel(0.25, filter_type=filter_type) + nap.process.filtering._get_windowed_sinc_kernel(1, filter_type, 4) def test_get__error(): with pytest.raises(TypeError, match=r"compute_lowpass_filter\(\) missing 1 required positional argument: 'data'"): @@ -258,7 +259,7 @@ def test_get__error(): def test_compare_sinc_kernel(): - kernel = nap.process.filtering._get_windowed_sinc_kernel(0.25) + kernel = nap.process.filtering._get_windowed_sinc_kernel(1, "lowpass", 4) x = np.arange(-(len(kernel)//2), 1+len(kernel)//2) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -273,3 +274,28 @@ def test_compare_sinc_kernel(): ikernel2 = kernel2 * -1.0 ikernel2[len(ikernel2) // 2] = 1.0 + ikernel2[len(kernel2) // 2] np.testing.assert_allclose(ikernel, ikernel2) + +@pytest.mark.parametrize("cutoff, fs, filter_type, mode, order, tb", [ + (250, 1000, "lowpass", "butter", 4, 0.02), + (250, 1000, "lowpass", "sinc", 4, 0.02), +]) +def test_get_filter_frequency_response(cutoff, fs, filter_type, mode, order, tb): + output = nap.get_filter_frequency_response(cutoff, fs, filter_type, mode, order, tb) + assert isinstance(output, pd.Series) + if mode == "butter": + sos = nap.process.filtering._get_butter_coefficients(cutoff, filter_type, fs, order) + w, h = signal.sosfreqz(sos, worN=1024, fs=fs) + np.testing.assert_array_almost_equal(w, output.index.values) + np.testing.assert_array_almost_equal(np.abs(h), output.values) + if mode == "sinc": + kernel = nap.process.filtering._get_windowed_sinc_kernel(cutoff, filter_type, fs, tb) + fft_result = np.fft.fft(kernel) + fft_result = np.fft.fftshift(fft_result) + fft_freq = np.fft.fftfreq(n=len(kernel), d=1 / fs) + fft_freq = np.fft.fftshift(fft_freq) + np.testing.assert_array_almost_equal(fft_freq[fft_freq >= 0], output.index.values) + np.testing.assert_array_almost_equal(np.abs(fft_result[fft_freq >= 0]), output.values) + +def test_get_filter_frequency_response_error(): + with pytest.raises(ValueError, match="Unrecognized filter mode. Choose either 'butter' or 'sinc'"): + nap.get_filter_frequency_response(250, 1000, "lowpass", "a", 4, 0.02) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 446a4fd0..97cd2361 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -514,16 +514,16 @@ def test_convolve_2d_kernel(self, tsd): def test_smooth(self, tsd): if not isinstance(tsd, nap.Ts): from scipy import signal - tsd2 = tsd.smooth(1) + tsd2 = tsd.smooth(1, size_factor=10) tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = std*100 + M = std*11 window = signal.windows.gaussian(M, std=std) window = window / window.sum() for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -541,10 +541,10 @@ def test_smooth(self, tsd): tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = std*200 + M = std*201 window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) @@ -554,10 +554,10 @@ def test_smooth(self, tsd): tmp = tsd.values.reshape(tsd.shape[0], -1) tmp2 = np.zeros_like(tmp) std = int(tsd.rate * 1) - M = int(tsd.rate * 10) + M = int(tsd.rate * 11) window = signal.windows.gaussian(M, std=std) for i in range(tmp.shape[-1]): - tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2-1:1-M//2-1] + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[M//2:1-M//2-1] np.testing.assert_array_almost_equal( tmp2, tsd2.values.reshape(tsd2.shape[0], -1) From 900c4b50f56ea26c9ce392eeab4c1ac6c23d86b0 Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Sep 2024 19:27:10 -0400 Subject: [PATCH 184/195] nan fix --- tests/test_numpy_compatibility.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index aecb5677..7b1b3777 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -11,7 +11,7 @@ # tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6)) -tsd.d[tsd.values>0.9] = np.NaN +tsd.d[tsd.values>0.9] = np.nan @pytest.mark.parametrize( From bbf629847b763c6010ce4eb8b16c0f8962ecd86c Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Mon, 9 Sep 2024 19:31:26 -0400 Subject: [PATCH 185/195] fix all nans --- pynapple/core/base_class.py | 2 +- tests/test_jitted.py | 8 ++++---- tests/test_time_series.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pynapple/core/base_class.py b/pynapple/core/base_class.py index c119b57a..8436c222 100644 --- a/pynapple/core/base_class.py +++ b/pynapple/core/base_class.py @@ -44,7 +44,7 @@ def __init__(self, t, time_units="s", time_support=None): self.time_support.values[:, 1] - self.time_support.values[:, 0] ) else: - self.rate = np.NaN + self.rate = np.nan self.time_support = IntervalSet(start=[], end=[]) @property diff --git a/tests/test_jitted.py b/tests/test_jitted.py index 7b437c7a..5b8a4aa1 100644 --- a/tests/test_jitted.py +++ b/tests/test_jitted.py @@ -52,9 +52,9 @@ def restrict(ep, tsd): ) ) ix3 = np.vstack((ix, ix2)).T - # ix[np.floor(ix / 2) * 2 != ix] = np.NaN + # ix[np.floor(ix / 2) * 2 != ix] = np.nan # ix = np.floor(ix/2) - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] @@ -90,7 +90,7 @@ def test_jitrestrict_with_count(): ix = np.array(pd.cut(tsd.index, bins, labels=np.arange(len(bins) - 1, dtype=np.float64))) ix2 = np.array(pd.cut(tsd.index,bins,labels=np.arange(len(bins) - 1, dtype=np.float64),right=False,)) ix3 = np.vstack((ix, ix2)).T - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] ix = ix3[:,0] @@ -417,7 +417,7 @@ def test_jitin_interval(): ) ) ix3 = np.vstack((ix, ix2)).T - ix3[np.floor(ix3 / 2) * 2 != ix3] = np.NaN + ix3[np.floor(ix3 / 2) * 2 != ix3] = np.nan ix3 = np.floor(ix3 / 2) ix3[np.isnan(ix3[:, 0]), 0] = ix3[np.isnan(ix3[:, 0]), 1] inep2 = ix3[:, 0] diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 97cd2361..7e098fe4 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -388,7 +388,7 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.values, new_tsd.values) tmp = np.random.rand(*tsd.shape) - tmp[tmp>0.9] = np.NaN + tmp[tmp>0.9] = np.nan tsd = tsd.__class__(t=tsd.t, d=tmp) new_tsd = tsd.dropna() @@ -406,7 +406,7 @@ def test_dropna(self, tsd): np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) np.testing.assert_array_equal(new_tsd.time_support, tsd.time_support) - tsd = tsd.__class__(t=tsd.t, d=np.ones(tsd.shape)*np.NaN) + tsd = tsd.__class__(t=tsd.t, d=np.ones(tsd.shape)*np.nan) new_tsd = tsd.dropna() assert len(new_tsd) == 0 assert len(new_tsd.time_support) == 0 From 6af28c04f0c0e46eb9a3959d7f5c6686ad20998b Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Sep 2024 17:05:02 -0400 Subject: [PATCH 186/195] added description of bands. Fixed typing --- docs/api_guide/tutorial_pynapple_filtering.py | 75 +++++++++++++++---- mkdocs.yml | 3 +- pynapple/process/filtering.py | 14 ++-- 3 files changed, 68 insertions(+), 24 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index 9bbf23d3..4134f3ca 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -81,7 +81,7 @@ # %% # Let's compare it to the `sinc` mode for Windowed-sinc. -sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc') +sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.003) # %% # Let's plot it @@ -106,7 +106,7 @@ # the 50 Hz component in the signal. sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') -sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc') +sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.004) # %% @@ -117,7 +117,6 @@ plt.xlim(0, 1) plt.legend() plt.subplot(212) -# plt.plot(sig, alpha=0.5) plt.plot(sig_butter, label = "Butterworth") plt.plot(sig_sinc, '--', label = "Windowed-sinc") plt.legend() @@ -144,29 +143,75 @@ # The remaining notebook compares the two modes. # # *** -# Frequency responses +# Frequency Responses # ------------------- # -# In order to check the validity of the filter, the function `get_filter_frequency_response` provides the frequency -# response of the filters. The calling signature is similar to the previous functions. -# The function returns a pandas Series with the frequencies as index. +# We can inspect the frequency response of a filter by plotting its power spectral density (PSD). +# To do this, we can use the `get_filter_frequency_response` function, which returns a pandas Series with the frequencies +# as the index and the PSD as values. +# +# Let's plot the frequency response of a Butterworth filter and a sinc low-pass filter. + +# compute the frequency response of the filters +psd_butter = nap.get_filter_frequency_response( + 200, fs,"lowpass", "butter", order=8 +) +psd_sinc = nap.get_filter_frequency_response( + 200, fs,"lowpass", "sinc", transition_bandwidth=0.1 +) + +# compute the transition bandwidth +tb_butter = psd_butter[psd_butter > 0.99].index.max(), psd_butter[psd_butter < 0.01].index.min() +tb_sinc = psd_sinc[psd_sinc > 0.99].index.max(), psd_sinc[psd_sinc < 0.01].index.min() + +fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(15, 5)) +fig.suptitle("Frequency response", fontsize="x-large") +axs[0].set_title("Butterworth Filter") +axs[0].plot(psd_butter) +axs[0].axvspan(0, tb_butter[0], alpha=0.4, color="green", label="Pass Band") +axs[0].axvspan(*tb_butter, alpha=0.4, color="orange", label="Transition Band") +axs[0].axvspan(tb_butter[1], 500, alpha=0.4, color="red", label="Stop Band") +axs[0].legend().get_frame().set_alpha(1.) +axs[0].set_xlim(0, 500) +axs[0].set_xlabel("Frequency (Hz)") +axs[0].set_ylabel("Amplitude") + +axs[1].set_title("Sinc Filter") +axs[1].plot(psd_sinc) +axs[1].axvspan(0, tb_sinc[0], alpha=0.4, color="green", label="Pass Band") +axs[1].axvspan(*tb_sinc, alpha=0.4, color="orange", label="Transition Band") +axs[1].axvspan(tb_sinc[1], 500, alpha=0.4, color="red", label="Stop Band") +axs[1].legend().get_frame().set_alpha(1.) +axs[1].set_xlabel("Frequency (Hz)") + +# %% +# The frequency band with response close to one will be preserved by the filtering (pass band), +# the band with response close to zero will be discarded (stop band), and the band in between will be partially attenuated +# (transition band). +# +# ??? note "Transition Bandwidth (Click to expand/collapse)" +# Here, we define the transition band as the range where the amplitude attenuation is between 99% and 1%. +# The `transition_bandwidth` parameter of the sinc filter is approximately the width of the transition +# band normalized by the sampling frequency. In the example above, if you divide the transition band width +# of 122Hz by the sampling frequency of 1000Hz, you get 0.122, which is close to the 0.1 value set. # -# Let's get the frequency response for a Butterworth low pass filter with different order: +# You can modulate the width of the transition band by setting the `order` parameter of the Butterworth filter +# or the `transition_bandwidth` parameter of the sinc filter. +# First, let's get the frequency response for a Butterworth low pass filter with different order: + butter_freq = { order: nap.get_filter_frequency_response(250, fs, "lowpass", "butter", order=order) for order in [2, 4, 6]} # %% -# ... and the frequency response for the Windowed-sinc equivalent with different transition bandwidth. +# ... and then the frequency response for the Windowed-sinc equivalent with different transition bandwidth. sinc_freq = { - tb:nap.get_filter_frequency_response(250, fs,"lowpass", "sinc", transition_bandwidth=tb) + tb: nap.get_filter_frequency_response(250, fs,"lowpass", "sinc", transition_bandwidth=tb) for tb in [0.002, 0.02, 0.2]} # %% # Let's plot the frequency response of both. -from scipy import signal - fig = plt.figure(figsize = (20, 10)) gs = plt.GridSpec(2, 2) for order in butter_freq.keys(): @@ -196,8 +241,8 @@ plt.legend() # %% -# In some cases, the transition bandwidth that is too high generates a kernel that is too short. The amplitude of the -# original signal will then be lower than expected. +# ⚠️ **Warning:** In some cases, the transition bandwidth that is too high generates a kernel that is too short. +# The amplitude of the original signal will then be lower than expected. # In this case, the solution is to decrease the transition bandwidth when using the windowed-sinc mode. # Note that this increases the length of the kernel significantly. # Let see it with the band pass filter. @@ -287,5 +332,3 @@ def benchmark_dimensions(mode): plt.xlabel("Number of dimensions") plt.ylabel("Time (s)") plt.title("Low pass filtering benchmark") - -plt.show() diff --git a/mkdocs.yml b/mkdocs.yml index 063d2549..b2e35d8d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -65,4 +65,5 @@ markdown_extensions: pygments_lang_class: true - pymdownx.inlinehilite - pymdownx.snippets - - pymdownx.superfences \ No newline at end of file + - pymdownx.superfences + - admonition diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index fc48284e..993e2c70 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -110,7 +110,7 @@ def _get_windowed_sinc_kernel( ------- np.ndarray """ - M = int(np.rint(20.0 / transition_bandwidth)) + M = int(np.rint(4.0 / transition_bandwidth)) x = np.arange(-(M // 2), 1 + (M // 2)) fc = np.transpose(np.atleast_2d(fc / sampling_frequency)) kernel = np.sinc(2 * fc * x) @@ -191,7 +191,7 @@ def _compute_filter( if fs is None: fs = data.rate - cutoff = np.array(cutoff) + cutoff = np.array(cutoff, dtype=float) if mode == "butter": return _compute_butterworth_filter( @@ -218,7 +218,7 @@ def compute_bandpass_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - cutoff : tuple of (float, float) + cutoff : (Numeric, Numeric) Cutoff frequencies in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -275,7 +275,7 @@ def compute_bandstop_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - cutoff : tuple of (float, float) + cutoff : (Numeric, Numeric) Cutoff frequencies in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -332,7 +332,7 @@ def compute_highpass_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - cutoff : float + cutoff : Numeric Cutoff frequency in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -389,7 +389,7 @@ def compute_lowpass_filter( ---------- data : Tsd, TsdFrame, or TsdTensor The signal to be filtered. - cutoff : float + cutoff : Numeric Cutoff frequency in Hz. fs : float, optional The sampling frequency of the signal in Hz. If not provided, it will be inferred from the time axis of the data. @@ -446,7 +446,7 @@ def get_filter_frequency_response( Parameters ---------- - cutoff : float or tuple of float + cutoff : Numeric or tuple of Numeric Cutoff frequency in Hz. fs : float The sampling frequency of the signal in Hz. From 6309c925922bd2f44edd620b9a1b23fa4b0c9492 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 11 Sep 2024 17:26:39 -0400 Subject: [PATCH 187/195] adding jax compatibility --- pynapple/process/filtering.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index fc48284e..3f0ca72d 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -60,10 +60,23 @@ def _compute_butterworth_filter( Apply a Butterworth filter to the provided signal. """ sos = _get_butter_coefficients(cutoff, filter_type, sampling_frequency, order) - out = np.zeros_like(data.d) - for ep in data.time_support: - slc = data.get_slice(start=ep.start[0], end=ep.end[0]) - out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) + + if nap.utils.get_backend() == "jax": + from pynajax.jax_process_filtering import jax_sosfiltfilt + + out = jax_sosfiltfilt( + sos, + data.index.values, + data.values, + data.time_support.start, + data.time_support.end, + ) + + else: + out = np.zeros_like(data.d) + for ep in data.time_support: + slc = data.get_slice(start=ep.start[0], end=ep.end[0]) + out[slc] = sosfiltfilt(sos, data.d[slc], axis=0) kwargs = dict(t=data.t, d=out, time_support=data.time_support) if isinstance(data, nap.TsdFrame): From 5f3eb77f395d57e6e3ccbeef785029798b85d5df Mon Sep 17 00:00:00 2001 From: BalzaniEdoardo Date: Wed, 11 Sep 2024 17:48:55 -0400 Subject: [PATCH 188/195] improved flow --- docs/api_guide/tutorial_pynapple_filtering.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index 4134f3ca..cd5c6f2c 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -150,7 +150,7 @@ # To do this, we can use the `get_filter_frequency_response` function, which returns a pandas Series with the frequencies # as the index and the PSD as values. # -# Let's plot the frequency response of a Butterworth filter and a sinc low-pass filter. +# Let's extract the frequency response of a Butterworth filter and a sinc low-pass filter. # compute the frequency response of the filters psd_butter = nap.get_filter_frequency_response( @@ -160,6 +160,9 @@ 200, fs,"lowpass", "sinc", transition_bandwidth=0.1 ) +# %% +# ...and plot it. + # compute the transition bandwidth tb_butter = psd_butter[psd_butter > 0.99].index.max(), psd_butter[psd_butter < 0.01].index.min() tb_sinc = psd_sinc[psd_sinc > 0.99].index.max(), psd_sinc[psd_sinc < 0.01].index.min() @@ -184,6 +187,9 @@ axs[1].legend().get_frame().set_alpha(1.) axs[1].set_xlabel("Frequency (Hz)") +print(f"Transition band butterworth filter: ({int(tb_butter[0])}Hz, {int(tb_butter[1])}Hz)") +print(f"Transition band sinc filter: ({int(tb_sinc[0])}Hz, {int(tb_sinc[1])}Hz)") + # %% # The frequency band with response close to one will be preserved by the filtering (pass band), # the band with response close to zero will be discarded (stop band), and the band in between will be partially attenuated From d4d4941994be683ad84cc5bf62f8987c6289a70e Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 12 Sep 2024 15:49:55 -0400 Subject: [PATCH 189/195] Fix tests for pynajax --- tests/test_filtering.py | 16 ++++++++-------- tests/test_utils.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index d6b1e748..63bc0bd8 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -63,11 +63,11 @@ def test_low_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequ transition_bandwidth=transition_bandwidth) if mode == "butter": out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "lowpass") - np.testing.assert_array_equal(out.d, out_sci) + np.testing.assert_array_almost_equal(out.d, out_sci) if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "lowpass") - np.testing.assert_array_equal(out.d, out_sinc) + np.testing.assert_array_almost_equal(out.d, out_sinc) assert isinstance(out, type(tsd)) assert np.all(out.t == tsd.t) @@ -101,11 +101,11 @@ def test_high_pass(freq, mode, order, transition_bandwidth, shape, sampling_freq if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "highpass") - np.testing.assert_array_equal(out.d, out_sinc) + np.testing.assert_array_almost_equal(out.d, out_sinc) if mode == "butter": out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "highpass") - np.testing.assert_array_equal(out.d, out_sci) + np.testing.assert_array_almost_equal(out.d, out_sci) assert isinstance(out, type(tsd)) assert np.all(out.t == tsd.t) @@ -139,11 +139,11 @@ def test_bandpass(freq, mode, order, transition_bandwidth, shape, sampling_frequ if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandpass") - np.testing.assert_array_equal(out.d, out_sinc) + np.testing.assert_array_almost_equal(out.d, out_sinc) if mode == "butter": out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandpass") - np.testing.assert_array_equal(out.d, out_sci) + np.testing.assert_array_almost_equal(out.d, out_sci) assert isinstance(out, type(tsd)) assert np.all(out.t == tsd.t) @@ -177,11 +177,11 @@ def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequ if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandstop") - np.testing.assert_array_equal(out.d, out_sinc) + np.testing.assert_array_almost_equal(out.d, out_sinc) if mode == "butter": out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "bandstop") - np.testing.assert_array_equal(out.d, out_sci) + np.testing.assert_array_almost_equal(out.d, out_sci) assert isinstance(out, type(tsd)) diff --git a/tests/test_utils.py b/tests/test_utils.py index fdafbb3d..71c06027 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,7 +6,7 @@ import pytest def test_get_backend(): - assert nap.core.utils.get_backend() == "numba" + assert nap.core.utils.get_backend() in ["numba", "jax"] def test_is_array_like(): assert nap.core.utils.is_array_like(np.ones(3)) From 764d6b67997e578c20089c28575050952930a10d Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 12 Sep 2024 17:15:34 -0400 Subject: [PATCH 190/195] Updating --- docs/api_guide/tutorial_pynapple_filtering.py | 30 +++++++-------- docs/examples/tutorial_phase_preferences.py | 2 +- pynapple/process/__init__.py | 8 ++-- pynapple/process/filtering.py | 35 ++++++++++++----- tests/test_filtering.py | 38 +++++++++---------- 5 files changed, 64 insertions(+), 49 deletions(-) diff --git a/docs/api_guide/tutorial_pynapple_filtering.py b/docs/api_guide/tutorial_pynapple_filtering.py index cd5c6f2c..87516578 100644 --- a/docs/api_guide/tutorial_pynapple_filtering.py +++ b/docs/api_guide/tutorial_pynapple_filtering.py @@ -5,16 +5,16 @@ The filtering module holds the functions for frequency manipulation : -- `nap.compute_bandstop_filter` -- `nap.compute_lowpass_filter` -- `nap.compute_highpass_filter` -- `nap.compute_bandpass_filter` +- `nap.apply_bandstop_filter` +- `nap.apply_lowpass_filter` +- `nap.apply_highpass_filter` +- `nap.apply_bandpass_filter` The functions have similar calling signatures. For example, to filter a 1000 Hz signal between 10 and 20 Hz using a Butterworth filter: ```{python} ->>> new_tsd = nap.compute_bandpass_filter(tsd, (10, 20), fs=1000, mode='butter') +>>> new_tsd = nap.apply_bandpass_filter(tsd, (10, 20), fs=1000, mode='butter') ``` Currently, the filtering module provides two methods for frequency manipulation: `butter` @@ -75,13 +75,13 @@ # %% # Let's say we would like to see only the 10 Hz component. -# We can use the function `compute_bandpass_filter` with mode `butter` for Butterworth. +# We can use the function `apply_bandpass_filter` with mode `butter` for Butterworth. -sig_butter = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='butter') +sig_butter = nap.apply_bandpass_filter(sig, (8, 12), fs, mode='butter') # %% # Let's compare it to the `sinc` mode for Windowed-sinc. -sig_sinc = nap.compute_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.003) +sig_sinc = nap.apply_bandpass_filter(sig, (8, 12), fs, mode='sinc', transition_bandwidth=0.003) # %% # Let's plot it @@ -95,7 +95,7 @@ plt.plot(sig_butter, label = "Butterworth") plt.plot(sig_sinc, '--', label = "Windowed-sinc") plt.legend() -plt.xlabel("Time (Hz)") +plt.xlabel("Time (s)") plt.xlim(0, 1) @@ -105,8 +105,8 @@ # Another use of filtering is to remove some frequencies. Here we can try to remove # the 50 Hz component in the signal. -sig_butter = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') -sig_sinc = nap.compute_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.004) +sig_butter = nap.apply_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='butter') +sig_sinc = nap.apply_bandstop_filter(sig, cutoff=(45, 55), fs=fs, mode='sinc', transition_bandwidth=0.004) # %% @@ -256,7 +256,7 @@ sinc_freq = { tb:nap.get_filter_frequency_response((100, 200), fs, "bandpass", "sinc", transition_bandwidth=tb) - for tb in [0.004, 0.5]} + for tb in [0.004, 0.2]} fig = plt.figure(figsize = (20, 10)) @@ -280,14 +280,14 @@ def get_mean_perf(tsd, mode, n=10): tmp = np.zeros(n) for i in range(n): t1 = perf_counter() - _ = nap.compute_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode) + _ = nap.apply_lowpass_filter(tsd, 0.25 * tsd.rate, mode=mode) t2 = perf_counter() tmp[i] = t2 - t1 return [np.mean(tmp), np.std(tmp)] def benchmark_time_points(mode): times = [] - for T in np.arange(1000, 100000, 40000): + for T in np.arange(1000, 100000, 20000): time_array = np.arange(T)/1000 data_array = np.random.randn(len(time_array)) startend = np.linspace(0, time_array[-1], T//100).reshape(T//200, 2) @@ -298,7 +298,7 @@ def benchmark_time_points(mode): def benchmark_dimensions(mode): times = [] - for n in np.arange(1, 100, 30): + for n in np.arange(1, 100, 10): time_array = np.arange(10000)/1000 data_array = np.random.randn(len(time_array), n) startend = np.linspace(0, time_array[-1], 10000//100).reshape(10000//200, 2) diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index d826d2c3..7f8583cc 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -156,7 +156,7 @@ def plot_timefrequency(freqs, powers, ax=None): # # As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.compute_bandpass_filter`. -theta_band = nap.compute_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) +theta_band = nap.apply_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) # %% # We can plot the original signal and the filtered signal. diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 8a913e28..5a92f223 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -5,10 +5,10 @@ ) from .decoding import decode_1d, decode_2d from .filtering import ( - compute_bandpass_filter, - compute_bandstop_filter, - compute_highpass_filter, - compute_lowpass_filter, + apply_bandpass_filter, + apply_bandstop_filter, + apply_highpass_filter, + apply_lowpass_filter, get_filter_frequency_response, ) from .perievent import ( diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 51a7cc8c..9555354f 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -1,6 +1,7 @@ """Filtering module.""" import inspect +from collections.abc import Iterable from functools import wraps from numbers import Number @@ -18,11 +19,21 @@ def wrapper(*args, **kwargs): sig = inspect.signature(func) kwargs = sig.bind_partial(*args, **kwargs).arguments - if not isinstance(kwargs["cutoff"], Number): - if len(kwargs["cutoff"]) != 2 or not all( - isinstance(fq, Number) for fq in kwargs["cutoff"] + cutoff = kwargs["cutoff"] + filter_type = kwargs["filter_type"] + if filter_type in ["lowpass", "highpass"] and not isinstance(cutoff, Number): + raise ValueError( + f"{filter_type} filter require a single number. {cutoff} provided instead." + ) + if filter_type in ["bandpass", "bandstop"]: + if ( + not isinstance(cutoff, Iterable) + or len(cutoff) != 2 + or not all(isinstance(fq, Number) for fq in cutoff) ): - raise ValueError + raise ValueError( + f"{filter_type} filter require a tuple of two numbers. {cutoff} provided instead." + ) if "fs" in kwargs: if kwargs["fs"] is not None and not isinstance(kwargs["fs"], Number): @@ -218,12 +229,13 @@ def _compute_filter( raise ValueError("Unrecognized filter mode. Choose either 'butter' or 'sinc'") -def compute_bandpass_filter( +def apply_bandpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a band-pass filter to the provided signal. Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. @@ -275,12 +287,13 @@ def compute_bandpass_filter( ) -def compute_bandstop_filter( +def apply_bandstop_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a band-stop filter to the provided signal. Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. @@ -332,12 +345,13 @@ def compute_bandstop_filter( ) -def compute_highpass_filter( +def apply_highpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a high-pass filter to the provided signal. Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. @@ -389,12 +403,13 @@ def compute_highpass_filter( ) -def compute_lowpass_filter( +def apply_lowpass_filter( data, cutoff, fs=None, mode="butter", order=4, transition_bandwidth=0.02 ): """ Apply a low-pass filter to the provided signal. Mode can be : + - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. @@ -452,8 +467,8 @@ def get_filter_frequency_response( ): """ Utility function to evaluate the frequency response of a particular type of filter. The arguments are the same - as the function `compute_lowpass_filter`, `compute_highpass_filter`, `compute_bandpass_filter` and - `compute_bandstop_filter`. + as the function `apply_lowpass_filter`, `apply_highpass_filter`, `apply_bandpass_filter` and + `apply_bandstop_filter`. This function returns a pandas Series object with the index as frequencies. diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 63bc0bd8..19461760 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -59,8 +59,8 @@ def test_low_pass(freq, mode, order, transition_bandwidth, shape, sampling_frequ if sampling_frequency is not None and sampling_frequency != tsd.rate: sampling_frequency = tsd.rate - out = nap.compute_lowpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, - transition_bandwidth=transition_bandwidth) + out = nap.apply_lowpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) if mode == "butter": out_sci = compare_scipy(tsd, ep, order, freq, tsd.rate, "lowpass") np.testing.assert_array_almost_equal(out.d, out_sci) @@ -96,8 +96,8 @@ def test_high_pass(freq, mode, order, transition_bandwidth, shape, sampling_freq if sampling_frequency is not None and sampling_frequency != tsd.rate: sampling_frequency = tsd.rate - out = nap.compute_highpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, - transition_bandwidth=transition_bandwidth) + out = nap.apply_highpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "highpass") @@ -134,8 +134,8 @@ def test_bandpass(freq, mode, order, transition_bandwidth, shape, sampling_frequ if sampling_frequency is not None and sampling_frequency != tsd.rate: sampling_frequency = tsd.rate - out = nap.compute_bandpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, - transition_bandwidth=transition_bandwidth) + out = nap.apply_bandpass_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandpass") @@ -172,8 +172,8 @@ def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequ if sampling_frequency is not None and sampling_frequency != tsd.rate: sampling_frequency = tsd.rate - out = nap.compute_bandstop_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, - transition_bandwidth=transition_bandwidth) + out = nap.apply_bandstop_filter(tsd, freq, fs=sampling_frequency, mode=mode, order=order, + transition_bandwidth=transition_bandwidth) if mode == "sinc": out_sinc = compare_sinc(tsd, ep, transition_bandwidth, freq, tsd.rate, "bandstop") @@ -194,10 +194,10 @@ def test_bandstop(freq, mode, order, transition_bandwidth, shape, sampling_frequ # Errors ######################################################################## @pytest.mark.parametrize("func, freq", [ - (nap.compute_lowpass_filter, 10), - (nap.compute_highpass_filter, 10), - (nap.compute_bandpass_filter, [10, 20]), - (nap.compute_bandstop_filter, [10, 20]), + (nap.apply_lowpass_filter, 10), + (nap.apply_highpass_filter, 10), + (nap.apply_bandpass_filter, [10, 20]), + (nap.apply_bandstop_filter, [10, 20]), ]) @pytest.mark.parametrize("data, fs, mode, order, transition_bandwidth, expected_exception", [ (sample_data(), None, "butter", "a", 0.02, pytest.raises(ValueError,match="Invalid value for 'order': Parameter 'order' should be of type int")), @@ -214,10 +214,10 @@ def test_compute_filtered_signal_raise_errors(func, freq, data, fs, mode, order, func(data, freq, fs=fs, mode=mode, order=order, transition_bandwidth=transition_bandwidth) @pytest.mark.parametrize("func, freq, expected_exception", [ - (nap.compute_lowpass_filter, "a", pytest.raises(ValueError)), - (nap.compute_highpass_filter, "b", pytest.raises(ValueError)), - (nap.compute_bandpass_filter, [10, "b"], pytest.raises(ValueError)), - (nap.compute_bandstop_filter, [10, 20, 30], pytest.raises(ValueError)), + (nap.apply_lowpass_filter, "a", pytest.raises(ValueError,match=r"lowpass filter require a single number. a provided instead.")), + (nap.apply_highpass_filter, "b", pytest.raises(ValueError,match=r"highpass filter require a single number. b provided instead.")), + (nap.apply_bandpass_filter, [10, "b"], pytest.raises(ValueError,match="bandpass filter require a tuple of two numbers. \[10, 'b'\] provided instead.")), + (nap.apply_bandstop_filter, [10, 20, 30], pytest.raises(ValueError,match=r"bandstop filter require a tuple of two numbers. \[10, 20, 30\] provided instead.")) ]) def test_compute_filtered_signal_bad_freq(func, freq, expected_exception): with expected_exception: @@ -233,7 +233,7 @@ def test_filtering_nyquist_edge_case(nyquist_fraction, order): nyquist_freq = 0.5 * data.rate freq = nyquist_freq * nyquist_fraction - out = nap.filtering.compute_lowpass_filter(data, freq, order=order) + out = nap.filtering.apply_lowpass_filter(data, freq, order=order) assert isinstance(out, type(data)) np.testing.assert_allclose(out.t, data.t) np.testing.assert_allclose(out.time_support, data.time_support) @@ -254,8 +254,8 @@ def test_get_kernel_error(filter_type, expected_exception): nap.process.filtering._get_windowed_sinc_kernel(1, filter_type, 4) def test_get__error(): - with pytest.raises(TypeError, match=r"compute_lowpass_filter\(\) missing 1 required positional argument: 'data'"): - nap.compute_lowpass_filter(cutoff=0.25) + with pytest.raises(TypeError, match=r"apply_lowpass_filter\(\) missing 1 required positional argument: 'data'"): + nap.apply_lowpass_filter(cutoff=0.25) def test_compare_sinc_kernel(): From a86090c9937c4bdbccd2fc9f8adfa8cf26397cba Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 12 Sep 2024 17:23:32 -0400 Subject: [PATCH 191/195] Update --- pynapple/process/filtering.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pynapple/process/filtering.py b/pynapple/process/filtering.py index 9555354f..dc13f967 100644 --- a/pynapple/process/filtering.py +++ b/pynapple/process/filtering.py @@ -236,8 +236,8 @@ def apply_bandpass_filter( Apply a band-pass filter to the provided signal. Mode can be : - - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. Parameters ---------- @@ -294,8 +294,8 @@ def apply_bandstop_filter( Apply a band-stop filter to the provided signal. Mode can be : - - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. Parameters ---------- @@ -352,8 +352,8 @@ def apply_highpass_filter( Apply a high-pass filter to the provided signal. Mode can be : - - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. Parameters ---------- @@ -410,8 +410,8 @@ def apply_lowpass_filter( Apply a low-pass filter to the provided signal. Mode can be : - - 'butter' for Butterworth filter. In this case, `order` determines the order of the filter. - - 'sinc' for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. + - `"butter"` for Butterworth filter. In this case, `order` determines the order of the filter. + - `"sinc"` for Windowed-Sinc convolution. `transition_bandwidth` determines the transition bandwidth. Parameters ---------- From 6da301bec9c69198b036d221c55bc17e7917319b Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 13 Sep 2024 14:19:10 -0400 Subject: [PATCH 192/195] firt commit --- README.md | 11 +- docs/AUTHORS.md | 5 +- docs/HISTORY.md | 15 +- docs/examples/tutorial_phase_preferences.py | 2 +- docs/index.md | 12 +- pynapple/process/__init__.py | 5 +- pynapple/process/spectrum.py | 207 ++++++++++++++++++ .../{signal_processing.py => wavelets.py} | 207 +----------------- 8 files changed, 250 insertions(+), 214 deletions(-) create mode 100644 pynapple/process/spectrum.py rename pynapple/process/{signal_processing.py => wavelets.py} (54%) diff --git a/README.md b/README.md index 9f7a984b..27d61abf 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,15 @@ pynapple is a light-weight python library for neurophysiological data analysis. New release :fire: ------------------ +### pynapple >= 0.7 + +Pynapple now implements signal processing. For example, you can filter any time series with a particular bandpass: + +```python +nap.apply_bandpass_filter(signal, (10, 20), fs=1250) +``` +New functions includes power spectral density and Morlet wavelet decomposition. See the [documentation](https://pynapple-org.github.io/pynapple/reference/process/) for more details. + ### pynapple >= 0.6 Starting with 0.6, [`IntervalSet`](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) objects are behaving as immutable numpy ndarray. Before 0.6, you could select an interval within an `IntervalSet` object with: @@ -45,8 +54,6 @@ With pynapple>=0.6, the slicing is similar to numpy and it returns an `IntervalS new_intervalset = intervalset[0] ``` -See the [documentation](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) for more details. - ### pynapple >= 0.4 Starting with 0.4, pynapple rely on the [numpy array container](https://numpy.org/doc/stable/user/basics.dispatch.html) approach instead of Pandas for the time series. Pynapple builtin functions will remain the same except for functions inherited from Pandas. diff --git a/docs/AUTHORS.md b/docs/AUTHORS.md index 6d9ff76d..d6182d39 100644 --- a/docs/AUTHORS.md +++ b/docs/AUTHORS.md @@ -5,14 +5,15 @@ Development Lead ---------------- - Guillaume Viejo +- Edoardo Balzani Contributors ------------ -- Edoardo Balzani - Adrien Peyrache - Dan Levenstein - Sofia Skromne Carrasco - Davide Spalla -- Luigi Petrucco \ No newline at end of file +- Luigi Petrucco + - ... [and many more!](https://github.com/pynapple-org/pynapple/graphs/contributors) \ No newline at end of file diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 07ba8ce0..75624468 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -6,7 +6,20 @@ Another postdoc in the lab, Francesco Battaglia, then made major contributions t Around 2016-2017, Luke Sjulson started *TSToolbox2*, still in Matlab and which includes some important changes. In 2018, Francesco started neuroseries, a Python package built on Pandas. It was quickly adopted in Adrien's lab, especially by Guillaume Viejo, a postdoc in the lab. Gradually, the majority of the lab was using it and new functions were constantly added. -In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. +In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. +The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. + +Since 2023, the development of pynapple is lead by Guillaume Viejo and Edoardo Balzani at the Center for Computational Neuroscience +of the Flatiron institute. + + + +0.7.0 (2024-09-16) +------------------ + +- Morlet wavelets spectrogram with utility for plotting the wavelets. +- (Mean) Power spectral density. Returns a Pandas DataFrame. +- 0.6.6 (2024-05-28) ------------------ diff --git a/docs/examples/tutorial_phase_preferences.py b/docs/examples/tutorial_phase_preferences.py index 7f8583cc..c500ca82 100644 --- a/docs/examples/tutorial_phase_preferences.py +++ b/docs/examples/tutorial_phase_preferences.py @@ -154,7 +154,7 @@ def plot_timefrequency(freqs, powers, ax=None): # Filtering Theta # --------------- # -# As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.compute_bandpass_filter`. +# As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.apply_bandpass_filter`. theta_band = nap.apply_bandpass_filter(tsd_rem, cutoff=(6.0, 10.0), fs=FS) diff --git a/docs/index.md b/docs/index.md index 0df91139..c6a121b9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,6 +30,15 @@ To ask any questions or get support for using pynapple, please consider joining New releases :fire: ------------------ +### pynapple >= 0.7 + +Pynapple now implements signal processing. For example, you can filter any time series with a particular bandpass: + +```python +nap.apply_bandpass_filter(signal, (10, 20), fs=1250) +``` +New functions includes power spectral density and Morlet wavelet decomposition. See the [documentation](https://pynapple-org.github.io/pynapple/reference/process/) for more details. + ### pynapple >= 0.6 Starting with 0.6, [`IntervalSet`](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) objects are behaving as immutable numpy ndarray. Before 0.6, you could select an interval within an `IntervalSet` object with: @@ -44,8 +53,6 @@ With pynapple>=0.6, the slicing is similar to numpy and it returns an `IntervalS new_intervalset = intervalset[0] ``` -See the [documentation](https://pynapple-org.github.io/pynapple/reference/core/interval_set/) for more details. - ### pynapple >= 0.4 Starting with 0.4, pynapple rely on the [numpy array container](https://numpy.org/doc/stable/user/basics.dispatch.html) approach instead of Pandas for the time series. Pynapple builtin functions will remain the same except for functions inherited from Pandas. @@ -54,7 +61,6 @@ This allows for a better handling of returned objects. Additionaly, it is now possible to define time series objects with more than 2 dimensions with `TsdTensor`. You can also look at this [notebook](https://pynapple-org.github.io/pynapple/generated/gallery/tutorial_pynapple_numpy/) for a demonstration of numpy compatibilities. - Getting Started --------------- diff --git a/pynapple/process/__init__.py b/pynapple/process/__init__.py index 5a92f223..b7d9576c 100644 --- a/pynapple/process/__init__.py +++ b/pynapple/process/__init__.py @@ -22,11 +22,9 @@ shift_timestamps, shuffle_ts_intervals, ) -from .signal_processing import ( +from .spectrum import ( compute_mean_power_spectral_density, compute_power_spectral_density, - compute_wavelet_transform, - generate_morlet_filterbank, ) from .tuning_curves import ( compute_1d_mutual_info, @@ -37,3 +35,4 @@ compute_2d_tuning_curves_continuous, compute_discrete_tuning_curves, ) +from .wavelets import compute_wavelet_transform, generate_morlet_filterbank diff --git a/pynapple/process/spectrum.py b/pynapple/process/spectrum.py new file mode 100644 index 00000000..30668c06 --- /dev/null +++ b/pynapple/process/spectrum.py @@ -0,0 +1,207 @@ +""" +# Power spectral density + +This module contains functions to compute power spectral density and mean power spectral density. + +""" + +from numbers import Number + +import numpy as np +import pandas as pd +from scipy import signal + +from .. import core as nap + + +def compute_power_spectral_density( + sig, fs=None, ep=None, full_range=False, norm=False, n=None +): + """ + Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Time series. + fs : float, optional + Sampling rate, in Hz. If None, will be calculated from the given signal + ep : None or pynapple.IntervalSet, optional + The epoch to calculate the fft on. Must be length 1. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude + n: int, optional + Length of the transformed axis of the output. If n is smaller than the length of the input, + the input is cropped. If it is larger, the input is padded with zeros. If n is not given, + the length of the input along the axis specified by axis is used. + + Returns + ------- + pandas.DataFrame + Time frequency representation of the input signal, indexes are frequencies, values + are powers. + + Notes + ----- + compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep + parameter otherwise will be sig.time_support, but it must only be a single epoch. + """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + if len(ep) != 1: + raise ValueError("Given epoch (or signal time_support) must have length 1") + if fs is None: + fs = sig.rate + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + + fft_result = np.fft.fft(sig.restrict(ep).values, n=n, axis=0) + if n is None: + n = len(sig.restrict(ep)) + fft_freq = np.fft.fftfreq(n, 1 / fs) + + if norm: + fft_result = fft_result / fft_result.shape[0] + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + + if not full_range: + return ret.loc[ret.index >= 0] + return ret + + +def compute_mean_power_spectral_density( + sig, + interval_size, + fs=None, + ep=None, + full_range=False, + norm=False, + time_unit="s", +): + """ + Compute mean power spectral density by averaging FFT over epochs of same size. + + The parameter `interval_size` controls the duration of the epochs. + + To imporve frequency resolution, the signal is multiplied by a Hamming window. + + Note that this function assumes a constant sampling rate for `sig`. + + Parameters + ---------- + sig : pynapple.Tsd or pynapple.TsdFrame + Signal with equispaced samples + interval_size : Number + Epochs size to compute to average the FFT across + fs : None, optional + Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` + ep : None or pynapple.IntervalSet, optional + The `IntervalSet` to calculate the fft on. Can be any length. + full_range : bool, optional + If true, will return full fft frequency range, otherwise will return only positive values + norm: bool, optional + Whether the FFT result is divided by the length of the signal to normalize the amplitude + time_unit : str, optional + Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') + + Returns + ------- + pandas.DataFrame + Power spectral density. + + Examples + -------- + >>> import numpy as np + >>> import pynapple as nap + >>> t = np.arange(0, 1, 1/1000) + >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) + >>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1) + + Raises + ------ + RuntimeError + If splitting the epoch with `interval_size` results in an empty set. + TypeError + If `ep` or `sig` are not respectively pynapple time series or interval set. + """ + if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): + raise TypeError("sig must be either a Tsd or a TsdFrame object.") + + if not (ep is None or isinstance(ep, nap.IntervalSet)): + raise TypeError("ep param must be a pynapple IntervalSet object, or None") + if ep is None: + ep = sig.time_support + + if not (fs is None or isinstance(fs, Number)): + raise TypeError("fs must be of type float or int") + if fs is None: + fs = sig.rate + + if not isinstance(full_range, bool): + raise TypeError("full_range must be of type bool or None") + + if not isinstance(norm, bool): + raise TypeError("norm must be of type bool") + + # Split the ep + interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ + 0 + ] + split_ep = ep.split(interval_size) + + if len(split_ep) == 0: + raise RuntimeError( + f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size" + ) + + # Get the slices of each ep + slices = np.zeros((len(split_ep), 2), dtype=int) + + for i in range(len(split_ep)): + sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) + slices[i, 0] = sl.start + slices[i, 1] = sl.stop + + # Check what is the signal length + N = np.min(np.diff(slices, 1)) + + if N == 0: + raise RuntimeError( + "One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed." + ) + + # Get the freqs + fft_freq = np.fft.fftfreq(N, 1 / fs) + + # Get the Hamming window + window = signal.windows.hamming(N) + if sig.ndim == 2: + window = window[:, np.newaxis] + + # Compute the fft + fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) + + for i in range(len(slices)): + tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window + fft_result += np.fft.fft(tmp, axis=0) + + if norm: + fft_result = fft_result / (float(N) * float(len(slices))) + + ret = pd.DataFrame(fft_result, fft_freq) + ret.sort_index(inplace=True) + if not full_range: + return ret.loc[ret.index >= 0] + return ret diff --git a/pynapple/process/signal_processing.py b/pynapple/process/wavelets.py similarity index 54% rename from pynapple/process/signal_processing.py rename to pynapple/process/wavelets.py index fb575db8..e8aa601c 100644 --- a/pynapple/process/signal_processing.py +++ b/pynapple/process/wavelets.py @@ -1,215 +1,18 @@ """ -# Signal processing tools. +# Wavelets decomposition -- `nap.compute_power_spectral_density` -- `nap.compute_mean_power_spectral_density` -- `nap.compute_wavelet_transform` -- `nap.generate_morlet_filterbank` +The main function for doing wavelet decomposition is `nap.compute_wavelet_transform` -""" +For now, pynapple only implements Morlet wavelets. To check the shape and quality of the wavelets, check out +the function `nap.generate_morlet_filterbank` to plot the wavelets. -from numbers import Number +""" import numpy as np -import pandas as pd -from scipy import signal from .. import core as nap -def compute_power_spectral_density( - sig, fs=None, ep=None, full_range=False, norm=False, n=None -): - """ - Perform numpy fft on sig, returns output assuming a constant sampling rate for the signal. - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Time series. - fs : float, optional - Sampling rate, in Hz. If None, will be calculated from the given signal - ep : None or pynapple.IntervalSet, optional - The epoch to calculate the fft on. Must be length 1. - full_range : bool, optional - If true, will return full fft frequency range, otherwise will return only positive values - norm: bool, optional - Whether the FFT result is divided by the length of the signal to normalize the amplitude - n: int, optional - Length of the transformed axis of the output. If n is smaller than the length of the input, - the input is cropped. If it is larger, the input is padded with zeros. If n is not given, - the length of the input along the axis specified by axis is used. - - Returns - ------- - pandas.DataFrame - Time frequency representation of the input signal, indexes are frequencies, values - are powers. - - Notes - ----- - compute_spectogram computes fft on only a single epoch of data. This epoch be given with the ep - parameter otherwise will be sig.time_support, but it must only be a single epoch. - """ - if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("sig must be either a Tsd or a TsdFrame object.") - if not (fs is None or isinstance(fs, Number)): - raise TypeError("fs must be of type float or int") - if not (ep is None or isinstance(ep, nap.IntervalSet)): - raise TypeError("ep param must be a pynapple IntervalSet object, or None") - if ep is None: - ep = sig.time_support - if len(ep) != 1: - raise ValueError("Given epoch (or signal time_support) must have length 1") - if fs is None: - fs = sig.rate - if not isinstance(full_range, bool): - raise TypeError("full_range must be of type bool or None") - if not isinstance(norm, bool): - raise TypeError("norm must be of type bool") - - fft_result = np.fft.fft(sig.restrict(ep).values, n=n, axis=0) - if n is None: - n = len(sig.restrict(ep)) - fft_freq = np.fft.fftfreq(n, 1 / fs) - - if norm: - fft_result = fft_result / fft_result.shape[0] - - ret = pd.DataFrame(fft_result, fft_freq) - ret.sort_index(inplace=True) - - if not full_range: - return ret.loc[ret.index >= 0] - return ret - - -def compute_mean_power_spectral_density( - sig, - interval_size, - fs=None, - ep=None, - full_range=False, - norm=False, - time_unit="s", -): - """ - Compute mean power spectral density by averaging FFT over epochs of same size. - - The parameter `interval_size` controls the duration of the epochs. - - To imporve frequency resolution, the signal is multiplied by a Hamming window. - - Note that this function assumes a constant sampling rate for `sig`. - - Parameters - ---------- - sig : pynapple.Tsd or pynapple.TsdFrame - Signal with equispaced samples - interval_size : Number - Epochs size to compute to average the FFT across - fs : None, optional - Sampling frequency of `sig`. If `None`, `fs` is equal to `sig.rate` - ep : None or pynapple.IntervalSet, optional - The `IntervalSet` to calculate the fft on. Can be any length. - full_range : bool, optional - If true, will return full fft frequency range, otherwise will return only positive values - norm: bool, optional - Whether the FFT result is divided by the length of the signal to normalize the amplitude - time_unit : str, optional - Time units for parameter `interval_size`. Can be ('s'[default], 'ms', 'us') - - Returns - ------- - pandas.DataFrame - Power spectral density. - - Examples - -------- - >>> import numpy as np - >>> import pynapple as nap - >>> t = np.arange(0, 1, 1/1000) - >>> signal = nap.Tsd(d=np.sin(t * 50 * np.pi * 2), t=t) - >>> mpsd = nap.compute_mean_power_spectral_density(signal, 0.1) - - Raises - ------ - RuntimeError - If splitting the epoch with `interval_size` results in an empty set. - TypeError - If `ep` or `sig` are not respectively pynapple time series or interval set. - """ - if not isinstance(sig, (nap.Tsd, nap.TsdFrame)): - raise TypeError("sig must be either a Tsd or a TsdFrame object.") - - if not (ep is None or isinstance(ep, nap.IntervalSet)): - raise TypeError("ep param must be a pynapple IntervalSet object, or None") - if ep is None: - ep = sig.time_support - - if not (fs is None or isinstance(fs, Number)): - raise TypeError("fs must be of type float or int") - if fs is None: - fs = sig.rate - - if not isinstance(full_range, bool): - raise TypeError("full_range must be of type bool or None") - - if not isinstance(norm, bool): - raise TypeError("norm must be of type bool") - - # Split the ep - interval_size = nap.TsIndex.format_timestamps(np.array([interval_size]), time_unit)[ - 0 - ] - split_ep = ep.split(interval_size) - - if len(split_ep) == 0: - raise RuntimeError( - f"Splitting epochs with interval_size={interval_size} generated an empty IntervalSet. Try decreasing interval_size" - ) - - # Get the slices of each ep - slices = np.zeros((len(split_ep), 2), dtype=int) - - for i in range(len(split_ep)): - sl = sig.get_slice(split_ep[i, 0], split_ep[i, 1]) - slices[i, 0] = sl.start - slices[i, 1] = sl.stop - - # Check what is the signal length - N = np.min(np.diff(slices, 1)) - - if N == 0: - raise RuntimeError( - "One interval doesn't have any signal associated. Check the parameter ep or the time support if no epoch is passed." - ) - - # Get the freqs - fft_freq = np.fft.fftfreq(N, 1 / fs) - - # Get the Hamming window - window = signal.windows.hamming(N) - if sig.ndim == 2: - window = window[:, np.newaxis] - - # Compute the fft - fft_result = np.zeros((N, *sig.shape[1:]), dtype=complex) - - for i in range(len(slices)): - tmp = sig[slices[i, 0] : slices[i, 1]].values[0:N] * window - fft_result += np.fft.fft(tmp, axis=0) - - if norm: - fft_result = fft_result / (float(N) * float(len(slices))) - - ret = pd.DataFrame(fft_result, fft_freq) - ret.sort_index(inplace=True) - if not full_range: - return ret.loc[ret.index >= 0] - return ret - - def _morlet(M=1024, gaussian_width=1.5, window_length=1.0, precision=8): """ Defines the complex Morlet wavelet kernel. From e5ee5b2a49b3587f972e592819a42ae1eda31ccc Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 13 Sep 2024 14:38:06 -0400 Subject: [PATCH 193/195] More updates --- README.md | 2 +- docs/HISTORY.md | 13 +++++++++++-- docs/index.md | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 27d61abf..83638ec0 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ New release :fire: ### pynapple >= 0.7 -Pynapple now implements signal processing. For example, you can filter any time series with a particular bandpass: +Pynapple now implements signal processing. For example, to filter a 1250 Hz sampled time series between 10 Hz and 20 Hz: ```python nap.apply_bandpass_filter(signal, (10, 20), fs=1250) diff --git a/docs/HISTORY.md b/docs/HISTORY.md index 75624468..a64d469e 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -9,7 +9,8 @@ In 2018, Francesco started neuroseries, a Python package built on Pandas. It was In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. -Since 2023, the development of pynapple is lead by Guillaume Viejo and Edoardo Balzani at the Center for Computational Neuroscience +Since 2023, the development of pynapple is lead by [Guillaume Viejo](https://www.simonsfoundation.org/people/guillaume-viejo/) +and [Edoardo Balzani](https://www.simonsfoundation.org/people/edoardo-balzani/) at the Center for Computational Neuroscience of the Flatiron institute. @@ -19,7 +20,15 @@ of the Flatiron institute. - Morlet wavelets spectrogram with utility for plotting the wavelets. - (Mean) Power spectral density. Returns a Pandas DataFrame. -- +- Convolve function works for any dimension of time series and any dimensions of kernel. +- `dtype` in count function +- `get_slice`: public method with a simplified API, argument start, end, time_units. returns a slice that matches behavior of Base.get. +- `_get_slice`: private method, adds the argument "mode" this can be: "after_t", "before_t", "closest_t", "restrict". +- `split` method for IntervalSet. Argument is `interval_size` in time unit. +- Changed os import to pathlib. +- Fixed pickling issue. TsGroup can now be saved as pickle. +- TsGroup can be created from an iterable of Ts/Tsd objects. +- IntervalSet can be created from (start, end) pairs 0.6.6 (2024-05-28) ------------------ diff --git a/docs/index.md b/docs/index.md index c6a121b9..d517eeb1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -32,7 +32,7 @@ New releases :fire: ### pynapple >= 0.7 -Pynapple now implements signal processing. For example, you can filter any time series with a particular bandpass: +Pynapple now implements signal processing. For example, to filter a 1250 Hz sampled time series between 10 Hz and 20 Hz: ```python nap.apply_bandpass_filter(signal, (10, 20), fs=1250) From bedb89f14d9951c8961ade8c24d51ead43b609c3 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 13 Sep 2024 14:42:08 -0400 Subject: [PATCH 194/195] Changing version number --- pynapple/__init__.py | 2 +- pyproject.toml | 6 +++--- setup.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pynapple/__init__.py b/pynapple/__init__.py index 06b74c00..f55f6194 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.6.6" +__version__ = "0.7.0" from .core import ( IntervalSet, Ts, diff --git a/pyproject.toml b/pyproject.toml index 52d8682d..73afed77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pynapple" -version = "0.6.6" +version = "0.7.0" description = "PYthon Neural Analysis Package Pour Laboratoires d’Excellence" readme = "README.md" authors = [{ name = "Guillaume Viejo", email = "guillaume.viejo@gmail.com" }] @@ -36,8 +36,8 @@ requires-python = ">=3.8" include = ["pynapple", "pynapple.*"] [project.urls] -homepage = "https://github.com/pynapple-org/pynapple" -documentation = "https://pynapple-org.github.io/pynapple/" +homepage = "http://pynapple.org/" +documentation = "http://pynapple.org/" repository = "https://github.com/pynapple-org/pynapple" ########################################################################## diff --git a/setup.py b/setup.py index b4fbb0c3..d8fc430f 100644 --- a/setup.py +++ b/setup.py @@ -59,8 +59,8 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/pynapple-org/pynapple', - version='v0.6.6', + version='v0.7.0', zip_safe=False, long_description_content_type='text/markdown', - download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.6.6.tar.gz' + download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.7.0.tar.gz' ) From 79855e00f9cbbb4090caf4605fa4e08c26dd18a1 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Fri, 13 Sep 2024 15:24:02 -0400 Subject: [PATCH 195/195] Cleaning --- MANIFEST.in | 11 --- draft_pynapple_fastplotlib.py | 172 ---------------------------------- 2 files changed, 183 deletions(-) delete mode 100644 MANIFEST.in delete mode 100644 draft_pynapple_fastplotlib.py diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 965b2dda..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,11 +0,0 @@ -include AUTHORS.rst -include CONTRIBUTING.rst -include HISTORY.rst -include LICENSE -include README.rst - -recursive-include tests * -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] - -recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/draft_pynapple_fastplotlib.py b/draft_pynapple_fastplotlib.py deleted file mode 100644 index a3dd423d..00000000 --- a/draft_pynapple_fastplotlib.py +++ /dev/null @@ -1,172 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Fastplotlib -=========== - -Working with calcium data. - -For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. - -The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. - -See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. - -This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo. - -""" -# %% -# %gui qt - -import pynapple as nap -import numpy as np -import fastplotlib as fpl - -import sys -# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' - -def get_memory_map(filepath, nChannels, frequency=20000): - n_channels = int(nChannels) - f = open(filepath, 'rb') - startoffile = f.seek(0, 0) - endoffile = f.seek(0, 2) - bytes_size = 2 - n_samples = int((endoffile-startoffile)/n_channels/bytes_size) - duration = n_samples/frequency - interval = 1/frequency - f.close() - fp = np.memmap(filepath, np.int16, 'r', shape = (n_samples, n_channels)) - timestep = np.arange(0, n_samples)/frequency - - return fp, timestep - - -#### LFP -data_array, time_array = get_memory_map("your/path/to/MyProject/sub-A2929/A2929-200711/A2929-200711.dat", 16) -lfp = nap.TsdFrame(t=time_array, d=data_array) - -lfp2 = lfp.get(0, 20)[:,14] -lfp2 = np.vstack((lfp2.t, lfp2.d)).T - -#### NWB -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") -units = nwb['units']#.getby_category("location")['adn'] -tmp = units.to_tsd().get(0, 20) -tmp = np.vstack((tmp.index.values, tmp.values)).T - - - -fig = fpl.Figure(canvas="glfw", shape=(2,1)) -fig[0,0].add_line(data=lfp2, thickness=1, cmap="autumn") -fig[1,0].add_scatter(tmp) -fig.show(maintain_aspect=False) -# fpl.run() - - - - -# grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = ['lfp', 'wavelet']) -# grid_plot['lfp'].add_line(lfp.t, lfp[:,14].d) - - -import numpy as np -import fastplotlib as fpl - -fig = fpl.Figure(canvas="glfw")#, shape=(2,1), controller_ids="sync") -fig[0,0].add_line(data=np.random.randn(1000)) -fig.show(maintain_aspect=False) - -fig2 = fpl.Figure(canvas="glfw", controllers=fig.controllers)#, shape=(2,1), controller_ids="sync") -fig2[0,0].add_line(data=np.random.randn(1000)*1000) -fig2.show(maintain_aspect=False) - - - -# Not sure about this : -fig[1,0].controller.controls["mouse1"] = "pan", "drag", (1.0, 0.0) - -fig[1,0].controller.controls.pop("mouse2") -fig[1,0].controller.controls.pop("mouse4") -fig[1,0].controller.controls.pop("wheel") - -import pygfx - -controller = pygfx.PanZoomController() -controller.controls.pop("mouse1") -controller.add_camera(fig[0, 0].camera) -controller.register_events(fig[0, 0].viewport) - -controller2 = pygfx.PanZoomController() -controller2.add_camera(fig[1, 0].camera) -controller2.controls.pop("mouse1") -controller2.register_events(fig[1, 0].viewport) - - - - - - - - - - - - - - - - -sys.exit() - -################################################################################################# - - -nwb = nap.load_file("your/path/to/MyProject/sub-A2929/A2929-200711/pynapplenwb/A2929-200711.nwb") -units = nwb['units']#.getby_category("location")['adn'] -tmp = units.to_tsd() -tmp = np.vstack((tmp.index.values, tmp.values)).T - -# Example 1 - -fplot = fpl.Plot() -fplot.add_scatter(tmp) -fplot.graphics[0].cmap = "jet" -fplot.graphics[0].cmap.values = tmp[:, 1] -fplot.show(maintain_aspect=False) - -# Example 2 - -names = [['raster'], ['position']] -grid_plot = fpl.GridPlot(shape=(2, 1), controller_ids="sync", names = names) -grid_plot['raster'].add_scatter(tmp) -grid_plot['position'].add_line(np.vstack((nwb['ry'].t, nwb['ry'].d)).T) -grid_plot.show(maintain_aspect=False) -grid_plot['raster'].auto_scale(maintain_aspect=False) - - -# Example 3 -#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi") -#frames = frames[:,:,:,0] -frames = np.random.randn(10, 100, 100) - -iw = fpl.ImageWidget(frames, cmap="gnuplot2") - -#iw.show() - -# Example 4 - -from PyQt6 import QtWidgets - - -mainwidget = QtWidgets.QWidget() - -hlayout = QtWidgets.QHBoxLayout(mainwidget) - -iw.widget.setParent(mainwidget) - -hlayout.addWidget(iw.widget) - -grid_plot.widget.setParent(mainwidget) - -hlayout.addWidget(grid_plot.widget) - -mainwidget.show()