|
| 1 | +from functools import partial |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import numpy as np |
| 6 | +import scipy.signal as signal |
| 7 | + |
| 8 | +from .utils import ( |
| 9 | + _get_shifted_indices, |
| 10 | + _get_slicing, |
| 11 | + _odd_ext_multiepoch, |
| 12 | + _revert_epochs, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +@partial(jax.jit, static_argnums=(3, )) |
| 17 | +def _recursion_loop_sos(signal, sos, zi, nan_function): |
| 18 | + """ |
| 19 | + Applies a recursive second-order section (SOS) filter to the input signal. |
| 20 | +
|
| 21 | + Parameters |
| 22 | + ---------- |
| 23 | + signal : jnp.ndarray |
| 24 | + The input signal to be filtered, with shape (n_samples,). |
| 25 | + sos : jnp.ndarray |
| 26 | + Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6). |
| 27 | + zi : jnp.ndarray |
| 28 | + Initial conditions for the filter, with shape (n_sections, 2, n_epochs). |
| 29 | + nan_function : callable |
| 30 | + A function that specifies how to re-initialize the initial conditions when a NaN is encountered in the signal. |
| 31 | + It should take two arguments: the epoch number and the current filter state, and return a tuple of the updated |
| 32 | + epoch number and the re-initialized filter state. |
| 33 | +
|
| 34 | + Returns |
| 35 | + ------- |
| 36 | + jnp.ndarray |
| 37 | + The filtered signal, with the same shape as the input signal. |
| 38 | + """ |
| 39 | + |
| 40 | + def internal_loop(s, x_zi): |
| 41 | + x_cur, zi_slice = x_zi |
| 42 | + x_new = sos[s, 0] * x_cur + zi_slice[s, 0] |
| 43 | + zi_slice = zi_slice.at[s, 0].set( |
| 44 | + sos[s, 1] * x_cur - sos[s, 4] * x_new + zi_slice[s, 1] |
| 45 | + ) |
| 46 | + zi_slice = zi_slice.at[s, 1].set( |
| 47 | + sos[s, 2] * x_cur - sos[s, 5] * x_new) |
| 48 | + x_cur = x_new |
| 49 | + return x_cur, zi_slice |
| 50 | + |
| 51 | + def recursion_step(carry, x): |
| 52 | + epoch_num, zi_slice = carry |
| 53 | + |
| 54 | + x_cur, zi_slice = jax.lax.fori_loop( |
| 55 | + lower=0, upper=sos.shape[0], body_fun=internal_loop, init_val=(x, zi_slice) |
| 56 | + ) |
| 57 | + |
| 58 | + # Use jax.lax.cond to choose between nan_case and not_nan_case |
| 59 | + epoch_num, zi_slice = jax.lax.cond( |
| 60 | + jnp.isnan(x), # Condition to check |
| 61 | + nan_function, # Function to call if x is NaN |
| 62 | + lambda i, x: (i, zi_slice), # Function to call if x is not NaN |
| 63 | + epoch_num, |
| 64 | + zi, |
| 65 | + ) |
| 66 | + |
| 67 | + return (epoch_num, zi_slice), x_cur |
| 68 | + |
| 69 | + _, res = jax.lax.scan(recursion_step, (0, zi[..., 0]), signal) |
| 70 | + |
| 71 | + return res |
| 72 | + |
| 73 | + |
| 74 | +# vectorize the recursion over signals. |
| 75 | +_vmap_recursion_sos = jax.vmap(_recursion_loop_sos, in_axes=(1, None, 2, None), out_axes=1) |
| 76 | + |
| 77 | + |
| 78 | +def _insert_constant(idx_start, idx_end, data_array, window_size, const=jnp.nan): |
| 79 | + """ |
| 80 | + Insert a constant value array between epochs in a time series data array. |
| 81 | +
|
| 82 | + This function interleaves a constant value array of specified size between each epoch in the data array. |
| 83 | +
|
| 84 | + Parameters |
| 85 | + ---------- |
| 86 | + idx_start : jnp.ndarray |
| 87 | + Array of start indices for each epoch. |
| 88 | + idx_end : jnp.ndarray |
| 89 | + Array of end indices for each epoch. |
| 90 | + data_array : jnp.ndarray |
| 91 | + The input data array, with shape (n_samples, ...). |
| 92 | + window_size : int |
| 93 | + The size of the constant array to be inserted between epochs. |
| 94 | + const : float, optional |
| 95 | + The constant value to be inserted, by default jnp.nan. |
| 96 | +
|
| 97 | + Returns |
| 98 | + ------- |
| 99 | + data_array: jnp.ndarray |
| 100 | + The modified data array with the constant arrays inserted. |
| 101 | + ix_orig: jnp.ndarray |
| 102 | + Indices corresponding to the samples in the original data array. |
| 103 | + ix_shift: jnp.ndarray |
| 104 | + The shifted indices after the constant array has been interleaved. |
| 105 | + idx_start_shift: |
| 106 | + The shifted start indices of the epochs in the modified array. |
| 107 | + idx_end_shift: |
| 108 | + The shifted end indices of the epochs in the modified array. |
| 109 | + """ |
| 110 | + # shift by a window every epoch |
| 111 | + idx_start_shift, idx_end_shift = _get_shifted_indices( |
| 112 | + idx_start, idx_end, window_size |
| 113 | + ) |
| 114 | + |
| 115 | + # get the indices for setting elements |
| 116 | + ix_orig = _get_slicing(idx_start, idx_end) |
| 117 | + ix_shift = _get_slicing(idx_start_shift, idx_end_shift) |
| 118 | + |
| 119 | + tot_size = ix_shift[-1] - ix_shift[0] + 1 |
| 120 | + data_array = ( |
| 121 | + jnp.full((tot_size, *data_array.shape[1:]), const) |
| 122 | + .at[ix_shift] |
| 123 | + .set(data_array[ix_orig]) |
| 124 | + ) |
| 125 | + return data_array, ix_orig, ix_shift, idx_start_shift, idx_end_shift |
| 126 | + |
| 127 | + |
| 128 | +def jax_sosfiltfilt(sos, time_array, data_array, starts, ends): |
| 129 | + """ |
| 130 | + Apply forward-backward filtering using a second-order section (SOS) filter. |
| 131 | +
|
| 132 | + This function applies an SOS filter to the data array in both forward and reverse directions, |
| 133 | + which results in zero-phase filtering. |
| 134 | +
|
| 135 | + Parameters |
| 136 | + ---------- |
| 137 | + sos : np.ndarray |
| 138 | + Array of second-order filter coefficients in the 'sos' format, with shape (n_sections, 6). |
| 139 | + time_array : np.ndarray |
| 140 | + The time array corresponding to the data, with shape (n_samples,). |
| 141 | + data_array : jnp.ndarray |
| 142 | + The data array to be filtered, with shape (n_samples, ...). |
| 143 | + starts : np.ndarray |
| 144 | + Array of start indices for the epochs in the data array. |
| 145 | + ends : np.ndarray |
| 146 | + Array of end indices for the epochs in the data array. |
| 147 | +
|
| 148 | + Returns |
| 149 | + ------- |
| 150 | + : jnp.ndarray |
| 151 | + The zero-phase filtered data array, with the same shape as the input data array. |
| 152 | + """ |
| 153 | + |
| 154 | + original_shape = data_array.shape |
| 155 | + data_array = data_array.reshape(data_array.shape[0], -1) |
| 156 | + |
| 157 | + # same default padding as scipy.sosfiltfilt ("pad" method and "odd" padtype). |
| 158 | + n_sections = sos.shape[0] |
| 159 | + ntaps = 2 * n_sections + 1 |
| 160 | + ntaps -= min((sos[:, 2] == 0).sum(), (sos[:, 5] == 0).sum()) |
| 161 | + pad_num = 3 * ntaps |
| 162 | + |
| 163 | + ext, ix_start_pad, ix_end_pad, ix_data = _odd_ext_multiepoch(pad_num, time_array, data_array, starts, ends) |
| 164 | + |
| 165 | + # get the start/end index of each epoch after padding |
| 166 | + ix_start_ep = np.hstack((ix_start_pad[0], ix_start_pad[1:-1] + pad_num)) |
| 167 | + ix_end_ep = np.hstack((ix_start_ep[1:], ix_end_pad[-1])) |
| 168 | + |
| 169 | + zi = signal.sosfilt_zi(sos) |
| 170 | + |
| 171 | + # this braodcast has shape (*zi.shape, data_array.shape[1], len(ix_start_pad)) |
| 172 | + z0 = zi[..., jnp.newaxis, jnp.newaxis] * ext.T[jnp.newaxis, jnp.newaxis, ..., ix_start_ep] |
| 173 | + |
| 174 | + if len(starts) > 1: |
| 175 | + # multi epoch case augmenting with nans. |
| 176 | + aug_data, ix_orig, ix_shift, idx_start_shift, idx_end_shift = _insert_constant( |
| 177 | + ix_start_ep, ix_end_ep, ext, window_size=1, const=np.nan |
| 178 | + ) |
| 179 | + |
| 180 | + # grab the next initial condition, increase the epoch counter |
| 181 | + nan_func = lambda ep_num, x: (ep_num + 1, x[..., ep_num + 1]) |
| 182 | + else: |
| 183 | + # single epoch, no augmentation |
| 184 | + nan_func = lambda ep_num, x: (ep_num + 1, x[..., 0]) |
| 185 | + aug_data = ext |
| 186 | + idx_start_shift = ix_start_ep |
| 187 | + idx_end_shift = ix_end_ep |
| 188 | + ix_shift = slice(None) |
| 189 | + |
| 190 | + |
| 191 | + # call forward recursion |
| 192 | + out = _vmap_recursion_sos(aug_data, sos, z0, nan_func) |
| 193 | + |
| 194 | + # reverse time axis |
| 195 | + irev = _revert_epochs(idx_start_shift, idx_end_shift) |
| 196 | + out = out.at[ix_shift].set(out[irev]) |
| 197 | + |
| 198 | + # compute new init cond |
| 199 | + z0 = zi[..., jnp.newaxis, jnp.newaxis] * out.T[jnp.newaxis, jnp.newaxis, ..., idx_start_shift] |
| 200 | + |
| 201 | + # call backward recursion |
| 202 | + out = _vmap_recursion_sos(out, sos, z0, nan_func) |
| 203 | + |
| 204 | + # re-flip axis |
| 205 | + out = out.at[ix_shift].set(out[irev]) |
| 206 | + |
| 207 | + # remove nans and padding |
| 208 | + out = out[ix_shift][ix_data] |
| 209 | + |
| 210 | + return out.reshape(original_shape) |
0 commit comments