Skip to content

Commit 853d8a4

Browse files
authored
Merge pull request SpikeInterface#3517 from cwindolf/time_bins
Fix a cross-band interpolation bug, and allow time_vector in interpolate_motion
2 parents 3fd3d97 + 38e0ada commit 853d8a4

File tree

4 files changed

+184
-60
lines changed

4 files changed

+184
-60
lines changed

src/spikeinterface/core/baserecording.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
23
import warnings
34
from pathlib import Path
45

@@ -7,14 +8,9 @@
78

89
from .base import BaseSegment
910
from .baserecordingsnippets import BaseRecordingSnippets
10-
from .core_tools import (
11-
convert_bytes_to_str,
12-
convert_seconds_to_str,
13-
)
14-
from .recording_tools import write_binary_recording
15-
16-
11+
from .core_tools import convert_bytes_to_str, convert_seconds_to_str
1712
from .job_tools import split_job_kwargs
13+
from .recording_tools import write_binary_recording
1814

1915

2016
class BaseRecording(BaseRecordingSnippets):
@@ -950,11 +946,11 @@ def time_to_sample_index(self, time_s):
950946
sample_index = time_s * self.sampling_frequency
951947
else:
952948
sample_index = (time_s - self.t_start) * self.sampling_frequency
953-
sample_index = round(sample_index)
949+
sample_index = np.round(sample_index).astype(int)
954950
else:
955951
sample_index = np.searchsorted(self.time_vector, time_s, side="right") - 1
956952

957-
return int(sample_index)
953+
return sample_index
958954

