Skip to content

Commit 9e84a62

Browse files
authored
Merge pull request SpikeInterface#3208 from samuelgarcia/apply_curation
Start apply_curation()
2 parents 63b295c + 0eacd1a commit 9e84a62

File tree

9 files changed

+410
-56
lines changed

9 files changed

+410
-56
lines changed

doc/api.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,21 @@ spikeinterface.curation
324324
------------------------
325325
.. automodule:: spikeinterface.curation
326326

327-
.. autoclass:: CurationSorting
328-
.. autoclass:: MergeUnitsSorting
329-
.. autoclass:: SplitUnitSorting
327+
.. autofunction:: apply_curation
330328
.. autofunction:: get_potential_auto_merge
331329
.. autofunction:: find_redundant_units
332330
.. autofunction:: remove_redundant_units
333331
.. autofunction:: remove_duplicated_spikes
334332
.. autofunction:: remove_excess_spikes
333+
334+
Deprecated
335+
~~~~~~~~~~
336+
.. automodule:: spikeinterface.curation
337+
335338
.. autofunction:: apply_sortingview_curation
339+
.. autoclass:: CurationSorting
340+
.. autoclass:: MergeUnitsSorting
341+
.. autoclass:: SplitUnitSorting
336342

337343

338344
spikeinterface.generation

doc/modules/curation.rst

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,11 +261,23 @@ format is the definition; the second part of the format is manual action):
261261
}
262262
263263
264-
.. note::
265-
The curation format was recently introduced (v0.101.0), and we are still working on
266-
properly integrating it into the SpikeInterface ecosystem.
267-
Soon there will be functions vailable, in the curation module, to apply this
268-
standardized curation format to ``SortingAnalyzer`` and a ``BaseSorting`` objects.
264+
The curation format can be loaded into a dictionary and directly applied to
265+
a ``BaseSorting`` or ``SortingAnalyzer`` object using the :py:func:`~spikeinterface.curation.apply_curation` function.
266+
267+
.. code-block:: python
268+
269+
from spikeinterface.curation import apply_curation
270+
271+
# load the curation JSON file
272+
curation_json = "path/to/curation.json"
273+
with open(curation_json, 'r') as f:
274+
curation_dict = json.load(f)
275+
276+
# apply the curation to the sorting output
277+
clean_sorting = apply_curation(sorting, curation_dict=curation_dict)
278+
279+
# apply the curation to the sorting analyzer
280+
clean_sorting_analyzer = apply_curation(sorting_analyzer, curation_dict=curation_dict)
269281
270282
271283
Using the ``SpikeInterface GUI``

