Skip to content

Commit

Permalink
Adding nan method to interpolate channels (#12027)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
anaradanovic and drammock authored Sep 29, 2023
1 parent 04e05d4 commit 9f31cf3
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 12 deletions.
53 changes: 44 additions & 9 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,17 +830,19 @@ def interpolate_bads(
.. versionadded:: 0.17
method : dict | None
Method to use for each channel type.
Currently only the key ``"eeg"`` has multiple options:
All channel types support "nan".
The key ``"eeg"`` has two additional options:
- ``"spline"`` (default)
Use spherical spline interpolation.
- ``"MNE"``
Use minimum-norm projection to a sphere and back.
This is the method used for MEG channels.
The value for ``"meg"`` is ``"MNE"``, and the value for
``"fnirs"`` is ``"nearest"``. The default (None) is thus an alias
for::
The default value for ``"meg"`` is ``"MNE"``, and the default value
for ``"fnirs"`` is ``"nearest"``.
The default (None) is thus an alias for::
method=dict(meg="MNE", eeg="spline", fnirs="nearest")
Expand All @@ -858,6 +860,10 @@ def interpolate_bads(
Notes
-----
.. versionadded:: 0.9.0
.. warning::
Be careful when using ``method="nan"``; the default value
``reset_bads=True`` may not be what you want.
"""
from .interpolation import (
_interpolate_bads_eeg,
Expand All @@ -869,9 +875,31 @@ def interpolate_bads(
method = _handle_default("interpolation_method", method)
for key in method:
_check_option("method[key]", key, ("meg", "eeg", "fnirs"))
_check_option("method['eeg']", method["eeg"], ("spline", "MNE"))
_check_option("method['meg']", method["meg"], ("MNE",))
_check_option("method['fnirs']", method["fnirs"], ("nearest",))
_check_option(
"method['eeg']",
method["eeg"],
(
"spline",
"MNE",
"nan",
),
)
_check_option(
"method['meg']",
method["meg"],
(
"MNE",
"nan",
),
)
_check_option(
"method['fnirs']",
method["fnirs"],
(
"nearest",
"nan",
),
)

if len(self.info["bads"]) == 0:
warn("No bad channels to interpolate. Doing nothing...")
Expand All @@ -884,11 +912,18 @@ def interpolate_bads(
else:
eeg_mne = True
_interpolate_bads_meeg(
self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude
self, mode=mode, origin=origin, eeg=eeg_mne, exclude=exclude, method=method
)
_interpolate_bads_nirs(self, exclude=exclude)
_interpolate_bads_nirs(self, exclude=exclude, method=method["fnirs"])

if reset_bads is True:
if "nan" in method.values():
warn(
"interpolate_bads was called with method='nan' and "
"reset_bads=True. Consider setting reset_bads=False so that the "
"nan-containing channels can be easily excluded from later "
"computations."
)
self.info["bads"] = [ch for ch in self.info["bads"] if ch in exclude]

return self
Expand Down
20 changes: 17 additions & 3 deletions mne/channels/interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Authors: Denis Engemann <[email protected]>
# Ana Radanovic <[email protected]>
#
# License: BSD-3-Clause

Expand Down Expand Up @@ -191,10 +192,14 @@ def _interpolate_bads_meeg(
eeg=True,
ref_meg=False,
exclude=(),
*,
method=None,
verbose=None,
):
from ..forward import _map_meg_or_eeg_channels

if method is None:
method = {"meg": "MNE", "eeg": "MNE"}
bools = dict(meg=meg, eeg=eeg)
info = _simplify_info(inst.info)
for ch_type, do in bools.items():
Expand All @@ -210,6 +215,12 @@ def _interpolate_bads_meeg(
continue
# select the bad channels to be interpolated
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[])

if method[ch_type] == "nan":
inst._data[picks_bad] = np.nan
continue

# do MNE based interpolation
if ch_type == "eeg":
picks_to = picks_type
bad_sel = np.isin(picks_type, picks_bad)
Expand Down Expand Up @@ -243,7 +254,7 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
chs = [inst.info["chs"][i] for i in picks_nirs]
locs3d = np.array([ch["loc"][:3] for ch in chs])

_check_option("fnirs_method", method, ["nearest"])
_check_option("fnirs_method", method, ["nearest", "nan"])

if method == "nearest":
dist = pdist(locs3d)
Expand All @@ -258,7 +269,10 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None):
# Find closest remaining channels for same frequency
closest_idx = np.argmin(dists_to_bad) + (bad % 2)
inst._data[bad] = inst._data[closest_idx]

inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]
else:
assert method == "nan"
inst._data[picks_bad] = np.nan
# TODO: this seems like a bug because it does not respect reset_bads
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude]

return inst
25 changes: 25 additions & 0 deletions mne/channels/tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from mne._fiff.proj import _has_eeg_average_ref_proj
from mne.utils import _record_warnings


base_dir = Path(__file__).parent.parent.parent / "io" / "tests" / "data"
raw_fname = base_dir / "test_raw.fif"
event_name = base_dir / "test-eve.fif"
Expand Down Expand Up @@ -324,3 +325,27 @@ def test_interpolation_nirs():
assert raw_haemo.info["bads"] == ["S1_D2 hbo", "S1_D2 hbr"]
raw_haemo.interpolate_bads()
assert raw_haemo.info["bads"] == []


def test_nan_interpolation(raw):
"""Test 'nan' method for interpolating bads."""
ch_to_interp = [raw.ch_names[1]] # don't use channel 0 (type is IAS not MEG)
raw.info["bads"] = ch_to_interp

# test that warning appears for reset_bads = True
with pytest.warns(RuntimeWarning, match="Consider setting reset_bads=False"):
raw.interpolate_bads(method="nan", reset_bads=True)

# despite warning, interpolation still happened, make sure the channel is NaN
bad_chs = raw.get_data(ch_to_interp)
assert np.isnan(bad_chs).all()

# make sure reset_bads=False works as expected
raw.info["bads"] = ch_to_interp
raw.interpolate_bads(method="nan", reset_bads=False)
assert raw.info["bads"] == ch_to_interp

# make sure other channels are untouched
raw.drop_channels(ch_to_interp)
good_chs = raw.get_data()
assert np.isfinite(good_chs).all()

0 comments on commit 9f31cf3

Please sign in to comment.