6
6
from spikeinterface .preprocessing .basepreprocessor import BasePreprocessor , BasePreprocessorSegment
7
7
from spikeinterface .preprocessing .filter import fix_dtype
8
8
9
+ from .motion_utils import ensure_time_bin_edges , ensure_time_bins
10
+
9
11
10
12
def correct_motion_on_peaks (peaks , peak_locations , motion , recording ) -> np .ndarray :
11
13
"""
@@ -54,14 +56,19 @@ def interpolate_motion_on_traces(
54
56
segment_index = None ,
55
57
channel_inds = None ,
56
58
interpolation_time_bin_centers_s = None ,
59
+ interpolation_time_bin_edges_s = None ,
57
60
spatial_interpolation_method = "kriging" ,
58
61
spatial_interpolation_kwargs = {},
59
62
dtype = None ,
60
63
):
61
64
"""
62
65
Apply inverse motion with spatial interpolation on traces.
63
66
64
- Traces can be full traces, but also waveforms snippets.
67
+ Traces can be full traces, but also waveforms snippets. Times used for looking up
68
+ displacements are controlled by interpolation_time_bin_edges_s or
69
+ interpolation_time_bin_centers_s, or fall back to the Motion object's time bins
70
+ by default; times in the recording outside these time bins use the closest edge
71
+ bin's displacement value during interpolation.
65
72
66
73
Parameters
67
74
----------
@@ -80,6 +87,9 @@ def interpolate_motion_on_traces(
80
87
interpolation_time_bin_centers_s : None or np.array
81
88
Manually specify the time bins which the interpolation happens
82
89
in for this segment. If None, these are the motion estimate's time bins.
90
+ interpolation_time_bin_edges_s : None or np.array
91
+ If present, interpolation chunks will be the time bins defined by these edges
92
+ rather than interpolation_time_bin_centers_s or the motion's bins.
83
93
spatial_interpolation_method : "idw" | "kriging", default: "kriging"
84
94
The spatial interpolation method used to interpolate the channel locations:
85
95
* idw : Inverse Distance Weighing
@@ -119,26 +129,33 @@ def interpolate_motion_on_traces(
119
129
total_num_chans = channel_locations .shape [0 ]
120
130
121
131
# -- determine the blocks of frames that will land in the same interpolation time bin
122
- time_bins = interpolation_time_bin_centers_s
123
- if time_bins is None :
124
- time_bins = motion .temporal_bins_s [segment_index ]
125
- bin_s = time_bins [1 ] - time_bins [0 ]
126
- bins_start = time_bins [0 ] - 0.5 * bin_s
127
- # nearest bin center for each frame?
128
- bin_inds = (times - bins_start ) // bin_s
129
- bin_inds = bin_inds .astype (int )
132
+ if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None :
133
+ interpolation_time_bin_centers_s = motion .temporal_bins_s [segment_index ]
134
+ interpolation_time_bin_edges_s = motion .temporal_bin_edges_s [segment_index ]
135
+ else :
136
+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s = ensure_time_bins (
137
+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s
138
+ )
139
+
140
+ # bin the frame times according to the interpolation time bins.
141
+ # searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
142
+ # hence the -1. doing it with "left" is not as nice -- we want t==b[0]
143
+ # to lead to i=1 (rounding down).
144
+ interpolation_bin_inds = np .searchsorted (interpolation_time_bin_edges_s , times , side = "right" ) - 1
145
+
130
146
# the time bins may not cover the whole set of times in the recording,
131
147
# so we need to clip these indices to the valid range
132
- np .clip (bin_inds , 0 , time_bins .size , out = bin_inds )
148
+ n_bins = interpolation_time_bin_edges_s .shape [0 ] - 1
149
+ np .clip (interpolation_bin_inds , 0 , n_bins - 1 , out = interpolation_bin_inds )
133
150
134
151
# -- what are the possibilities here anyway?
135
- bins_here = np .arange (bin_inds [0 ], bin_inds [- 1 ] + 1 )
152
+ interpolation_bins_here = np .arange (interpolation_bin_inds [0 ], interpolation_bin_inds [- 1 ] + 1 )
136
153
137
154
# inperpolation kernel will be the same per temporal bin
138
155
interp_times = np .empty (total_num_chans )
139
156
current_start_index = 0
140
- for bin_ind in bins_here :
141
- bin_time = time_bins [ bin_ind ]
157
+ for interp_bin_ind in interpolation_bins_here :
158
+ bin_time = interpolation_time_bin_centers_s [ interp_bin_ind ]
142
159
interp_times .fill (bin_time )
143
160
channel_motions = motion .get_displacement_at_time_and_depth (
144
161
interp_times ,
@@ -166,16 +183,17 @@ def interpolate_motion_on_traces(
166
183
# ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}")
167
184
# plt.show()
168
185
186
+ # quick search logic to find frames corresponding to this interpolation bin in the recording
169
187
# quickly find the end of this bin, which is also the start of the next
170
188
next_start_index = current_start_index + np .searchsorted (
171
- bin_inds [current_start_index :], bin_ind + 1 , side = "left"
189
+ interpolation_bin_inds [current_start_index :], interp_bin_ind + 1 , side = "left"
172
190
)
173
- in_bin = slice (current_start_index , next_start_index )
191
+ frames_in_bin = slice (current_start_index , next_start_index )
174
192
175
193
# here we use a simple np.matmul even if dirft_kernel can be super sparse.
176
194
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
177
195
# in ChunkRecordingExecutor)
178
- np .matmul (traces [in_bin ], drift_kernel , out = traces_corrected [in_bin ])
196
+ np .matmul (traces [frames_in_bin ], drift_kernel , out = traces_corrected [frames_in_bin ])
179
197
current_start_index = next_start_index
180
198
181
199
return traces_corrected
@@ -297,6 +315,7 @@ def __init__(
297
315
p = 1 ,
298
316
num_closest = 3 ,
299
317
interpolation_time_bin_centers_s = None ,
318
+ interpolation_time_bin_edges_s = None ,
300
319
interpolation_time_bin_size_s = None ,
301
320
dtype = None ,
302
321
** spatial_interpolation_kwargs ,
@@ -363,9 +382,14 @@ def __init__(
363
382
364
383
# handle manual interpolation_time_bin_centers_s
365
384
# the case where interpolation_time_bin_size_s is set is handled per-segment below
366
- if interpolation_time_bin_centers_s is None :
385
+ if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None :
367
386
if interpolation_time_bin_size_s is None :
368
387
interpolation_time_bin_centers_s = motion .temporal_bins_s
388
+ interpolation_time_bin_edges_s = motion .temporal_bin_edges_s
389
+ else :
390
+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s = ensure_time_bins (
391
+ interpolation_time_bin_centers_s , interpolation_time_bin_edges_s
392
+ )
369
393
370
394
for segment_index , parent_segment in enumerate (recording ._recording_segments ):
371
395
# finish the per-segment part of the time bin logic
@@ -375,8 +399,13 @@ def __init__(
375
399
t_start , t_end = parent_segment .sample_index_to_time (np .array ([0 , s_end ]))
376
400
halfbin = interpolation_time_bin_size_s / 2.0
377
401
segment_interpolation_time_bins_s = np .arange (t_start + halfbin , t_end , interpolation_time_bin_size_s )
402
+ segment_interpolation_time_bin_edges_s = np .arange (
403
+ t_start , t_end + halfbin , interpolation_time_bin_size_s
404
+ )
405
+ assert segment_interpolation_time_bin_edges_s .shape == (segment_interpolation_time_bins_s .shape [0 ] + 1 ,)
378
406
else :
379
407
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s [segment_index ]
408
+ segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s [segment_index ]
380
409
381
410
rec_segment = InterpolateMotionRecordingSegment (
382
411
parent_segment ,
@@ -387,6 +416,7 @@ def __init__(
387
416
channel_inds ,
388
417
segment_index ,
389
418
segment_interpolation_time_bins_s ,
419
+ segment_interpolation_time_bin_edges_s ,
390
420
dtype = dtype_ ,
391
421
)
392
422
self .add_recording_segment (rec_segment )
@@ -420,6 +450,7 @@ def __init__(
420
450
channel_inds ,
421
451
segment_index ,
422
452
interpolation_time_bin_centers_s ,
453
+ interpolation_time_bin_edges_s ,
423
454
dtype = "float32" ,
424
455
):
425
456
BasePreprocessorSegment .__init__ (self , parent_recording_segment )
@@ -429,13 +460,11 @@ def __init__(
429
460
self .channel_inds = channel_inds
430
461
self .segment_index = segment_index
431
462
self .interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
463
+ self .interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
432
464
self .dtype = dtype
433
465
self .motion = motion
434
466
435
467
def get_traces (self , start_frame , end_frame , channel_indices ):
436
- if self .time_vector is not None :
437
- raise NotImplementedError ("InterpolateMotionRecording does not yet support recordings with time_vectors." )
438
-
439
468
if start_frame is None :
440
469
start_frame = 0
441
470
if end_frame is None :
@@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
453
482
channel_inds = self .channel_inds ,
454
483
spatial_interpolation_method = self .spatial_interpolation_method ,
455
484
spatial_interpolation_kwargs = self .spatial_interpolation_kwargs ,
456
- interpolation_time_bin_centers_s = self .interpolation_time_bin_centers_s ,
485
+ interpolation_time_bin_edges_s = self .interpolation_time_bin_edges_s ,
457
486
)
458
487
459
488
if channel_indices is not None :
0 commit comments