src/spikeinterface/core/sorting_tools.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def random_spikes_selection(
227227

228228

229229
def apply_merges_to_sorting(
230-
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_kept=False, new_id_strategy="append"
230+
sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append"
231231
):
232232
"""
233233
Apply a resolved representation of the merges to a sorting object.
@@ -250,8 +250,8 @@ def apply_merges_to_sorting(
250250
merged units will have the first unit_id of every lists of merges.
251251
censor_ms: float | None, default: None
252252
When applying the merges, should be discard consecutive spikes violating a given refractory per
253-
return_kept : bool, default: False
254-
If True, also return also a boolean mask of kept spikes.
253+
return_extra : bool, default: False
254+
If True, also return also a boolean mask of kept spikes and new_unit_ids.
255255
new_id_strategy : "append" | "take_first", default: "append"
256256
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
257257
@@ -316,8 +316,8 @@ def apply_merges_to_sorting(
316316
spikes = spikes[keep_mask]
317317
sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids)
318318

319-
if return_kept:
320-
return sorting, keep_mask
319+
if return_extra:
320+
return sorting, keep_mask, new_unit_ids
321321
else:
322322
return sorting
323323

@@ -384,11 +384,13 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
384384
new_unit_ids : list | None, default: None
385385
Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`.
386386
If None, new ids will be generated.
387-
new_id_strategy : "append" | "take_first", default: "append"
387+
new_id_strategy : "append" | "take_first" | "join", default: "append"
388388
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
389389
390390
* "append" : new_units_ids will be added at the end of max(sorging.unit_ids)
391391
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
392+
* "join" : new_unit_ids will join unit_ids of groups with a "-".
393+
Only works if unit_ids are str otherwise switch to "append"
392394
393395
Returns
394396
-------
@@ -423,6 +425,12 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_
423425
else:
424426
# dtype int
425427
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
428+
elif new_id_strategy == "join":
429+
if np.issubdtype(dtype, np.character):
430+
new_unit_ids = ["-".join(group) for group in merge_unit_groups]
431+
else:
432+
# dtype int
433+
new_unit_ids = list(max(old_unit_ids) + 1 + np.arange(num_merge, dtype=dtype))
426434
else:
427435
raise ValueError("wrong new_id_strategy")
428436

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -732,12 +732,12 @@ def _save_or_select_or_merge(
732732
else:
733733
from spikeinterface.core.sorting_tools import apply_merges_to_sorting
734734

735-
sorting_provenance, keep_mask = apply_merges_to_sorting(
735+
sorting_provenance, keep_mask, _ = apply_merges_to_sorting(
736736
sorting=sorting_provenance,
737737
merge_unit_groups=merge_unit_groups,
738738
new_unit_ids=new_unit_ids,
739739
censor_ms=censor_ms,
740-
return_kept=True,
740+
return_extra=True,
741741
)
742742
if censor_ms is None:
743743
# in this case having keep_mask None is faster instead of having a vector of ones
@@ -885,6 +885,7 @@ def merge_units(
885885
merging_mode="soft",
886886
sparsity_overlap=0.75,
887887
new_id_strategy="append",
888+
return_new_unit_ids=False,
888889
format="memory",
889890
folder=None,
890891
verbose=False,
@@ -917,14 +918,15 @@ def merge_units(
917918
The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids.
918919
* "append" : new_units_ids will be added at the end of max(sorting.unit_ids)
919920
* "take_first" : new_unit_ids will be the first unit_id of every list of merges
921+
return_new_unit_ids : bool, default False
922+
Alse return new_unit_ids which are the ids of the new units.
920923
folder : Path | None, default: None
921924
The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr"
922925
format : "memory" | "binary_folder" | "zarr", default: "memory"
923926
The format of SortingAnalyzer
924927
verbose : bool, default: False
925928
Whether to display calculations (such as sparsity estimation)
926929
927-
928930
Returns
929931
-------
930932
analyzer : SortingAnalyzer
@@ -952,7 +954,7 @@ def merge_units(
952954
)
953955
all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids)
954956

955-
return self._save_or_select_or_merge(
957+
new_analyzer = self._save_or_select_or_merge(
956958
format=format,
957959
folder=folder,
958960
merge_unit_groups=merge_unit_groups,
@@ -964,6 +966,10 @@ def merge_units(
964966
new_unit_ids=new_unit_ids,
965967
**job_kwargs,
966968
)
969+
if return_new_unit_ids:
970+
return new_analyzer, new_unit_ids
971+
else:
972+
return new_analyzer
967973

968974
def copy(self):
969975
"""

src/spikeinterface/core/tests/test_sorting_tools.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_apply_merges_to_sorting():
9696
spikes1[spikes1["unit_index"] == 2]["sample_index"], spikes2[spikes2["unit_index"] == 0]["sample_index"]
9797
)
9898

99-
sorting3, keep_mask = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_kept=True)
99+
sorting3, keep_mask, _ = apply_merges_to_sorting(sorting1, [["a", "b"]], censor_ms=1.5, return_extra=True)
100100
spikes3 = sorting3.to_spike_vector()
101101
assert spikes3.size < spikes1.size
102102
assert not keep_mask[1]
@@ -153,6 +153,11 @@ def test_generate_unit_ids_for_merge_group():
153153
)
154154
assert np.array_equal(new_unit_ids, ["0", "9"])
155155

156+
new_unit_ids = generate_unit_ids_for_merge_group(
157+
["0", "5", "12", "9", "15"], [["0", "5"], ["9", "15"]], new_id_strategy="join"
158+
)
159+
assert np.array_equal(new_unit_ids, ["0-5", "9-15"])
160+
156161

157162
if __name__ == "__main__":
158163
# test_spike_vector_to_spike_trains()

0 commit comments

Comments
 (0)