From 165cb31068a3dac62eee65921b4ab09ba1111d85 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 22 Oct 2024 08:12:01 -0600 Subject: [PATCH] Add more friendly error when writing recording with multiple offsets (#1111) --- CHANGELOG.md | 1 + .../tools/spikeinterface/spikeinterface.py | 33 ++++++++++++++++--- .../test_ecephys/test_tools_spikeinterface.py | 33 +++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d748d79e6..931383066 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ## Features * Using in-house `GenericDataChunkIterator` [PR #1068](https://github.com/catalystneuro/neuroconv/pull/1068) * Data interfaces now perform source (argument inputs) validation with the json schema [PR #1020](https://github.com/catalystneuro/neuroconv/pull/1020) +* Improve the error message when writing a recording extractor with multiple offsets [PR #1111](https://github.com/catalystneuro/neuroconv/pull/1111) * Added `channels_to_skip` to `EDFRecordingInterface` so the user can skip non-neural channels [PR #1110](https://github.com/catalystneuro/neuroconv/pull/1110) ## Improvements diff --git a/src/neuroconv/tools/spikeinterface/spikeinterface.py b/src/neuroconv/tools/spikeinterface/spikeinterface.py index 9128078a6..1be86862a 100644 --- a/src/neuroconv/tools/spikeinterface/spikeinterface.py +++ b/src/neuroconv/tools/spikeinterface/spikeinterface.py @@ -99,7 +99,8 @@ def add_devices_to_nwbfile(nwbfile: pynwb.NWBFile, metadata: Optional[DeepDict] metadata["Ecephys"]["Device"] = [defaults] for device_metadata in metadata["Ecephys"]["Device"]: if device_metadata.get("name", defaults["name"]) not in nwbfile.devices: - nwbfile.create_device(**dict(defaults, **device_metadata)) + device_kwargs = dict(defaults, **device_metadata) + nwbfile.create_device(**device_kwargs) def add_electrode_groups(recording: BaseRecording, nwbfile: pynwb.NWBFile, metadata: dict = None): @@ -778,6 +779,28 @@ def add_electrical_series( ) +def _report_variable_offset(channel_offsets, channel_ids): + """ + Helper function to report variable offsets per channel IDs. + Groups the different available offsets per channel IDs and raises a ValueError. + """ + # Group the different offsets per channel IDs + offset_to_channel_ids = {} + for offset, channel_id in zip(channel_offsets, channel_ids): + if offset not in offset_to_channel_ids: + offset_to_channel_ids[offset] = [] + offset_to_channel_ids[offset].append(channel_id) + + # Create a user-friendly message + message_lines = ["Recording extractors with heterogeneous offsets are not supported."] + message_lines.append("Multiple offsets were found per channel IDs:") + for offset, ids in offset_to_channel_ids.items(): + message_lines.append(f" Offset {offset}: Channel IDs {ids}") + message = "\n".join(message_lines) + + raise ValueError(message) + + def add_electrical_series_to_nwbfile( recording: BaseRecording, nwbfile: pynwb.NWBFile, @@ -905,14 +928,16 @@ def add_electrical_series_to_nwbfile( # Spikeinterface guarantees data in micro volts when return_scaled=True. This multiplies by gain and adds offsets # In nwb to get traces in Volts we take data*channel_conversion*conversion + offset channel_conversion = recording.get_channel_gains() - channel_offset = recording.get_channel_offsets() + channel_offsets = recording.get_channel_offsets() unique_channel_conversion = np.unique(channel_conversion) unique_channel_conversion = unique_channel_conversion[0] if len(unique_channel_conversion) == 1 else None - unique_offset = np.unique(channel_offset) + unique_offset = np.unique(channel_offsets) if unique_offset.size > 1: - raise ValueError("Recording extractors with heterogeneous offsets are not supported") + channel_ids = recording.get_channel_ids() + # This prints a user friendly error where the user is provided with a map from offset to channels + _report_variable_offset(channel_offsets, channel_ids) unique_offset = unique_offset[0] if unique_offset[0] is not None else 0 micro_to_volts_conversion_factor = 1e-6 diff --git a/tests/test_ecephys/test_tools_spikeinterface.py b/tests/test_ecephys/test_tools_spikeinterface.py index 11c29b31f..3436a2e70 100644 --- a/tests/test_ecephys/test_tools_spikeinterface.py +++ b/tests/test_ecephys/test_tools_spikeinterface.py @@ -1,3 +1,4 @@ +import re import unittest from datetime import datetime from pathlib import Path @@ -8,9 +9,11 @@ import numpy as np import psutil import pynwb.ecephys +import pytest from hdmf.data_utils import DataChunkIterator from hdmf.testing import TestCase from pynwb import NWBFile +from pynwb.testing.mock.file import mock_NWBFile from spikeinterface.core.generate import ( generate_ground_truth_recording, generate_recording, @@ -394,6 +397,36 @@ def test_variable_offsets_assertion(self): ) +def test_error_with_multiple_offset(): + # Generate a mock recording with 5 channels and 1 second duration + recording = generate_recording(num_channels=5, durations=[1.0]) + # Rename channels to specific identifiers for clarity in error messages + recording = recording.rename_channels(new_channel_ids=["a", "b", "c", "d", "e"]) + # Set different offsets for the channels + recording.set_channel_offsets(offsets=[0, 0, 1, 1, 2]) + + # Create a mock NWBFile object + nwbfile = mock_NWBFile() + + # Expected error message + expected_message_lines = [ + "Recording extractors with heterogeneous offsets are not supported.", + "Multiple offsets were found per channel IDs:", + " Offset 0: Channel IDs ['a', 'b']", + " Offset 1: Channel IDs ['c', 'd']", + " Offset 2: Channel IDs ['e']", + ] + expected_message = "\n".join(expected_message_lines) + + # Use re.escape to escape any special regex characters in the expected message + expected_message_regex = re.escape(expected_message) + + # Attempt to add electrical series to the NWB file + # Expecting a ValueError due to multiple offsets, matching the expected message + with pytest.raises(ValueError, match=expected_message_regex): + add_electrical_series_to_nwbfile(recording=recording, nwbfile=nwbfile) + + class TestAddElectricalSeriesChunking(unittest.TestCase): @classmethod def setUpClass(cls):