Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate unit_electrode_indices to SortingInterface #1124

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
## Bug Fixes

## Features
* Propagate the `unit_electrode_indices` argument from the spikeinterface tools to `BaseSortingExtractorInterface`. This allows users to map units to the electrode table when adding sorting data [PR #1124](https://github.com/catalystneuro/neuroconv/pull/1124)

## Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def add_to_nwbfile(
write_as: Literal["units", "processing"] = "units",
units_name: str = "units",
units_description: str = "Autogenerated by neuroconv.",
unit_electrode_indices: Optional[list[list[int]]] = None,
):
"""
Primary function for converting the data in a SortingExtractor to NWB format.
Expand All @@ -312,9 +313,15 @@ def add_to_nwbfile(
units_name : str, default: 'units'
The name of the units table. If write_as=='units', then units_name must also be 'units'.
units_description : str, default: 'Autogenerated by neuroconv.'
unit_electrode_indices : list of lists of int, optional
A list of lists of integers indicating the indices of the electrodes that each unit is associated with.
The length of the list must match the number of units in the sorting extractor.
"""
from ...tools.spikeinterface import add_sorting_to_nwbfile

if metadata is None:
metadata = self.get_metadata()

metadata_copy = deepcopy(metadata)
if write_ecephys_metadata:
self.add_channel_metadata_to_nwb(nwbfile=nwbfile, metadata=metadata_copy)
Expand Down Expand Up @@ -346,4 +353,5 @@ def add_to_nwbfile(
write_as=write_as,
units_name=units_name,
units_description=units_description,
unit_electrode_indices=unit_electrode_indices,
)
2 changes: 1 addition & 1 deletion src/neuroconv/nwbconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def validate_conversion_options(self, conversion_options: dict[str, dict]):

def create_nwbfile(self, metadata: Optional[dict] = None, conversion_options: Optional[dict] = None) -> NWBFile:
"""
Create and return an in-memory pynwb.NWBFile object with this interface's data added to it.
Create and return an in-memory pynwb.NWBFile object with the conversion data added to it.

Parameters
----------
Expand Down
19 changes: 14 additions & 5 deletions src/neuroconv/tools/spikeinterface/spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ def add_units_table_to_nwbfile(
write_in_processing_module: bool = False,
waveform_means: Optional[np.ndarray] = None,
waveform_sds: Optional[np.ndarray] = None,
unit_electrode_indices=None,
unit_electrode_indices: Optional[list[list[int]]] = None,
null_values_for_properties: Optional[dict] = None,
):
"""
Expand Down Expand Up @@ -1443,15 +1443,23 @@ def add_units_table_to_nwbfile(
Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels).
unit_electrode_indices : list of lists of int, optional
For each unit, a list of electrode indices corresponding to waveform data.
null_values_for_properties: dict, optional
A dictionary mapping properties to null values to use when the property is not present
unit_electrode_indices : list of lists of int, optional
A list of lists of integers indicating the indices of the electrodes that each unit is associated with.
The length of the list must match the number of units in the sorting extractor.
"""
unit_table_description = unit_table_description or "Autogenerated by neuroconv."

assert isinstance(
nwbfile, pynwb.NWBFile
), f"'nwbfile' should be of type pynwb.NWBFile but is of type {type(nwbfile)}"

if unit_electrode_indices is not None:
electrodes_table = nwbfile.electrodes
if electrodes_table is None:
raise ValueError(
"Electrodes table is required to map units to electrodes. Add an electrode table to the NWBFile first."
)

null_values_for_properties = dict() if null_values_for_properties is None else null_values_for_properties

if not write_in_processing_module and units_table_name != "units":
Expand Down Expand Up @@ -1706,7 +1714,7 @@ def add_sorting_to_nwbfile(
units_description: str = "Autogenerated by neuroconv.",
waveform_means: Optional[np.ndarray] = None,
waveform_sds: Optional[np.ndarray] = None,
unit_electrode_indices=None,
unit_electrode_indices: Optional[list[list[int]]] = None,
):
"""Add sorting data (units and their properties) to an NWBFile.

