diff --git a/CHANGELOG.md b/CHANGELOG.md index 11dce4e7a8..029db89789 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ ## Features * Propagate the `unit_electrode_indices` argument from the spikeinterface tools to `BaseSortingExtractorInterface`. This allows users to map units to the electrode table when adding sorting data [PR #1124](https://github.com/catalystneuro/neuroconv/pull/1124) +* Added `SortedRecordingConverter` to convert sorted recordings to NWB with correct metadata mapping between units and electrodes [PR #1132](https://github.com/catalystneuro/neuroconv/pull/1132) * Imaging interfaces have a new conversion option `always_write_timestamps` that can be used to force writing timestamps even if neuroconv's heuristics indicates regular sampling rate [PR #1125](https://github.com/catalystneuro/neuroconv/pull/1125) * Added .csv support to DeepLabCutInterface [PR #1140](https://github.com/catalystneuro/neuroconv/pull/1140) * `SpikeGLXRecordingInterface` now also accepts `folder_path` making its behavior equivalent to SpikeInterface [#1150](https://github.com/catalystneuro/neuroconv/pull/1150) diff --git a/docs/user_guide/index.rst b/docs/user_guide/index.rst index bf9aaf253b..6abfa51cad 100644 --- a/docs/user_guide/index.rst +++ b/docs/user_guide/index.rst @@ -25,6 +25,7 @@ and synchronize data across multiple sources. csvs expand_path backend_configuration + linking_sorted_data yaml docker_demo aws_demo diff --git a/docs/user_guide/linking_sorted_data.rst b/docs/user_guide/linking_sorted_data.rst new file mode 100644 index 0000000000..8b9d596b45 --- /dev/null +++ b/docs/user_guide/linking_sorted_data.rst @@ -0,0 +1,74 @@ +.. _linking_sorted_data: + +How to Link Sorted Data to Electrodes +=================================== + +The ``SortedRecordingConverter`` maintains proper linkage between sorted units and their corresponding recording channels in NWB files. +It handles the critical relationship between ``Units`` and ``Electrodes`` tables by: + +* Creating electrode table regions for each unit +* Maintaining electrode group and device relationships +* Mapping channel IDs to electrode indices correctly + +This automated handling ensures proper provenance tracking in the NWB file, which is essential for interpreting and analyzing sorted electrophysiology data. + +Basic Usage +---------- + +Single Probe and Single Recording +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +This example demonstrates linking data from a single Neuropixel probe recorded with SpikeGLX and sorted with Kilosort. + +The converter requires three components: + +1. A recording interface (:py:class:`~neuroconv.datainterfaces.ecephys.spikeglx.spikeglxrecordinginterface.SpikeGLXRecordingInterface`) +2. A sorting interface (:py:class:`~neuroconv.datainterfaces.ecephys.kilosort.kilosortinterface.KiloSortSortingInterface`) +3. A mapping between unit IDs and their associated channel IDs + +First, instantiate the interfaces:: + + from neuroconv import SortedRecordingConverter + from neuroconv.datainterfaces import SpikeGLXRecordingInterface, KiloSortSortingInterface + + # Initialize interfaces + recording_interface = SpikeGLXRecordingInterface( + folder_path="path/to/spikeglx_data", + stream_id="imec0.ap" + ) + sorting_interface = KiloSortSortingInterface( + folder_path="path/to/kilosort_data" + ) + +Access channel and unit IDs through interface properties:: + + # Access channel IDs + print(recording_interface.channel_ids) + # Example output: ['imec0.ap#AP0', 'imec0.ap#AP1', 'imec0.ap#AP2', ...] + + # Access unit IDs + print(sorting_interface.unit_ids) + # Example output: ['0', '1', '2', ...] + +Define the mapping between units and channels:: + + unit_ids_to_channel_ids = { + "0": ["imec0.ap#AP0", "imec0.ap#AP1"], # Unit 0 detected on two channels + "1": ["imec0.ap#AP2"], # Unit 1 detected on one channel + "2": ["imec0.ap#AP3", "imec0.ap#AP4"], # Unit 2 detected on two channels + ... # Map all remaining units to their respective channels + } + +.. note:: + + Every unit from the sorting interface must have a corresponding channel mapping. + +Create the converter and run the conversion:: + + converter = SortedRecordingConverter( + recording_interface=recording_interface, + sorting_interface=sorting_interface, + unit_ids_to_channel_ids=unit_ids_to_channel_ids + ) + + nwbfile = converter.run_conversion(nwbfile_path="path/to/output.nwb") diff --git a/src/neuroconv/converters/__init__.py b/src/neuroconv/converters/__init__.py index 7fa38bfd58..a741ff3319 100644 --- a/src/neuroconv/converters/__init__.py +++ b/src/neuroconv/converters/__init__.py @@ -9,6 +9,7 @@ from ..datainterfaces.behavior.lightningpose.lightningposeconverter import ( LightningPoseConverter, ) +from ..datainterfaces.ecephys.sortedrecordinginterface import SortedRecordingConverter from ..datainterfaces.ecephys.spikeglx.spikeglxconverter import SpikeGLXConverterPipe from ..datainterfaces.ophys.brukertiff.brukertiffconverter import ( BrukerTiffMultiPlaneConverter, @@ -22,4 +23,5 @@ BrukerTiffMultiPlaneConverter, BrukerTiffSinglePlaneConverter, MiniscopeConverter, + SortedRecordingConverter, ] diff --git a/src/neuroconv/datainterfaces/ecephys/baserecordingextractorinterface.py b/src/neuroconv/datainterfaces/ecephys/baserecordingextractorinterface.py index 6d0df14c12..8b650897b2 100644 --- a/src/neuroconv/datainterfaces/ecephys/baserecordingextractorinterface.py +++ b/src/neuroconv/datainterfaces/ecephys/baserecordingextractorinterface.py @@ -107,6 +107,11 @@ def get_metadata(self) -> DeepDict: return metadata + @property + def channel_ids(self): + "Gets the channel ids of the data." + return self.recording_extractor.get_channel_ids() + def get_original_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]: """ Retrieve the original unaltered timestamps for the data in this interface. diff --git a/src/neuroconv/datainterfaces/ecephys/basesortingextractorinterface.py b/src/neuroconv/datainterfaces/ecephys/basesortingextractorinterface.py index 8eeb59324a..9130270b99 100644 --- a/src/neuroconv/datainterfaces/ecephys/basesortingextractorinterface.py +++ b/src/neuroconv/datainterfaces/ecephys/basesortingextractorinterface.py @@ -75,6 +75,11 @@ def get_metadata_schema(self) -> dict: ) return metadata_schema + @property + def units_ids(self): + "Gets the units ids of the data." + return self.sorting_extractor.get_unit_ids() + def register_recording(self, recording_interface: BaseRecordingExtractorInterface): self.sorting_extractor.register_recording(recording=recording_interface.recording_extractor) diff --git a/src/neuroconv/datainterfaces/ecephys/sortedrecordinginterface.py b/src/neuroconv/datainterfaces/ecephys/sortedrecordinginterface.py new file mode 100644 index 0000000000..b934204de1 --- /dev/null +++ b/src/neuroconv/datainterfaces/ecephys/sortedrecordinginterface.py @@ -0,0 +1,120 @@ +from typing import Optional, Union + +from neuroconv import ConverterPipe +from neuroconv.datainterfaces.ecephys.baserecordingextractorinterface import ( + BaseRecordingExtractorInterface, +) +from neuroconv.datainterfaces.ecephys.basesortingextractorinterface import ( + BaseSortingExtractorInterface, +) +from neuroconv.tools.spikeinterface.spikeinterface import ( + _get_electrode_table_indices_for_recording, +) + + +class SortedRecordingConverter(ConverterPipe): + """ + A converter for handling simultaneous recording and sorting data from multiple probes, + ensuring correct mapping between sorted units and their corresponding electrodes. + + This converter addresses the challenge of maintaining proper linkage between sorted units + and their recording channels when dealing with multiple recording probes (e.g., multiple + Neuropixels probes). It ensures that units from each sorting interface are correctly + associated with electrodes from their corresponding recording interface. + """ + + keywords = ( + "electrophysiology", + "spike sorting", + ) + display_name = "SortedRecordingConverter" + associated_suffixes = ("None",) + info = "A converter for handling simultaneous recording and sorting data linking metadata properly." + + def __init__( + self, + recording_interface: BaseRecordingExtractorInterface, + sorting_interface: BaseSortingExtractorInterface, + unit_ids_to_channel_ids: dict[Union[str, int], list[Union[str, int]]], + ): + """ + Parameters + ---------- + recording_interface : BaseRecordingExtractorInterface + The interface handling the raw recording data. This typically represents data + from a single probe, like a SpikeGLXRecordingInterface. + sorting_interface : BaseSortingExtractorInterface + The interface handling the sorted units data. This typically represents the + output of a spike sorting algorithm, like a KiloSortSortingInterface. + unit_ids_to_channel_ids : dict[str | int, list[str | int]] + A mapping from unit IDs to their associated channel IDs. Each unit ID (key) + maps to a list of channel IDs (values) that were used to detect that unit. + This mapping ensures proper linkage between sorted units and recording channels. + """ + + self.recording_interface = recording_interface + self.sorting_interface = sorting_interface + self.unit_ids_to_channel_ids = unit_ids_to_channel_ids + + self.channel_ids = self.recording_interface.channel_ids + self.unit_ids = self.sorting_interface.units_ids + + # Convert channel_ids to set for comparison + available_channels = set(self.channel_ids) + + # Check that all referenced channels exist in recording + for unit_id, channel_ids in unit_ids_to_channel_ids.items(): + unknown_channels = set(channel_ids) - available_channels + if unknown_channels: + raise ValueError( + f"Inexistent channel IDs {unknown_channels} referenced in mapping for unit {unit_id} " + f"not found in recording. Available channels are {available_channels}" + ) + + # Check that all units have a channel mapping + available_units = set(self.unit_ids) + mapped_units = set(unit_ids_to_channel_ids.keys()) + unmapped_units = available_units - mapped_units + if unmapped_units: + raise ValueError(f"Units {unmapped_units} from sorting interface have no channel mapping") + + data_interfaces = [recording_interface, sorting_interface] + super().__init__(data_interfaces=data_interfaces) + + def add_to_nwbfile(self, nwbfile, metadata, conversion_options: Optional[dict] = None): + + conversion_options = conversion_options or dict() + conversion_options_recording = conversion_options.get("recording", dict()) + + self.recording_interface.add_to_nwbfile( + nwbfile=nwbfile, + metadata=metadata, + **conversion_options_recording, + ) + + # This returns the indices in the electrode table that corresponds to the channel ids of the recording + electrode_table_indices = _get_electrode_table_indices_for_recording( + recording=self.recording_interface.recording_extractor, + nwbfile=nwbfile, + ) + + # Map ids in the recording to the indices in the electrode table + channel_id_to_electrode_table_index = { + channel_id: electrode_table_indices[channel_index] + for channel_index, channel_id in enumerate(self.channel_ids) + } + + # Create a list of lists with the indices of the electrodes in the electrode table for each unit + unit_electrode_indices = [] + for unit_id in self.unit_ids: + unit_channel_ids = self.unit_ids_to_channel_ids[unit_id] + unit_indices = [channel_id_to_electrode_table_index[channel_id] for channel_id in unit_channel_ids] + unit_electrode_indices.append(unit_indices) + + conversion_options_sorting = conversion_options.get("sorting", dict()) + self.sorting_interface.add_to_nwbfile( + nwbfile=nwbfile, + metadata=metadata, + unit_electrode_indices=unit_electrode_indices, + **conversion_options_sorting, + ) diff --git a/src/neuroconv/nwbconverter.py b/src/neuroconv/nwbconverter.py index 2d70cf8eea..452e099b58 100644 --- a/src/neuroconv/nwbconverter.py +++ b/src/neuroconv/nwbconverter.py @@ -155,7 +155,7 @@ def validate_conversion_options(self, conversion_options: dict[str, dict]): def create_nwbfile(self, metadata: Optional[dict] = None, conversion_options: Optional[dict] = None) -> NWBFile: """ - Create and return an in-memory pynwb.NWBFile object with this interface's data added to it. + Create and return an in-memory pynwb.NWBFile object with the conversion data added to it. Parameters ---------- diff --git a/tests/test_ecephys/test_sorted_recording.py b/tests/test_ecephys/test_sorted_recording.py new file mode 100644 index 0000000000..b822e428cc --- /dev/null +++ b/tests/test_ecephys/test_sorted_recording.py @@ -0,0 +1,91 @@ +import pytest + +from neuroconv.converters import SortedRecordingConverter +from neuroconv.tools.testing.mock_interfaces import ( + MockRecordingInterface, + MockSortingInterface, +) + + +class TestSortedRecordingConverter: + + def basic_test(self): + + recording_interface = MockRecordingInterface(num_channels=4, durations=[0.100]) + recording_extractor = recording_interface.recording_extractor + recording_extractor = recording_extractor.rename_channels(new_channel_ids=["A", "B", "C"]) + recording_interface.recording_extractor = recording_extractor + + sorting_interface = MockSortingInterface(num_units=5, durations=[0.100]) + sorting_extractor = sorting_interface.sorting_extractor + sorting_extractor = sorting_extractor.rename_units(new_unit_ids=["a", "b", "c", "d", "e"]) + sorting_interface.sorting_extractor = sorting_extractor + + unit_ids_to_channel_ids = { + "a": ["A"], + "b": ["B"], + "c": ["C"], + "d": ["A", "B"], + "e": ["C", "A"], + } + sorted_recording_interface = SortedRecordingConverter( + recording_interface=recording_interface, + sorting_interface=sorting_interface, + unit_ids_to_channel_ids=unit_ids_to_channel_ids, + ) + + nwbfile = sorted_recording_interface.create_nwbfile() + + # Test that the region in the units table points to the correct electrodes + expected_unit_electrode_indices = { + "a": [0], + "b": [1], + "c": [2], + "d": [0, 1], + "e": [2, 0], + } + unit_table = nwbfile.units + for unit_row in unit_table.to_dataframe().itertuples(index=False): + + # Neuroconv write unit_ids as unit_names + unit_id = unit_row.unit_name + + unit_electrode_table_region = unit_row.electrodes + expected_unit_electrode_indices = expected_unit_electrode_indices[unit_id] + assert unit_electrode_table_region == expected_unit_electrode_indices + + def test_invalid_channel_mapping(self): + """Test that invalid channel mappings raise appropriate errors.""" + recording_interface = MockRecordingInterface(num_channels=4, durations=[0.100]) + recording_extractor = recording_interface.recording_extractor + recording_extractor = recording_extractor.rename_channels(new_channel_ids=["ch1", "ch2", "ch3", "ch4"]) + recording_interface.recording_extractor = recording_extractor + + sorting_interface = MockSortingInterface(num_units=3, durations=[0.100]) + sorting_extractor = sorting_interface.sorting_extractor + sorting_extractor = sorting_extractor.rename_units(new_unit_ids=["unit1", "unit2", "unit3"]) + sorting_interface.sorting_extractor = sorting_extractor + + # Test mapping with non-existent channel + invalid_channel_mapping = {"unit1": ["ch1", "ch5"], "unit2": ["ch2"], "unit3": ["ch3"]} # ch5 doesn't exist + + with pytest.raises(ValueError, match="Inexistent channel IDs {'ch5'} referenced in mapping for unit unit1"): + SortedRecordingConverter( + recording_interface=recording_interface, + sorting_interface=sorting_interface, + unit_ids_to_channel_ids=invalid_channel_mapping, + ) + + # Test mapping missing some units + incomplete_mapping = { + "unit1": ["ch1"], + "unit2": ["ch2"], + # unit3 is missing + } + + with pytest.raises(ValueError, match="Units {'unit3'} from sorting interface have no channel mapping"): + SortedRecordingConverter( + recording_interface=recording_interface, + sorting_interface=sorting_interface, + unit_ids_to_channel_ids=incomplete_mapping, + )