959955
def get_num_samples(self) -> int:
960956
"""Returns the number of samples in this signal segment

src/spikeinterface/sortingcomponents/motion/motion_interpolation.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment
77
from spikeinterface.preprocessing.filter import fix_dtype
88

9+
from .motion_utils import ensure_time_bin_edges, ensure_time_bins
10+
911

1012
def correct_motion_on_peaks(peaks, peak_locations, motion, recording) -> np.ndarray:
1113
"""
@@ -54,14 +56,19 @@ def interpolate_motion_on_traces(
5456
segment_index=None,
5557
channel_inds=None,
5658
interpolation_time_bin_centers_s=None,
59+
interpolation_time_bin_edges_s=None,
5760
spatial_interpolation_method="kriging",
5861
spatial_interpolation_kwargs={},
5962
dtype=None,
6063
):
6164
"""
6265
Apply inverse motion with spatial interpolation on traces.
6366
64-
Traces can be full traces, but also waveforms snippets.
67+
Traces can be full traces, but also waveforms snippets. Times used for looking up
68+
displacements are controlled by interpolation_time_bin_edges_s or
69+
interpolation_time_bin_centers_s, or fall back to the Motion object's time bins
70+
by default; times in the recording outside these time bins use the closest edge
71+
bin's displacement value during interpolation.
6572
6673
Parameters
6774
----------
@@ -80,6 +87,9 @@ def interpolate_motion_on_traces(
8087
interpolation_time_bin_centers_s : None or np.array
8188
Manually specify the time bins which the interpolation happens
8289
in for this segment. If None, these are the motion estimate's time bins.
90+
interpolation_time_bin_edges_s : None or np.array
91+
If present, interpolation chunks will be the time bins defined by these edges
92+
rather than interpolation_time_bin_centers_s or the motion's bins.
8393
spatial_interpolation_method : "idw" | "kriging", default: "kriging"
8494
The spatial interpolation method used to interpolate the channel locations:
8595
* idw : Inverse Distance Weighing
@@ -119,26 +129,33 @@ def interpolate_motion_on_traces(
119129
total_num_chans = channel_locations.shape[0]
120130

121131
# -- determine the blocks of frames that will land in the same interpolation time bin
122-
time_bins = interpolation_time_bin_centers_s
123-
if time_bins is None:
124-
time_bins = motion.temporal_bins_s[segment_index]
125-
bin_s = time_bins[1] - time_bins[0]
126-
bins_start = time_bins[0] - 0.5 * bin_s
127-
# nearest bin center for each frame?
128-
bin_inds = (times - bins_start) // bin_s
129-
bin_inds = bin_inds.astype(int)
132+
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
133+
interpolation_time_bin_centers_s = motion.temporal_bins_s[segment_index]
134+
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s[segment_index]
135+
else:
136+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
137+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
138+
)
139+
140+
# bin the frame times according to the interpolation time bins.
141+
# searchsorted(b, t, side="right") == i means that b[i-1] <= t < b[i]
142+
# hence the -1. doing it with "left" is not as nice -- we want t==b[0]
143+
# to lead to i=1 (rounding down).
144+
interpolation_bin_inds = np.searchsorted(interpolation_time_bin_edges_s, times, side="right") - 1
145+
130146
# the time bins may not cover the whole set of times in the recording,
131147
# so we need to clip these indices to the valid range
132-
np.clip(bin_inds, 0, time_bins.size, out=bin_inds)
148+
n_bins = interpolation_time_bin_edges_s.shape[0] - 1
149+
np.clip(interpolation_bin_inds, 0, n_bins - 1, out=interpolation_bin_inds)
133150

134151
# -- what are the possibilities here anyway?
135-
bins_here = np.arange(bin_inds[0], bin_inds[-1] + 1)
152+
interpolation_bins_here = np.arange(interpolation_bin_inds[0], interpolation_bin_inds[-1] + 1)
136153

137154
# inperpolation kernel will be the same per temporal bin
138155
interp_times = np.empty(total_num_chans)
139156
current_start_index = 0
140-
for bin_ind in bins_here:
141-
bin_time = time_bins[bin_ind]
157+
for interp_bin_ind in interpolation_bins_here:
158+
bin_time = interpolation_time_bin_centers_s[interp_bin_ind]
142159
interp_times.fill(bin_time)
143160
channel_motions = motion.get_displacement_at_time_and_depth(
144161
interp_times,
@@ -166,16 +183,17 @@ def interpolate_motion_on_traces(
166183
# ax.set_title(f"bin_ind {bin_ind} - {bin_time}s - {spatial_interpolation_method}")
167184
# plt.show()
168185

186+
# quick search logic to find frames corresponding to this interpolation bin in the recording
169187
# quickly find the end of this bin, which is also the start of the next
170188
next_start_index = current_start_index + np.searchsorted(
171-
bin_inds[current_start_index:], bin_ind + 1, side="left"
189+
interpolation_bin_inds[current_start_index:], interp_bin_ind + 1, side="left"
172190
)
173-
in_bin = slice(current_start_index, next_start_index)
191+
frames_in_bin = slice(current_start_index, next_start_index)
174192

175193
# here we use a simple np.matmul even if dirft_kernel can be super sparse.
176194
# because the speed for a sparse matmul is not so good when we disable multi threaad (due multi processing
177195
# in ChunkRecordingExecutor)
178-
np.matmul(traces[in_bin], drift_kernel, out=traces_corrected[in_bin])
196+
np.matmul(traces[frames_in_bin], drift_kernel, out=traces_corrected[frames_in_bin])
179197
current_start_index = next_start_index
180198

181199
return traces_corrected
@@ -297,6 +315,7 @@ def __init__(
297315
p=1,
298316
num_closest=3,
299317
interpolation_time_bin_centers_s=None,
318+
interpolation_time_bin_edges_s=None,
300319
interpolation_time_bin_size_s=None,
301320
dtype=None,
302321
**spatial_interpolation_kwargs,
@@ -363,9 +382,14 @@ def __init__(
363382

364383
# handle manual interpolation_time_bin_centers_s
365384
# the case where interpolation_time_bin_size_s is set is handled per-segment below
366-
if interpolation_time_bin_centers_s is None:
385+
if interpolation_time_bin_centers_s is None and interpolation_time_bin_edges_s is None:
367386
if interpolation_time_bin_size_s is None:
368387
interpolation_time_bin_centers_s = motion.temporal_bins_s
388+
interpolation_time_bin_edges_s = motion.temporal_bin_edges_s
389+
else:
390+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s = ensure_time_bins(
391+
interpolation_time_bin_centers_s, interpolation_time_bin_edges_s
392+
)
369393

370394
for segment_index, parent_segment in enumerate(recording._recording_segments):
371395
# finish the per-segment part of the time bin logic
@@ -375,8 +399,13 @@ def __init__(
375399
t_start, t_end = parent_segment.sample_index_to_time(np.array([0, s_end]))
376400
halfbin = interpolation_time_bin_size_s / 2.0
377401
segment_interpolation_time_bins_s = np.arange(t_start + halfbin, t_end, interpolation_time_bin_size_s)
402+
segment_interpolation_time_bin_edges_s = np.arange(
403+
t_start, t_end + halfbin, interpolation_time_bin_size_s
404+
)
405+
assert segment_interpolation_time_bin_edges_s.shape == (segment_interpolation_time_bins_s.shape[0] + 1,)
378406
else:
379407
segment_interpolation_time_bins_s = interpolation_time_bin_centers_s[segment_index]
408+
segment_interpolation_time_bin_edges_s = interpolation_time_bin_edges_s[segment_index]
380409

381410
rec_segment = InterpolateMotionRecordingSegment(
382411
parent_segment,
@@ -387,6 +416,7 @@ def __init__(
387416
channel_inds,
388417
segment_index,
389418
segment_interpolation_time_bins_s,
419+
segment_interpolation_time_bin_edges_s,
390420
dtype=dtype_,
391421
)
392422
self.add_recording_segment(rec_segment)
@@ -420,6 +450,7 @@ def __init__(
420450
channel_inds,
421451
segment_index,
422452
interpolation_time_bin_centers_s,
453+
interpolation_time_bin_edges_s,
423454
dtype="float32",
424455
):
425456
BasePreprocessorSegment.__init__(self, parent_recording_segment)
@@ -429,13 +460,11 @@ def __init__(
429460
self.channel_inds = channel_inds
430461
self.segment_index = segment_index
431462
self.interpolation_time_bin_centers_s = interpolation_time_bin_centers_s
463+
self.interpolation_time_bin_edges_s = interpolation_time_bin_edges_s
432464
self.dtype = dtype
433465
self.motion = motion
434466

435467
def get_traces(self, start_frame, end_frame, channel_indices):
436-
if self.time_vector is not None:
437-
raise NotImplementedError("InterpolateMotionRecording does not yet support recordings with time_vectors.")
438-
439468
if start_frame is None:
440469
start_frame = 0
441470
if end_frame is None:
@@ -453,7 +482,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
453482
channel_inds=self.channel_inds,
454483
spatial_interpolation_method=self.spatial_interpolation_method,
455484
spatial_interpolation_kwargs=self.spatial_interpolation_kwargs,
456-
interpolation_time_bin_centers_s=self.interpolation_time_bin_centers_s,
485+
interpolation_time_bin_edges_s=self.interpolation_time_bin_edges_s,
457486
)
458487

459488
if channel_indices is not None:

src/spikeinterface/sortingcomponents/motion/motion_utils.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import warnings
21
import json
2+
import warnings
33
from pathlib import Path
44

55
import numpy as np
@@ -54,6 +54,7 @@ def __init__(self, displacement, temporal_bins_s, spatial_bins_um, direction="y"
5454
self.direction = direction
5555
self.dim = ["x", "y", "z"].index(direction)
5656
self.check_properties()
57+
self.temporal_bin_edges_s = [ensure_time_bin_edges(tbins) for tbins in self.temporal_bins_s]
5758

5859
def check_properties(self):
5960
assert all(d.ndim == 2 for d in self.displacement)
@@ -576,3 +577,40 @@ def make_3d_motion_histograms(
576577
motion_histograms = np.log2(1 + motion_histograms)
577578

578579
return motion_histograms, temporal_bin_edges, spatial_bin_edges
580+
581+
582+
def ensure_time_bins(time_bin_centers_s=None, time_bin_edges_s=None):
583+
"""Ensure that both bin edges and bin centers are present
584+
585+
If either of the inputs are None but not both, the missing is reconstructed
586+
from the present. Going from edges to centers is done by taking midpoints.
587+
Going from centers to edges is done by taking midpoints and padding with the
588+
left and rightmost centers.
589+
590+
Parameters
591+
----------
592+
time_bin_centers_s : None or np.array
593+
time_bin_edges_s : None or np.array
594+
595+
Returns
596+
-------
597+
time_bin_centers_s, time_bin_edges_s
598+
"""
599+
if time_bin_centers_s is None and time_bin_edges_s is None:
600+
raise ValueError("Need at least one of time_bin_centers_s or time_bin_edges_s.")
601+
602+
if time_bin_centers_s is None:
603+
assert time_bin_edges_s.ndim == 1 and time_bin_edges_s.size >= 2
604+
time_bin_centers_s = 0.5 * (time_bin_edges_s[1:] + time_bin_edges_s[:-1])
605+
606+
if time_bin_edges_s is None:
607+
time_bin_edges_s = np.empty(time_bin_centers_s.shape[0] + 1, dtype=time_bin_centers_s.dtype)
608+
time_bin_edges_s[[0, -1]] = time_bin_centers_s[[0, -1]]
609+
if time_bin_centers_s.size > 2:
610+
time_bin_edges_s[1:-1] = 0.5 * (time_bin_centers_s[1:] + time_bin_centers_s[:-1])
611+
612+
return time_bin_centers_s, time_bin_edges_s
613+
614+
615+
def ensure_time_bin_edges(time_bin_centers_s=None, time_bin_edges_s=None):
616+
return ensure_time_bins(time_bin_centers_s, time_bin_edges_s)[1]

0 commit comments

Comments
 (0)