Skip to content

Commit

Permalink
Merge pull request #207 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Adding dropna
  • Loading branch information
gviejo authored Nov 20, 2023
2 parents f0cebd8 + 288662b commit a3530d1
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 12 deletions.
29 changes: 27 additions & 2 deletions pynapple/core/jitted_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: guillaume
# @Date: 2022-10-31 16:44:31
# @Last Modified by: gviejo
# @Last Modified time: 2023-10-15 16:05:27
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 18:27:43
import numpy as np
from numba import jit

Expand Down Expand Up @@ -749,6 +749,31 @@ def jitin_interval(time_array, starts, ends):
return data


@jit(nopython=True)
def jitremove_nan(time_array, index_nan):
n = len(time_array)
ix_start = np.zeros(n, dtype=np.bool_)
ix_end = np.zeros(n, dtype=np.bool_)

if not index_nan[0]: # First start
ix_start[0] = True

t = 1
while t < n:
if index_nan[t - 1] and not index_nan[t]: # start
ix_start[t] = True
if not index_nan[t - 1] and index_nan[t]: # end
ix_end[t - 1] = True
t += 1

if not index_nan[-1]: # Last stop
ix_end[-1] = True

starts = time_array[ix_start]
ends = time_array[ix_end]
return (starts, ends)


@jit(nopython=True)
def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5):
y = y.astype(np.float64)
Expand Down
47 changes: 45 additions & 2 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-01-27 18:33:31
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-08 18:44:24
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 18:59:08

"""
Expand Down Expand Up @@ -39,6 +39,7 @@
jitbin,
jitbin_array,
jitcount,
jitremove_nan,
jitrestrict,
jitthreshold,
jittsrestrict,
Expand Down Expand Up @@ -787,6 +788,48 @@ def get(self, start, end=None, time_units="s"):
idx_end = np.searchsorted(time_array, end, side="right")
return self[idx_start:idx_end]

def dropna(self, update_time_support=True):
"""Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs.
To change this behavior, you can set update_time_support=False.
Parameters
----------
update_time_support : bool, optional
Returns
-------
Tsd, TsdFrame or TsdTensor
The time series without the NaNs
"""
index_nan = np.any(np.isnan(self.values), axis=tuple(range(1, self.ndim)))
if np.all(index_nan): # In case it's only NaNs
return self.__class__(
t=np.array([]), d=np.empty(tuple([0] + [d for d in self.shape[1:]]))
)

elif np.any(index_nan):
if update_time_support:
time_array = self.index.values
starts, ends = jitremove_nan(time_array, index_nan)

to_fix = starts == ends
if np.any(to_fix):
ends[
to_fix
] += 1e-6 # adding 1 millisecond in case of a single point

ep = IntervalSet(starts, ends)

return self.__class__(
t=time_array[~index_nan], d=self.values[~index_nan], time_support=ep
)

else:
return self[~index_nan]

else:
return self


class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd):
"""
Expand Down
9 changes: 7 additions & 2 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-01-30 22:59:00
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-16 11:34:48
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 19:13:24

import numpy as np
from scipy.linalg import hankel
Expand Down Expand Up @@ -168,6 +168,11 @@ def compute_event_trigger_average(

tmp = feature.bin_average(binsize, ep)

# Check for any NaNs in feature
if np.any(np.isnan(tmp)):
tmp = tmp.dropna()
count = count.restrict(tmp.time_support)

# Build the Hankel matrix
n_p = len(idx1)
n_f = len(idx2)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_numpy_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: Guillaume Viejo
# @Date: 2023-09-18 18:11:24
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-08 18:14:12
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 16:55:26



Expand All @@ -17,7 +17,10 @@

tsd = nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 3), time_units="s")

tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))
# tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))

tsd.d[tsd.values>0.9] = np.NaN


@pytest.mark.parametrize(
"tsd",
Expand Down
35 changes: 32 additions & 3 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-04-01 09:57:55
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-08 18:46:52
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 18:48:57
#!/usr/bin/env python

"""Tests of time series for `pynapple` package."""
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(self):
@pytest.mark.parametrize(
"tsd",
[
nap.Tsd(t=np.arange(100), d=np.arange(100), time_units="s"),
nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s"),
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 5), time_units="s"),
nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 2), time_units="s"),
nap.Ts(t=np.arange(100), time_units="s"),
Expand Down Expand Up @@ -393,6 +393,35 @@ def test_get_timepoint(self, tsd):
np.testing.assert_array_equal(tsd.get(1), tsd[1])
np.testing.assert_array_equal(tsd.get(1000), tsd[-1])

def test_dropna(self, tsd):
if not isinstance(tsd, nap.Ts):

new_tsd = tsd.dropna()
np.testing.assert_array_equal(tsd.index.values, new_tsd.index.values)
np.testing.assert_array_equal(tsd.values, new_tsd.values)

tsd.values[tsd.values>0.9] = np.NaN
new_tsd = tsd.dropna()
assert not np.all(np.isnan(new_tsd))
tokeep = np.array([~np.any(np.isnan(tsd[i])) for i in range(len(tsd))])
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)

newtsd2 = tsd.restrict(new_tsd.time_support)
np.testing.assert_array_equal(newtsd2.index.values, new_tsd.index.values)
np.testing.assert_array_equal(newtsd2.values, new_tsd.values)

new_tsd = tsd.dropna(update_time_support=False)
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)
pd.testing.assert_frame_equal(new_tsd.time_support, tsd.time_support)

tsd.values[:] = np.NaN
new_tsd = tsd.dropna()
assert len(new_tsd) == 0
assert len(new_tsd.time_support) == 0


####################################################
# Test for tsd
####################################################
Expand Down

0 comments on commit a3530d1

Please sign in to comment.