Expand Down Expand Up @@ -1741,7 +1749,8 @@ def add_sorting_to_nwbfile(
waveform_sds : np.ndarray, optional
Waveform standard deviation for each unit. Shape: (num_units, num_samples, num_channels).
unit_electrode_indices : list of lists of int, optional
For each unit, a list of electrode indices corresponding to waveform data.
A list of lists of integers indicating the indices of the electrodes that each unit is associated with.
The length of the list must match the number of units in the sorting extractor.
"""

if skip_features is not None:
Expand Down
79 changes: 52 additions & 27 deletions tests/test_ecephys/test_ecephys_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,59 @@

python_version = Version(get_python_version())

from neuroconv.tools.testing.data_interface_mixins import (
RecordingExtractorInterfaceTestMixin,
SortingExtractorInterfaceTestMixin,
)


class TestSortingInterface(SortingExtractorInterfaceTestMixin):

data_interface_cls = MockSortingInterface
interface_kwargs = dict(num_units=4, durations=[0.100])

def test_electrode_indices(self, setup_interface):

recording_interface = MockRecordingInterface(num_channels=4, durations=[0.100])
recording_extractor = recording_interface.recording_extractor
recording_extractor = recording_extractor.rename_channels(new_channel_ids=["a", "b", "c", "d"])
recording_extractor.set_property(key="property", values=["A", "B", "C", "D"])
recording_interface.recording_extractor = recording_extractor

nwbfile = recording_interface.create_nwbfile()

unit_electrode_indices = [[3], [0, 1], [1], [2]]
expected_properties_matching = [["D"], ["A", "B"], ["B"], ["C"]]
self.interface.add_to_nwbfile(nwbfile=nwbfile, unit_electrode_indices=unit_electrode_indices)

unit_table = nwbfile.units

for unit_row, electrode_indices, property in zip(
unit_table.to_dataframe().itertuples(index=False),
unit_electrode_indices,
expected_properties_matching,
):
electrode_table_region = unit_row.electrodes
electrode_table_region_indices = electrode_table_region.index.to_list()
assert electrode_table_region_indices == electrode_indices

electrode_table_region_properties = electrode_table_region["property"].to_list()
assert electrode_table_region_properties == property

class TestRecordingInterface(TestCase):
def test_electrode_indices_assertion_error_when_missing_table(self, setup_interface):
with pytest.raises(
ValueError,
match="Electrodes table is required to map units to electrodes. Add an electrode table to the NWBFile first.",
):
self.interface.create_nwbfile(unit_electrode_indices=[[0], [1], [2], [3]])


class TestRecordingInterface(RecordingExtractorInterfaceTestMixin):
data_interface_cls = MockRecordingInterface
interface_kwargs = dict(durations=[0.100])


class TestRecordingInterfaceOld(TestCase):
@classmethod
def setUpClass(cls):
cls.single_segment_recording_interface = MockRecordingInterface(durations=[0.100])
Expand Down Expand Up @@ -84,32 +135,6 @@ def test_spike2_import_assertions_3_11(self):
Spike2RecordingInterface.get_all_channels_info(file_path="does_not_matter.smrx")


class TestSortingInterface:

def test_run_conversion(self, tmp_path):

nwbfile_path = Path(tmp_path) / "test_sorting.nwb"
num_units = 4
interface = MockSortingInterface(num_units=num_units, durations=(1.0,))
interface.sorting_extractor = interface.sorting_extractor.rename_units(new_unit_ids=["a", "b", "c", "d"])

interface.run_conversion(nwbfile_path=nwbfile_path)
with NWBHDF5IO(nwbfile_path, "r") as io:
nwbfile = io.read()

units = nwbfile.units
assert len(units) == num_units
units_df = units.to_dataframe()
# Get index in units table
for unit_id in interface.sorting_extractor.unit_ids:
# In pynwb we write unit name as unit_id
row = units_df.query(f"unit_name == '{unit_id}'")
spike_times = interface.sorting_extractor.get_unit_spike_train(unit_id=unit_id, return_times=True)
written_spike_times = row["spike_times"].iloc[0]

np.testing.assert_array_equal(spike_times, written_spike_times)


class TestSortingInterfaceOld(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down
9 changes: 0 additions & 9 deletions tests/test_ecephys/test_mock_recording_interface.py
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
from neuroconv.tools.testing.data_interface_mixins import (
RecordingExtractorInterfaceTestMixin,
)
from neuroconv.tools.testing.mock_interfaces import MockRecordingInterface


class TestMockRecordingInterface(RecordingExtractorInterfaceTestMixin):
data_interface_cls = MockRecordingInterface
interface_kwargs = dict(durations=[0.100])
Loading