Skip to content

Commit b61f837

Browse files
committed
Add SortingAnalyzer.set_unit_property() function
1 parent dd0c6d3 commit b61f837

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

src/spikeinterface/core/sortinganalyzer.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from __future__ import annotations
2-
from typing import Literal, Optional
2+
from typing import Literal, Optional, Any
33

44
from pathlib import Path
55
from itertools import chain
@@ -742,6 +742,53 @@ def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool =
742742
warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.")
743743
self._temporary_recording = recording
744744

745+
def set_unit_property(
746+
self,
747+
key,
748+
values: list | np.ndarray | tuple,
749+
ids: list | np.ndarray | tuple | None = None,
750+
missing_value: Any = None,
751+
) -> None:
752+
"""
753+
Set property vector for unit ids.
754+
755+
If the SortingAnalyzer backend is in memory, the property will be only set in memory.
756+
If the SortingAnalyzer backend is in binary_folder or zarr, the property will also
757+
be saved to to the backend.
758+
759+
Parameters
760+
----------
761+
key : str
762+
The property name
763+
values : np.array
764+
Array of values for the property
765+
ids : list/np.array, default: None
766+
List of subset of ids to set the values, default: None
767+
if None which is the default all the ids are set or changed
768+
missing_value : object, default: None
769+
In case the property is set on a subset of values ("ids" not None),
770+
it specifies the how the missing values should be filled.
771+
The missing_value has to be specified for types int and unsigned int.
772+
"""
773+
self.sorting.set_property(key, values, ids=ids, missing_value=missing_value)
774+
if not self.is_read_only():
775+
if self.format == "binary_folder":
776+
np.save(self.folder / "sorting" / "properties" / f"{key}.npy", self.sorting.get_property(key))
777+
elif self.format == "zarr":
778+
import zarr
779+
780+
zarr_root = self._get_zarr_root(mode="r+")
781+
prop_values = self.sorting.get_property(key)
782+
if prop_values.dtype.kind == "O":
783+
warnings.warn(f"Property {key} not saved because it is a python Object type")
784+
else:
785+
if key in zarr_root["sorting"]["properties"]:
786+
zarr_root["sorting"]["properties"][key][:] = prop_values
787+
else:
788+
zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None)
789+
# IMPORTANT: we need to re-consolidate the zarr store!
790+
zarr.consolidate_metadata(zarr_root.store)
791+
745792
def _save_or_select_or_merge(
746793
self,
747794
format="binary_folder",

src/spikeinterface/core/tests/test_sortinganalyzer.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def test_SortingAnalyzer_memory(tmp_path, dataset):
6666
)
6767
assert not sorting_analyzer.return_scaled
6868

69+
# test set_unit_property
70+
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
71+
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
72+
assert "quality" in sorting_analyzer.sorting.get_property_keys()
73+
assert "number" in sorting_analyzer.sorting.get_property_keys()
74+
6975

7076
def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
7177
recording, sorting = dataset
@@ -103,6 +109,15 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
103109
assert not sorting_analyzer.return_scaled
104110
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)
105111

112+
# test set_unit_property
113+
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
114+
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
115+
assert "quality" in sorting_analyzer.sorting.get_property_keys()
116+
assert "number" in sorting_analyzer.sorting.get_property_keys()
117+
sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto")
118+
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
119+
assert "number" in sorting_analyzer.sorting.get_property_keys()
120+
106121

107122
def test_SortingAnalyzer_zarr(tmp_path, dataset):
108123
recording, sorting = dataset
@@ -176,6 +191,15 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
176191
== LZMA.codec_id
177192
)
178193

194+
# test set_unit_property
195+
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
196+
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
197+
assert "quality" in sorting_analyzer.sorting.get_property_keys()
198+
assert "number" in sorting_analyzer.sorting.get_property_keys()
199+
sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto")
200+
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
201+
assert "number" in sorting_analyzer.sorting.get_property_keys()
202+
179203

180204
def test_load_without_runtime_info(tmp_path, dataset):
181205
import zarr
@@ -262,9 +286,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
262286
assert "sampling_frequency" in sorting_analyzer.rec_attributes
263287
assert "num_samples" in sorting_analyzer.rec_attributes
264288

265-
probe = sorting_analyzer.get_probe()
266-
sparsity = sorting_analyzer.sparsity
267-
268289
# compute
269290
sorting_analyzer.compute("dummy", param1=5.5)
270291
# equivalent
@@ -367,6 +388,9 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
367388
else:
368389
folder = None
369390
sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder)
391+
assert 0 not in sorting_analyzer4.unit_ids
392+
assert 1 not in sorting_analyzer4.unit_ids
393+
assert len(sorting_analyzer4.unit_ids) == len(sorting_analyzer.unit_ids) - 1
370394

371395
if format != "memory":
372396
if format == "zarr":
@@ -380,6 +404,10 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
380404
sorting_analyzer5 = sorting_analyzer.merge_units(
381405
merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard"
382406
)
407+
assert 0 not in sorting_analyzer5.unit_ids
408+
assert 1 not in sorting_analyzer5.unit_ids
409+
assert len(sorting_analyzer5.unit_ids) == len(sorting_analyzer.unit_ids) - 1
410+
assert 50 in sorting_analyzer5.unit_ids
383411

384412
# test compute with extension-specific params
385413
sorting_analyzer.compute(["dummy"], extension_params={"dummy": {"param1": 5.5}})

0 commit comments

Comments
 (0)