diff --git a/doc/preprocessing.rst b/doc/preprocessing.rst index 7c4e430780f..cbfda1ac49b 100644 --- a/doc/preprocessing.rst +++ b/doc/preprocessing.rst @@ -57,6 +57,7 @@ Projections: get_builtin_ch_adjacencies read_ch_adjacency equalize_channels + unify_bad_channels rename_channels generate_2d_layout make_1020_channel_selections diff --git a/mne/channels/__init__.py b/mne/channels/__init__.py index 13b002e5a59..3591d7aeeb4 100644 --- a/mne/channels/__init__.py +++ b/mne/channels/__init__.py @@ -22,6 +22,7 @@ "_EEG_SELECTIONS", "_divide_to_regions", "get_builtin_ch_adjacencies", + "unify_bad_channels", ], "layout": [ "Layout", diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 0b022569ef5..7c3de44fdd8 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -5,6 +5,8 @@ # Andrew Dykstra # Teon Brooks # Daniel McCloy +# Ana Radanovic +# Erica Peterson # # License: BSD-3-Clause @@ -206,6 +208,83 @@ def equalize_channels(instances, copy=True, verbose=None): return equalized_instances +def unify_bad_channels(insts): + """Unify bad channels across a list of instances. + + All instances must be of the same type and have matching channel names and channel + order. The ``.info["bads"]`` of each instance will be set to the union of + ``.info["bads"]`` across all instances. + + Parameters + ---------- + insts : list + List of instances (:class:`~mne.io.Raw`, :class:`~mne.Epochs`, + :class:`~mne.Evoked`, :class:`~mne.time_frequency.Spectrum`, + :class:`~mne.time_frequency.EpochsSpectrum`) across which to unify bad channels. + + Returns + ------- + insts : list + List of instances with bad channels unified across instances. + + See Also + -------- + mne.channels.equalize_channels + mne.channels.rename_channels + mne.channels.combine_channels + + Notes + ----- + This function modifies the instances in-place. + + .. versionadded:: 1.6 + """ + from ..io import BaseRaw + from ..epochs import Epochs + from ..evoked import Evoked + from ..time_frequency.spectrum import BaseSpectrum + + # ensure input is list-like + _validate_type(insts, (list, tuple), "insts") + # ensure non-empty + if len(insts) == 0: + raise ValueError("insts must not be empty") + # ensure all insts are MNE objects, and all the same type + inst_type = type(insts[0]) + valid_types = (BaseRaw, Epochs, Evoked, BaseSpectrum) + for inst in insts: + _validate_type(inst, valid_types, "each object in insts") + if type(inst) != inst_type: + raise ValueError("All insts must be the same type") + + # ensure all insts have the same channels and channel order + ch_names = insts[0].ch_names + for inst in insts[1:]: + dif = set(inst.ch_names) ^ set(ch_names) + if len(dif): + raise ValueError( + "Channels do not match across the objects in insts. Consider calling " + "equalize_channels before calling this function." + ) + elif inst.ch_names != ch_names: + raise ValueError( + "Channel names are sorted differently across instances. Please use " + "mne.channels.equalize_channels." + ) + + # collect bads as dict keys so that insertion order is preserved, then cast to list + all_bads = dict() + for inst in insts: + all_bads.update(dict.fromkeys(inst.info["bads"])) + all_bads = list(all_bads) + + # update bads on all instances + for inst in insts: + inst.info["bads"] = all_bads + + return insts + + class ReferenceMixin(MontageMixin): """Mixin class for Raw, Evoked, Epochs.""" diff --git a/mne/channels/tests/test_unify_bads.py b/mne/channels/tests/test_unify_bads.py new file mode 100644 index 00000000000..ac04983802b --- /dev/null +++ b/mne/channels/tests/test_unify_bads.py @@ -0,0 +1,53 @@ +import pytest +from mne.channels import unify_bad_channels + + +def test_error_raising(raw, epochs): + """Tests input checking.""" + with pytest.raises(TypeError, match=r"must be an instance of list"): + unify_bad_channels("bad input") + with pytest.raises(ValueError, match=r"insts must not be empty"): + unify_bad_channels([]) + with pytest.raises(TypeError, match=r"each object in insts must be an instance of"): + unify_bad_channels(["bad_instance"]) + with pytest.raises(ValueError, match=r"same type"): + unify_bad_channels([raw, epochs]) + with pytest.raises(ValueError, match=r"Channels do not match across"): + raw_alt1 = raw.copy() + raw_alt1.drop_channels(raw.info["ch_names"][-1]) + unify_bad_channels([raw, raw_alt1]) # ch diff preserving order + with pytest.raises(ValueError, match=r"sorted differently"): + raw_alt2 = raw.copy() + new_order = [raw.ch_names[-1]] + raw.ch_names[:-1] + raw_alt2.reorder_channels(new_order) + unify_bad_channels([raw, raw_alt2]) + + +def test_bads_compilation(raw): + """Tests that bads are compiled properly. + + Tests two cases: a) single instance passed to function with an existing + bad, and b) multiple instances passed to function with varying compilation + scenarios including empty bads, unique bads, and partially duplicated bads + listed out-of-order. + + Only the Raw instance type is tested, since bad channel implementation is + controlled across instance types with a MixIn class. + """ + assert raw.info["bads"] == [] + chns = raw.ch_names[:3] + no_bad = raw.copy() + one_bad = raw.copy() + one_bad.info["bads"] = [chns[1]] + three_bad = raw.copy() + three_bad.info["bads"] = chns + # scenario 1: single instance passed with actual bads + s_out = unify_bad_channels([one_bad]) + assert len(s_out) == 1, len(s_out) + assert s_out[0].info["bads"] == [chns[1]], (s_out[0].info["bads"], chns[1]) + # scenario 2: multiple instances passed + m_out = unify_bad_channels([one_bad, no_bad, three_bad]) + assert len(m_out) == 3, len(m_out) + expected_order = [chns[1], chns[0], chns[2]] + for inst in m_out: + assert inst.info["bads"] == expected_order