Skip to content

Commit

Permalink
BUG: Fix bug with spectrum warning (mne-tools#12186)
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner authored and snwnde committed Mar 20, 2024
1 parent bb54cb3 commit b79ebd6
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Bugs
- Fix bug with multi-plot 3D rendering where only one plot was updated (:gh:`11896` by `Eric Larson`_)
- Fix bug where ``verbose`` level was not respected inside parallel jobs (:gh:`12154` by `Eric Larson`_)
- Fix bug where subject birthdays were not correctly read by :func:`mne.io.read_raw_snirf` (:gh:`11912` by `Eric Larson`_)
- Fix bug where warnings were emitted when computing spectra for channels marked as bad (:gh:`12186` by `Eric Larson`_)
- Fix bug with :func:`mne.chpi.compute_head_pos` for CTF data where digitization points were modified in-place, producing an incorrect result during a save-load round-trip (:gh:`11934` by `Eric Larson`_)
- Fix bug where non-compliant stimulus data streams were not ignored by :func:`mne.io.read_raw_snirf` (:gh:`11915` by `Johann Benerradi`_)
- Fix bug with ``pca=False`` in :func:`mne.minimum_norm.compute_source_psd` (:gh:`11927` by `Alex Gramfort`_)
Expand Down
12 changes: 7 additions & 5 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,17 @@ def _check_values(self):
"""Check PSD results for correct shape and bad values."""
assert len(self._dims) == self._data.ndim, (self._dims, self._data.ndim)
assert self._data.shape == self._shape
# negative values OK if the spectrum is really fourier coefficients
if "taper" in self._dims:
return
# TODO: should this be more fine-grained (report "chan X in epoch Y")?
ch_dim = self._dims.index("channel")
dims = np.arange(self._data.ndim).tolist()
dims = list(range(self._data.ndim))
dims.pop(ch_dim)
# take min() across all but the channel axis
bad_value = self._data.min(axis=tuple(dims)) <= 0
# (if the abs becomes memory intensive we could iterate over channels)
use_data = self._data
if use_data.dtype.kind == "c":
use_data = np.abs(use_data)
bad_value = use_data.min(axis=tuple(dims)) == 0
bad_value &= ~np.isin(self.ch_names, self.info["bads"])
if bad_value.any():
chs = np.array(self.ch_names)[bad_value].tolist()
s = _pl(bad_value.sum())
Expand Down
13 changes: 8 additions & 5 deletions mne/time_frequency/tests/test_spectrum.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from contextlib import nullcontext
from functools import partial

import numpy as np
Expand Down Expand Up @@ -359,7 +358,6 @@ def test_spectrum_complex(method, average):
assert len(epochs) == 5
assert len(epochs.times) == 2 * sfreq
kwargs = dict(output="complex", method=method)
ctx = pytest.warns(UserWarning, match="Zero value")
if method == "welch":
kwargs["n_fft"] = sfreq
want_dims = ("epoch", "channel", "freq")
Expand All @@ -371,11 +369,9 @@ def test_spectrum_complex(method, average):
else:
assert method == "multitaper"
assert not average
ctx = nullcontext()
want_dims = ("epoch", "channel", "taper", "freq")
want_shape = (5, 1, 7, sfreq + 1)
with ctx:
spectrum = epochs.compute_psd(**kwargs)
spectrum = epochs.compute_psd(**kwargs)
idx = np.argmin(np.abs(spectrum.freqs - freq))
assert spectrum.freqs[idx] == freq
assert spectrum._dims == want_dims
Expand All @@ -389,6 +385,13 @@ def test_spectrum_complex(method, average):
coef = coef.mean(-1) # over segments
coef = coef.item()
assert_allclose(np.angle(coef), phase, rtol=1e-4)
# Now test that it warns appropriately
epochs._data[0, 0, :] = 0 # actually zero for one epoch and ch
with pytest.warns(UserWarning, match="Zero value.*channel 0"):
epochs.compute_psd(**kwargs)
# But not if we mark that channel as bad
epochs.info["bads"] = epochs.ch_names[:1]
epochs.compute_psd(**kwargs)


def test_spectrum_kwarg_triaging(raw):
Expand Down
3 changes: 2 additions & 1 deletion tutorials/preprocessing/50_artifact_correction_ssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@
# individual spectrum for each sensor, or an average (with confidence band)
# across sensors:

spectrum = empty_room_raw.compute_psd(verbose="error") # ignore zero value warning
raw.info["bads"] = ["MEG 2443"]
spectrum = empty_room_raw.compute_psd()
for average in (False, True):
spectrum.plot(average=average, dB=False, xscale="log", picks="data", exclude="bads")

Expand Down

0 comments on commit b79ebd6

Please sign in to comment.