diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 8deb7576..310f4e1e 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: guillaume # @Date: 2022-10-31 16:44:31 -# @Last Modified by: gviejo -# @Last Modified time: 2023-10-15 16:05:27 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-11-19 18:27:43 import numpy as np from numba import jit @@ -749,6 +749,31 @@ def jitin_interval(time_array, starts, ends): return data +@jit(nopython=True) +def jitremove_nan(time_array, index_nan): + n = len(time_array) + ix_start = np.zeros(n, dtype=np.bool_) + ix_end = np.zeros(n, dtype=np.bool_) + + if not index_nan[0]: # First start + ix_start[0] = True + + t = 1 + while t < n: + if index_nan[t - 1] and not index_nan[t]: # start + ix_start[t] = True + if not index_nan[t - 1] and index_nan[t]: # end + ix_end[t - 1] = True + t += 1 + + if not index_nan[-1]: # Last stop + ix_end[-1] = True + + starts = time_array[ix_start] + ends = time_array[ix_end] + return (starts, ends) + + @jit(nopython=True) def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): y = y.astype(np.float64) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index ac31831d..e1626245 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: gviejo # @Date: 2022-01-27 18:33:31 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-08 18:44:24 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-11-19 18:59:08 """ @@ -39,6 +39,7 @@ jitbin, jitbin_array, jitcount, + jitremove_nan, jitrestrict, jitthreshold, jittsrestrict, @@ -787,6 +788,48 @@ def get(self, start, end=None, time_units="s"): idx_end = np.searchsorted(time_array, end, side="right") return self[idx_start:idx_end] + def dropna(self, update_time_support=True): + """Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs. + To change this behavior, you can set update_time_support=False. + + Parameters + ---------- + update_time_support : bool, optional + + Returns + ------- + Tsd, TsdFrame or TsdTensor + The time series without the NaNs + """ + index_nan = np.any(np.isnan(self.values), axis=tuple(range(1, self.ndim))) + if np.all(index_nan): # In case it's only NaNs + return self.__class__( + t=np.array([]), d=np.empty(tuple([0] + [d for d in self.shape[1:]])) + ) + + elif np.any(index_nan): + if update_time_support: + time_array = self.index.values + starts, ends = jitremove_nan(time_array, index_nan) + + to_fix = starts == ends + if np.any(to_fix): + ends[ + to_fix + ] += 1e-6 # adding 1 millisecond in case of a single point + + ep = IntervalSet(starts, ends) + + return self.__class__( + t=time_array[~index_nan], d=self.values[~index_nan], time_support=ep + ) + + else: + return self[~index_nan] + + else: + return self + class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd): """ diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index f783c72f..6d935e70 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: gviejo # @Date: 2022-01-30 22:59:00 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-16 11:34:48 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-11-19 19:13:24 import numpy as np from scipy.linalg import hankel @@ -168,6 +168,11 @@ def compute_event_trigger_average( tmp = feature.bin_average(binsize, ep) + # Check for any NaNs in feature + if np.any(np.isnan(tmp)): + tmp = tmp.dropna() + count = count.restrict(tmp.time_support) + # Build the Hankel matrix n_p = len(idx1) n_f = len(idx2) diff --git a/tests/test_numpy_compatibility.py b/tests/test_numpy_compatibility.py index 254248b3..8438019d 100644 --- a/tests/test_numpy_compatibility.py +++ b/tests/test_numpy_compatibility.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: Guillaume Viejo # @Date: 2023-09-18 18:11:24 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-08 18:14:12 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-11-19 16:55:26 @@ -17,7 +17,10 @@ tsd = nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 3), time_units="s") -tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6)) +# tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6)) + +tsd.d[tsd.values>0.9] = np.NaN + @pytest.mark.parametrize( "tsd", diff --git a/tests/test_time_series.py b/tests/test_time_series.py index f0750b8d..449cbfa7 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- # @Author: gviejo # @Date: 2022-04-01 09:57:55 -# @Last Modified by: gviejo -# @Last Modified time: 2023-11-08 18:46:52 +# @Last Modified by: Guillaume Viejo +# @Last Modified time: 2023-11-19 18:48:57 #!/usr/bin/env python """Tests of time series for `pynapple` package.""" @@ -271,7 +271,7 @@ def __init__(self): @pytest.mark.parametrize( "tsd", [ - nap.Tsd(t=np.arange(100), d=np.arange(100), time_units="s"), + nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s"), nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 5), time_units="s"), nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 2), time_units="s"), nap.Ts(t=np.arange(100), time_units="s"), @@ -393,6 +393,35 @@ def test_get_timepoint(self, tsd): np.testing.assert_array_equal(tsd.get(1), tsd[1]) np.testing.assert_array_equal(tsd.get(1000), tsd[-1]) + def test_dropna(self, tsd): + if not isinstance(tsd, nap.Ts): + + new_tsd = tsd.dropna() + np.testing.assert_array_equal(tsd.index.values, new_tsd.index.values) + np.testing.assert_array_equal(tsd.values, new_tsd.values) + + tsd.values[tsd.values>0.9] = np.NaN + new_tsd = tsd.dropna() + assert not np.all(np.isnan(new_tsd)) + tokeep = np.array([~np.any(np.isnan(tsd[i])) for i in range(len(tsd))]) + np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values) + np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) + + newtsd2 = tsd.restrict(new_tsd.time_support) + np.testing.assert_array_equal(newtsd2.index.values, new_tsd.index.values) + np.testing.assert_array_equal(newtsd2.values, new_tsd.values) + + new_tsd = tsd.dropna(update_time_support=False) + np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values) + np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values) + pd.testing.assert_frame_equal(new_tsd.time_support, tsd.time_support) + + tsd.values[:] = np.NaN + new_tsd = tsd.dropna() + assert len(new_tsd) == 0 + assert len(new_tsd.time_support) == 0 + + #################################################### # Test for tsd ####################################################