Skip to content

Commit

Permalink
Add SortingAnalyzer.set_unit_property() function
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 committed Feb 14, 2025
1 parent dd0c6d3 commit b61f837
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
49 changes: 48 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Literal, Optional
from typing import Literal, Optional, Any

from pathlib import Path
from itertools import chain
Expand Down Expand Up @@ -742,6 +742,53 @@ def set_temporary_recording(self, recording: BaseRecording, check_dtype: bool =
warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.")
self._temporary_recording = recording

def set_unit_property(
self,
key,
values: list | np.ndarray | tuple,
ids: list | np.ndarray | tuple | None = None,
missing_value: Any = None,
) -> None:
"""
Set property vector for unit ids.
If the SortingAnalyzer backend is in memory, the property will be only set in memory.
If the SortingAnalyzer backend is in binary_folder or zarr, the property will also
be saved to to the backend.
Parameters
----------
key : str
The property name
values : np.array
Array of values for the property
ids : list/np.array, default: None
List of subset of ids to set the values, default: None
if None which is the default all the ids are set or changed
missing_value : object, default: None
In case the property is set on a subset of values ("ids" not None),
it specifies the how the missing values should be filled.
The missing_value has to be specified for types int and unsigned int.
"""
self.sorting.set_property(key, values, ids=ids, missing_value=missing_value)
if not self.is_read_only():
if self.format == "binary_folder":
np.save(self.folder / "sorting" / "properties" / f"{key}.npy", self.sorting.get_property(key))
elif self.format == "zarr":
import zarr

zarr_root = self._get_zarr_root(mode="r+")
prop_values = self.sorting.get_property(key)
if prop_values.dtype.kind == "O":
warnings.warn(f"Property {key} not saved because it is a python Object type")
else:
if key in zarr_root["sorting"]["properties"]:
zarr_root["sorting"]["properties"][key][:] = prop_values
else:
zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None)
# IMPORTANT: we need to re-consolidate the zarr store!
zarr.consolidate_metadata(zarr_root.store)

def _save_or_select_or_merge(
self,
format="binary_folder",
Expand Down
34 changes: 31 additions & 3 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def test_SortingAnalyzer_memory(tmp_path, dataset):
)
assert not sorting_analyzer.return_scaled

# test set_unit_property
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
assert "quality" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()


def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
recording, sorting = dataset
Expand Down Expand Up @@ -103,6 +109,15 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset):
assert not sorting_analyzer.return_scaled
_check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path)

# test set_unit_property
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
assert "quality" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()


def test_SortingAnalyzer_zarr(tmp_path, dataset):
recording, sorting = dataset
Expand Down Expand Up @@ -176,6 +191,15 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
== LZMA.codec_id
)

# test set_unit_property
sorting_analyzer.set_unit_property(key="quality", values=["good"] * len(sorting_analyzer.unit_ids))
sorting_analyzer.set_unit_property(key="number", values=np.arange(len(sorting_analyzer.unit_ids)))
assert "quality" in sorting_analyzer.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()
sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto")
assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys()
assert "number" in sorting_analyzer.sorting.get_property_keys()


def test_load_without_runtime_info(tmp_path, dataset):
import zarr
Expand Down Expand Up @@ -262,9 +286,6 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
assert "sampling_frequency" in sorting_analyzer.rec_attributes
assert "num_samples" in sorting_analyzer.rec_attributes

probe = sorting_analyzer.get_probe()
sparsity = sorting_analyzer.sparsity

# compute
sorting_analyzer.compute("dummy", param1=5.5)
# equivalent
Expand Down Expand Up @@ -367,6 +388,9 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
else:
folder = None
sorting_analyzer4 = sorting_analyzer.merge_units(merge_unit_groups=[[0, 1]], format=format, folder=folder)
assert 0 not in sorting_analyzer4.unit_ids
assert 1 not in sorting_analyzer4.unit_ids
assert len(sorting_analyzer4.unit_ids) == len(sorting_analyzer.unit_ids) - 1

if format != "memory":
if format == "zarr":
Expand All @@ -380,6 +404,10 @@ def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder):
sorting_analyzer5 = sorting_analyzer.merge_units(
merge_unit_groups=[[0, 1]], new_unit_ids=[50], format=format, folder=folder, merging_mode="hard"
)
assert 0 not in sorting_analyzer5.unit_ids
assert 1 not in sorting_analyzer5.unit_ids
assert len(sorting_analyzer5.unit_ids) == len(sorting_analyzer.unit_ids) - 1
assert 50 in sorting_analyzer5.unit_ids

# test compute with extension-specific params
sorting_analyzer.compute(["dummy"], extension_params={"dummy": {"param1": 5.5}})
Expand Down

0 comments on commit b61f837

Please sign in to comment.