-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding nan method to interpolate channels (#12027)
Co-authored-by: Daniel McCloy <[email protected]>
- Loading branch information
1 parent
04e05d4
commit 9f31cf3
Showing
3 changed files
with
86 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
# Authors: Denis Engemann <[email protected]> | ||
# Ana Radanovic <[email protected]> | ||
# | ||
# License: BSD-3-Clause | ||
|
||
|
@@ -191,10 +192,14 @@ def _interpolate_bads_meeg( | |
eeg=True, | ||
ref_meg=False, | ||
exclude=(), | ||
*, | ||
method=None, | ||
verbose=None, | ||
): | ||
from ..forward import _map_meg_or_eeg_channels | ||
|
||
if method is None: | ||
method = {"meg": "MNE", "eeg": "MNE"} | ||
bools = dict(meg=meg, eeg=eeg) | ||
info = _simplify_info(inst.info) | ||
for ch_type, do in bools.items(): | ||
|
@@ -210,6 +215,12 @@ def _interpolate_bads_meeg( | |
continue | ||
# select the bad channels to be interpolated | ||
picks_bad = pick_channels(inst.info["ch_names"], bads_type, exclude=[]) | ||
|
||
if method[ch_type] == "nan": | ||
inst._data[picks_bad] = np.nan | ||
continue | ||
|
||
# do MNE based interpolation | ||
if ch_type == "eeg": | ||
picks_to = picks_type | ||
bad_sel = np.isin(picks_type, picks_bad) | ||
|
@@ -243,7 +254,7 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): | |
chs = [inst.info["chs"][i] for i in picks_nirs] | ||
locs3d = np.array([ch["loc"][:3] for ch in chs]) | ||
|
||
_check_option("fnirs_method", method, ["nearest"]) | ||
_check_option("fnirs_method", method, ["nearest", "nan"]) | ||
|
||
if method == "nearest": | ||
dist = pdist(locs3d) | ||
|
@@ -258,7 +269,10 @@ def _interpolate_bads_nirs(inst, method="nearest", exclude=(), verbose=None): | |
# Find closest remaining channels for same frequency | ||
closest_idx = np.argmin(dists_to_bad) + (bad % 2) | ||
inst._data[bad] = inst._data[closest_idx] | ||
|
||
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] | ||
else: | ||
assert method == "nan" | ||
inst._data[picks_bad] = np.nan | ||
# TODO: this seems like a bug because it does not respect reset_bads | ||
inst.info["bads"] = [ch for ch in inst.info["bads"] if ch in exclude] | ||
|
||
return inst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters