diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 7c3de44fdd8..b6c82f27be2 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -830,7 +830,8 @@ def interpolate_bads( .. versionadded:: 0.17 method : dict | None Method to use for each channel type. - Currently only the key ``"eeg"`` has multiple options: + All channel types support "nan". + The key ``"eeg"`` has two additional options: - ``"spline"`` (default) Use spherical spline interpolation. @@ -838,9 +839,10 @@ def interpolate_bads( Use minimum-norm projection to a sphere and back. This is the method used for MEG channels. - The value for ``"meg"`` is ``"MNE"``, and the value for - ``"fnirs"`` is ``"nearest"``. The default (None) is thus an alias - for:: + The default value for ``"meg"`` is ``"MNE"``, and the default value + for ``"fnirs"`` is ``"nearest"``. + + The default (None) is thus an alias for:: method=dict(meg="MNE", eeg="spline", fnirs="nearest") @@ -858,6 +860,10 @@ def interpolate_bads( Notes ----- .. versionadded:: 0.9.0 + + .. warning:: + Be careful when using ``method="nan"``; the default value + ``reset_bads=True`` may not be what you want. """ from .interpolation import ( _interpolate_bads_eeg, @@ -869,9 +875,31 @@ def interpolate_bads( method = _handle_default("interpolation_method", method) for key in method: _check_option("method[key]", key, ("meg", "eeg", "fnirs")) - _check_option("method['eeg']", method["eeg"], ("spline", "MNE")) - _check_option("method['meg']", method["meg"], ("MNE",)) - _check_option("method['fnirs']", method["fnirs"], ("nearest",)) + _check_option( + "method['eeg']", + method["eeg"], + ( + "spline", + "MNE", + "nan", + ), + ) + _check_option( + "method['meg']", + method["meg"], + ( + "MNE", + "nan", + ), + ) + _check_option( + "method['fnirs']", + method["fnirs"], + ( + "nearest", + "nan", + ), + ) if len(self.info["bads"]) == 0: warn("No bad channels to interpolate. Doing nothing...") @@ -884,11 +912,18 @@ def interpolate_bads( else: eeg_mne = True _interpolate_bads_meeg( - self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude + self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude, method=method ) - _interpolate_bads_nirs(self, exclude=exclude) + _interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"]) if reset_bads is True: + if "nan" in method.values(): + warn( + "interpolate_bads was called with method='nan' and " + "reset_bads=True. Consider setting reset_bads=False so that the " + "nan-containing channels can be easily excluded from later " + "computations." + ) self.info["bads"] = [ch for ch in self.info["bads"] if ch in exclude] return self diff --git a/mne/channels/interpolation.py b/mne/channels/interpolation.py index 5dc84cad538..8f5418f9b85 100644 --- a/mne/channels/interpolation.py +++ b/mne/channels/interpolation.py @@ -1,4 +1,5 @@ # Authors: Denis Engemann +# Ana Radanovic # # License: BSD-3-Clause @@ -191,10 +192,14 @@ def _interpolate_bads_meeg( eeg=True, ref_meg=False, exclude=(), + *, + method=None, verbose=None, ): from ..forward import _map_meg_or_eeg_channels + if method is None: + method = {"meg": "MNE", "eeg": "MNE"} bools = dict(meg=meg, eeg=eeg) info = _simplify_info(inst.info) for ch_type, do in bools.items(): @@ -210,6 +215,12 @@ def _interpolate_bads_meeg( continue # select the bad channels to be interpolated picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) + + if method[ch_type] == "nan": + inst._data[picks_bad] = np.nan + continue + + # do MNE based interpolation if ch_type == "eeg": picks_to = picks_type bad_sel = np.isin(picks_type, picks_bad) @@ -243,7 +254,7 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): chs = [inst.info["chs"][i] for i in picks_nirs] locs3d = np.array([ch["loc"][:3] for ch in chs]) - _check_option("fnirs_method", method, ["nearest"]) + _check_option("fnirs_method", method, ["nearest", "nan"]) if method == "nearest": dist = pdist(locs3d) @@ -258,7 +269,10 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): # Find closest remaining channels for same frequency closest_idx = np.argmin(dists_to_bad) + (bad % 2) inst._data[bad] = inst._data[closest_idx] - - inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] + else: + assert method == "nan" + inst._data[picks_bad] = np.nan + # TODO: this seems like a bug because it does not respect reset_bads + inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] return inst diff --git a/mne/channels/tests/test_interpolation.py b/mne/channels/tests/test_interpolation.py index e6b8f9698a0..9e8032c915b 100644 --- a/mne/channels/tests/test_interpolation.py +++ b/mne/channels/tests/test_interpolation.py @@ -17,6 +17,7 @@ from mne._fiff.proj import _has_eeg_average_ref_proj from mne.utils import _record_warnings + base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data" raw_fname = base_dir / "test_raw.fif" event_name = base_dir / "test-eve.fif" @@ -324,3 +325,27 @@ def test_interpolation_nirs(): assert raw_haemo.info["bads"] == ["S1_D2 hbo", "S1_D2 hbr"] raw_haemo.interpolate_bads() assert raw_haemo.info["bads"] == [] + + +def test_nan_interpolation(raw): + """Test 'nan' method for interpolating bads.""" + ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG) + raw.info["bads"] = ch_to_interp + + # test that warning appears for reset_bads = True + with pytest.warns(RuntimeWarning, match="Consider setting reset_bads=False"): + raw.interpolate_bads(method="nan", reset_bads=True) + + # despite warning, interpolation still happened, make sure the channel is NaN + bad_chs = raw.get_data(ch_to_interp) + assert np.isnan(bad_chs).all() + + # make sure reset_bads=False works as expected + raw.info["bads"] = ch_to_interp + raw.interpolate_bads(method="nan", reset_bads=False) + assert raw.info["bads"] == ch_to_interp + + # make sure other channels are untouched + raw.drop_channels(ch_to_interp) + good_chs = raw.get_data() + assert np.isfinite(good_chs).all()