From b08f759efd08fa060e68f0bdee45f45e3fc5347a Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 6 Aug 2024 22:02:24 +0200 Subject: [PATCH] [ENH] (Re)implement complex data support for `Spectrum` and `SpectrumArray` classes (#12747) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Larson Co-authored-by: Daniel McCloy Co-authored-by: Alex Rockhill --- doc/changes/devel/12747.newfeature.rst | 3 + mne/time_frequency/spectrum.py | 272 +++++++++++++----- mne/time_frequency/tests/test_spectrum.py | 326 +++++++++++++++++++--- mne/utils/docs.py | 12 +- mne/utils/spectrum.py | 4 + 5 files changed, 511 insertions(+), 106 deletions(-) create mode 100644 doc/changes/devel/12747.newfeature.rst diff --git a/doc/changes/devel/12747.newfeature.rst b/doc/changes/devel/12747.newfeature.rst new file mode 100644 index 00000000000..2957117b778 --- /dev/null +++ b/doc/changes/devel/12747.newfeature.rst @@ -0,0 +1,3 @@ +Add support for storing Fourier coefficients in :class:`mne.time_frequency.Spectrum`, +:class:`mne.time_frequency.EpochsSpectrum`, :class:`mne.time_frequency.SpectrumArray`, +and :class:`mne.time_frequency.EpochsSpectrumArray` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index 45dadf9741a..902d1c70e30 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -43,7 +43,7 @@ _is_numeric, check_fname, ) -from ..utils.misc import _identity_function, _pl +from ..utils.misc import _pl from ..utils.spectrum import _get_instance_type_string, _split_psd_kwargs from ..viz.topo import _plot_timeseries, _plot_timeseries_unified, _plot_topo from ..viz.topomap import _make_head_outlines, _prepare_topomap_plot, plot_psds_topomap @@ -55,7 +55,7 @@ _prepare_sensor_names, plt_show, ) -from .multitaper import psd_array_multitaper +from .multitaper import _psd_from_mt, psd_array_multitaper from .psd import _check_nfft, psd_array_welch @@ -314,13 +314,7 @@ def __init__( # method self._inst_type = type(inst) method = _validate_method(method, _get_instance_type_string(self)) - # don't allow complex output psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) - if method_kw.get("output", "") == "complex": - raise ValueError( - f"Complex output is not supported in {type(self).__name__} objects. " - f"Please use mne.time_frequency.{psd_funcs[method].__name__}() instead." - ) # triage method and kwargs. partial() doesn't check validity of kwargs, # so we do it manually to save compute time if any are invalid. psd_funcs = dict(welch=psd_array_welch, multitaper=psd_array_multitaper) @@ -352,9 +346,13 @@ def __init__( ) if method_kw.get("average", "") in (None, False): self._dims += ("segment",) + if self._returns_complex_tapers(**method_kw): + self._dims = self._dims[:-1] + ("taper",) + self._dims[-1:] # record data type (for repr and html_repr) self._data_type = ( - "Fourier Coefficients" if "taper" in self._dims else "Power Spectrum" + "Fourier Coefficients" + if method_kw.get("output") == "complex" + else "Power Spectrum" ) # set nave (child constructor overrides this for Evoked input) self._nave = None @@ -376,6 +374,7 @@ def __getstate__(self): data_type=self._data_type, info=self.info, nave=self.nave, + weights=self.weights, ) return out @@ -393,6 +392,7 @@ def __setstate__(self, state): self.info = Info(**state["info"]) self._data_type = state["data_type"] self._nave = state.get("nave") # objs saved before #11282 won't have `nave` + self._weights = state.get("weights") # objs saved before #12747 won't have self.preload = True # instance type inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray) @@ -440,22 +440,35 @@ def _check_values(self): s = _pl(bad_value.sum()) warn(f'Zero value in spectrum for channel{s} {", ".join(chs)}', UserWarning) + def _returns_complex_tapers(self, **method_kw): + return self.method == "multitaper" and method_kw.get("output") == "complex" + def _compute_spectra(self, data, fmin, fmax, n_jobs, method_kw, verbose): # make the spectra result = self._psd_func( data, self.sfreq, fmin=fmin, fmax=fmax, n_jobs=n_jobs, verbose=verbose ) - # assign ._data ._freqs, ._shape - psds, freqs = result - self._data = psds + # assign ._data (handling unaggregated multitaper output) + if self._returns_complex_tapers(**method_kw): + fourier_coefs, freqs, weights = result + self._data = fourier_coefs + self._weights = weights + else: + psds, freqs = result + self._data = psds + self._weights = None + # assign properties (._data already assigned above) self._freqs = freqs # this is *expected* shape, it gets asserted later in _check_values() # (and then deleted afterwards) self._shape = (len(self.ch_names), len(self.freqs)) - # append n_welch_segments + # append n_welch_segments (use "" as .get() default since None considered valid) if method_kw.get("average", "") in (None, False): n_welch_segments = _compute_n_welch_segments(data.shape[-1], method_kw) self._shape += (n_welch_segments,) + # insert n_tapers + if self._returns_complex_tapers(**method_kw): + self._shape = self._shape[:-1] + (self._weights.size,) + self._shape[-1:] # we don't need these anymore, and they make save/load harder del self._picks del self._psd_func @@ -486,6 +499,10 @@ def method(self): def nave(self): return self._nave + @property + def weights(self): + return self._weights + @property def sfreq(self): return self._sfreq @@ -643,34 +660,13 @@ def plot( (picks_list, units_list, scalings_list, titles_list) = _split_picks_by_type( self, picks, units, scalings, titles ) - # handle unaggregated Welch - if "segment" in self._dims: - logger.info("Aggregating Welch estimates (median) before plotting...") - seg_axis = self._dims.index("segment") - _f = partial(np.nanmedian, axis=seg_axis) - else: # "normal" cases - _f = _identity_function - ch_axis = self._dims.index("channel") - psd_list = [_f(self._data.take(_p, axis=ch_axis)) for _p in picks_list] - # handle epochs - if "epoch" in self._dims: - # XXX TODO FIXME decide how to properly aggregate across repeated - # measures (epochs) and non-repeated but correlated measures - # (channels) when calculating stddev or a CI. For across-channel - # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 - # with a correction factor that estimated data rank using monte - # carlo simulations; seems like we could use our own data rank - # estimation methods to similar effect. Their exact approach used - # complex spectra though, here we've already converted to power; - # not sure if that makes an important difference? Anyway that - # aggregation would need to happen in the _plot_psd function - # though, not here... for now we just average like we always did. - - # only log message if averaging will actually have an effect - if self._data.shape[0] > 1: - logger.info("Averaging across epochs...") - # epoch axis should always be the first axis - psd_list = [_p.mean(axis=0) for _p in psd_list] + # prepare data (e.g. aggregate across dims, convert complex to power) + psd_list = [ + self._prepare_data_for_plot( + self._data.take(_p, axis=self._dims.index("channel")) + ) + for _p in picks_list + ] # initialize figure fig, axes = _line_figure(self, axes, picks=picks) # don't add ylabels & titles if figure has unexpected number of axes @@ -739,8 +735,8 @@ def plot_topo( layout = find_layout(self.info) psds, freqs = self.get_data(return_freqs=True) - if "epoch" in self._dims: - psds = np.mean(psds, axis=self._dims.index("epoch")) + # prepare data (e.g. aggregate across dims, convert complex to power) + psds = self._prepare_data_for_plot(psds) if dB: psds = 10 * np.log10(psds) y_label = "dB" @@ -852,8 +848,8 @@ def plot_topomap( outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) psds, freqs = self.get_data(picks=picks, return_freqs=True) - if "epoch" in self._dims: - psds = np.mean(psds, axis=self._dims.index("epoch")) + # prepare data (e.g. aggregate across dims, convert complex to power) + psds = self._prepare_data_for_plot(psds) psds *= scaling**2 if merge_channels: @@ -891,6 +887,42 @@ def plot_topomap( show=show, ) + def _prepare_data_for_plot(self, data): + # handle unaggregated Welch + if "segment" in self._dims: + logger.info("Aggregating Welch estimates (median) before plotting...") + data = np.nanmedian(data, axis=self._dims.index("segment")) + # handle unaggregated multitaper (also handles complex -> power) + elif "taper" in self._dims: + logger.info("Aggregating multitaper estimates before plotting...") + data = _psd_from_mt(data, self.weights) + + # handle complex data (should only be Welch remaining) + if np.iscomplexobj(data): + data = (data * data.conj()).real # Scaling may be slightly off + + # handle epochs + if "epoch" in self._dims: + # XXX TODO FIXME decide how to properly aggregate across repeated + # measures (epochs) and non-repeated but correlated measures + # (channels) when calculating stddev or a CI. For across-channel + # aggregation, doi:10.1007/s10162-012-0321-8 used hotellings T**2 + # with a correction factor that estimated data rank using monte + # carlo simulations; seems like we could use our own data rank + # estimation methods to similar effect. Their exact approach used + # complex spectra though, here we've already converted to power; + # not sure if that makes an important difference? Anyway that + # aggregation would need to happen in the _plot_psd function + # though, not here... for now we just average like we always did. + + # only log message if averaging will actually have an effect + if data.shape[0] > 1: + logger.info("Averaging across epochs before plotting...") + # epoch axis should always be the first axis + data = data.mean(axis=0) + + return data + @verbose def save(self, fname, *, overwrite=False, verbose=None): """Save spectrum data to disk (in HDF5 format). @@ -1057,11 +1089,16 @@ class Spectrum(BaseSpectrum): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : str - The method used to compute the spectrum (``'welch'`` or ``'multitaper'``). + method : ``'welch'``| ``'multitaper'`` + The method used to compute the spectrum. nave : int | None The number of trials averaged together when generating the spectrum. ``None`` indicates no averaging is known to have occurred. + weights : array | None + The weights for each taper. Only present if spectra computed with + ``method='multitaper'`` and ``output='complex'``. + + .. versionadded:: 1.8 See Also -------- @@ -1179,21 +1216,59 @@ def __getitem__(self, item): return BaseRaw._getitem(self, item, return_times=False) -def _check_data_shape(data, freqs, info, ndim): - if data.ndim != ndim: - raise ValueError(f"Data must be a {ndim}D array.") +def _check_data_shape(data, info, freqs, dim_names, weights, is_epoched): + if data.ndim != len(dim_names): + raise ValueError( + f"Expected data to have {len(dim_names)} dimensions, got {data.ndim}." + ) + + allowed_dims = ["epoch", "channel", "freq", "segment", "taper"] + if not is_epoched: + allowed_dims.remove("epoch") + # TODO maybe we should be nice and allow plural versions of each dimname? + for dim in dim_names: + _check_option("dim_names", dim, allowed_dims) + if "channel" not in dim_names or "freq" not in dim_names: + raise ValueError("Both 'channel' and 'freq' must be present in `dim_names`.") + + if list(dim_names).index("channel") != int(is_epoched): + raise ValueError( + f"'channel' must be the {'second' if is_epoched else 'first'} dimension of " + "the data." + ) want_n_chan = _pick_data_channels(info).size - want_n_freq = freqs.size - got_n_chan, got_n_freq = data.shape[-2:] + got_n_chan = data.shape[list(dim_names).index("channel")] if got_n_chan != want_n_chan: raise ValueError( - f"The number of channels in `data` ({got_n_chan}) must match the " - f"number of good data channels in `info` ({want_n_chan})." + f"The number of channels in `data` ({got_n_chan}) must match the number of " + f"good data channels in `info` ({want_n_chan})." ) + + # given we limit max array size and ensure channel & freq dims present, only one of + # taper or segment can be present + if "taper" in dim_names: + if dim_names[-2] != "taper": # _psd_from_mt assumes this (called when plotting) + raise ValueError( + "'taper' must be the second to last dimension of the data." + ) + # expect weights for each taper + actual = None if weights is None else weights.size + expected = data.shape[list(dim_names).index("taper")] + if actual != expected: + raise ValueError( + f"Expected size of `weights` to be {expected} to match 'n_tapers' in " + f"`data`, got {actual}." + ) + elif "segment" in dim_names and dim_names[-1] != "segment": + raise ValueError("'segment' must be the last dimension of the data.") + + # freq being in wrong position ruled out by above checks + want_n_freq = freqs.size + got_n_freq = data.shape[list(dim_names).index("freq")] if got_n_freq != want_n_freq: raise ValueError( - f"The last dimension of `data` ({got_n_freq}) must have the same " - f"number of elements as `freqs` ({want_n_freq})." + f"The number of frequencies in `data` ({got_n_freq}) must match the number " + f"of elements in `freqs` ({want_n_freq})." ) @@ -1203,10 +1278,22 @@ class SpectrumArray(Spectrum): Parameters ---------- - data : array, shape (n_channels, n_freqs) - The power spectral density for each channel. + data : ndarray, shape (n_channels, [n_tapers], n_freqs, [n_segments]) + The spectra for each channel. %(info_not_none)s %(freqs_tfr_array)s + dim_names : tuple of str + The name of the dimensions in the data, in the order they occur. Must contain + ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include + either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., + multitaper algorithms) dimension. If including ``'taper'``, you should also pass + a ``weights`` parameter. + + .. versionadded:: 1.8 + weights : ndarray | None + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1229,21 +1316,31 @@ def __init__( data, info, freqs, + dim_names=("channel", "freq"), + weights=None, *, verbose=None, ): - _check_data_shape(data, freqs, info, ndim=2) + # (channel, [taper], freq, [segment]) + _check_option("data.ndim", data.ndim, (2, 3)) # only allow one extra dimension + + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=False) self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=("channel", "freq"), + dims=dim_names, freqs=freqs, inst_type_str="Array", - data_type="Power Spectrum", + data_type=( + "Fourier Coefficients" + if np.iscomplexobj(data) + else "Power Spectrum" + ), info=info, + weights=weights, ) ) @@ -1279,8 +1376,13 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin): Frequencies at which the amplitude, power, or fourier coefficients have been computed. %(info_not_none)s - method : str - The method used to compute the spectrum ('welch' or 'multitaper'). + method : ``'welch'``| ``'multitaper'`` + The method used to compute the spectrum. + weights : array | None + The weights for each taper. Only present if spectra computed with + ``method='multitaper'`` and ``output='complex'``. + + .. versionadded:: 1.8 See Also -------- @@ -1420,6 +1522,11 @@ def average(self, method="mean"): "supported. Consider averaging the signals before computing " "the Welch spectrum estimates." ) + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs is not supported. Consider " + "averaging the signals before computing the complex spectrum." + ) # serialize the object and update data, dims, and data type state = super().__getstate__() state["nave"] = state["data"].shape[0] @@ -1449,12 +1556,24 @@ class EpochsSpectrumArray(EpochsSpectrum): Parameters ---------- - data : array, shape (n_epochs, n_channels, n_freqs) - The power spectral density for each channel in each epoch. + data : ndarray, shape (n_epochs, n_channels, [n_tapers], n_freqs, [n_segments]) + The spectra for each channel in each epoch. %(info_not_none)s %(freqs_tfr_array)s %(events_epochs)s %(event_id)s + dim_names : tuple of str + The name of the dimensions in the data, in the order they occur. Must contain + ``'channel'`` and ``'freq'``; if data are unaggregated estimates, also include + either a ``'segment'`` (e.g., Welch-like algorithms) or ``'taper'`` (e.g., + multitaper algorithms) dimension. If including ``'taper'``, you should also pass + a ``weights`` parameter. + + .. versionadded:: 1.8 + weights : ndarray | None + Weights for the ``'taper'`` dimension, if present (see ``dim_names``). + + .. versionadded:: 1.8 %(verbose)s See Also @@ -1478,31 +1597,44 @@ def __init__( freqs, events=None, event_id=None, + dim_names=("epoch", "channel", "freq"), + weights=None, *, verbose=None, ): - _check_data_shape(data, freqs, info, ndim=3) + # (epoch, channel, [taper], freq, [segment]) + _check_option("data.ndim", data.ndim, (3, 4)) # only allow one extra dimension + + if list(dim_names).index("epoch") != 0: + raise ValueError("'epoch' must be the first dimension of `data`.") if events is not None and data.shape[0] != events.shape[0]: raise ValueError( - f"The first dimension of `data` ({data.shape[0]}) must match the " - f"first dimension of `events` ({events.shape[0]})." + f"The first dimension of `data` ({data.shape[0]}) must match the first " + f"dimension of `events` ({events.shape[0]})." ) + _check_data_shape(data, info, freqs, dim_names, weights, is_epoched=True) + self.__setstate__( dict( method="unknown", data=data, sfreq=info["sfreq"], - dims=("epoch", "channel", "freq"), + dims=dim_names, freqs=freqs, inst_type_str="Array", - data_type="Power Spectrum", + data_type=( + "Fourier Coefficients" + if np.iscomplexobj(data) + else "Power Spectrum" + ), info=info, events=events, event_id=event_id, metadata=None, selection=np.arange(data.shape[0]), drop_log=tuple(tuple() for _ in range(data.shape[0])), + weights=weights, ) ) diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index a44c6aeaa17..980df42d791 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -6,10 +6,12 @@ import numpy as np import pytest from matplotlib.colors import same_color -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal -from mne import Annotations +from mne import Annotations, create_info, make_fixed_length_epochs +from mne.io import RawArray from mne.time_frequency import read_spectrum +from mne.time_frequency.multitaper import _psd_from_mt from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray from mne.utils import _record_warnings @@ -22,8 +24,6 @@ def test_compute_psd_errors(raw): raw.compute_psd(foo=None) with pytest.raises(TypeError, match="keyword arguments foo, bar for"): raw.compute_psd(foo=None, bar=None) - with pytest.raises(ValueError, match="Complex output is not supported in "): - raw.compute_psd(output="complex") raw.set_annotations(Annotations(onset=0.01, duration=0.01, description="bad_foo")) with pytest.raises(NotImplementedError, match='Cannot use method="multitaper"'): raw.compute_psd(method="multitaper", reject_by_annotation=True) @@ -33,7 +33,7 @@ def test_compute_psd_errors(raw): @pytest.mark.parametrize( ( "fmin, fmax, tmin, tmax, picks, proj, n_fft, n_overlap, n_per_seg, " - "average, window, bandwidth, adaptive, low_bias, normalization" + "average, window, bandwidth, adaptive, low_bias, normalization, output" ), [ [ @@ -52,6 +52,7 @@ def test_compute_psd_errors(raw): False, True, "length", + "power", ], # defaults [ 5, @@ -69,7 +70,26 @@ def test_compute_psd_errors(raw): True, False, "full", - ], # non-defaults + "power", + ], # non-defaults (excluding output) + [ + 0, + np.inf, + None, + None, + None, + False, + 256, + 0, + None, + "mean", + "hamming", + None, + False, + True, + "length", + "complex", + ], # complex (testing with non-defaults doesn't increase coverage) ], ) def test_spectrum_params( @@ -89,6 +109,7 @@ def test_spectrum_params( adaptive, low_bias, normalization, + output, raw, ): """Test valid parameter combinations in the .compute_psd() method.""" @@ -100,6 +121,7 @@ def test_spectrum_params( tmax=tmax, picks=picks, proj=proj, + output=output, ) if method == "welch": kwargs.update( @@ -260,6 +282,76 @@ def test_spectrum_to_data_frame(inst, request, evoked): assert_frame_equal(_pick_first, _pick_last) +def _agg_helper(df, weights, group_cols): + """Aggregate complex multitaper spectrum after conversion to DataFrame.""" + from pandas import Series + + unagged_columns = df[group_cols].iloc[0].values.tolist() + x_mt = df.drop(columns=group_cols).values[np.newaxis].T + psd = _psd_from_mt(x_mt, weights) + psd = np.atleast_1d(np.squeeze(psd)).tolist() + _df = dict(zip(df.columns, unagged_columns + psd)) + return Series(_df) + + +@pytest.mark.parametrize("long_format", (False, True)) +@pytest.mark.parametrize( + "method, output", + [("welch", "complex"), ("welch", "power"), ("multitaper", "complex")], +) +def test_unaggregated_spectrum_to_data_frame(raw, long_format, method, output): + """Test converting unaggregated spectra (multiple segments/tapers) to data frame.""" + pytest.importorskip("pandas") + from pandas.testing import assert_frame_equal + + from mne.utils.dataframe import _inplace + + # aggregated spectrum → dataframe + orig_df = raw.compute_psd(method=method).to_data_frame(long_format=long_format) + # unaggregated welch or complex multitaper → + # aggregate w/ pandas (to make sure we did reshaping right) + kwargs = dict() + if method == "welch": + kwargs.update(average=False) + spectrum = raw.compute_psd(method=method, output=output, **kwargs) + df = spectrum.to_data_frame(long_format=long_format) + grouping_cols = ["freq"] + drop_cols = ["segment"] if method == "welch" else ["taper"] + if long_format: + grouping_cols.append("channel") + drop_cols.append("ch_type") + orig_df.drop(columns="ch_type", inplace=True) + # only do a couple freq bins, otherwise test takes forever for multitaper + subset = partial(np.isin, test_elements=spectrum.freqs[:2]) + df = df.loc[subset(df["freq"])] + orig_df = orig_df.loc[subset(orig_df["freq"])] + # sort orig_df, because at present we can't actually prevent pandas from + # sorting at the agg step *sigh* + _inplace(orig_df, "sort_values", by=grouping_cols, ignore_index=True) + # aggregate + df = df.drop(columns=drop_cols) + gb = df.groupby(grouping_cols, as_index=False, observed=False) + if method == "welch": + if output == "complex": + + def _fun(x): + return np.mean(np.real(x * np.conj(x))) # use mean to aggregate + + agg_df = gb.agg(_fun) + else: + agg_df = gb.mean() # excludes missing values itself + else: + gb = gb[df.columns] # XXX: try removing when minimum pandas >= 2.1 is required + agg_df = gb.apply(_agg_helper, spectrum.weights, grouping_cols) + # even with check_categorical=False, we know that the *data* matches; + # what may differ is the order of the "levels" in the *metadata* for the + # channel name column + agg_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True) + orig_df.sort_values(by=grouping_cols, ignore_index=True, inplace=True) + # One can have categorical dtype and the other plain object, so don't check that + assert_frame_equal(agg_df, orig_df, check_categorical=False, check_dtype=False) + + # not testing with Evoked because it already has projs applied @pytest.mark.parametrize("inst", ("raw", "epochs")) def test_spectrum_proj(inst, request): @@ -275,6 +367,58 @@ def test_spectrum_proj(inst, request): assert has_proj == no_proj +@pytest.mark.parametrize( + "method, average", [("welch", False), ("welch", "mean"), ("multitaper", None)] +) +def test_spectrum_complex(method, average): + """Test output='complex' support.""" + sfreq = 100 + n = 10 * sfreq + freq = 3.0 + phase = np.pi / 4 # should be recoverable + data = np.cos(2 * np.pi * freq * np.arange(n) / sfreq + phase)[np.newaxis] + raw = RawArray(data, create_info(1, sfreq, "eeg")) + epochs = make_fixed_length_epochs(raw, duration=2.0, preload=True) + assert len(epochs) == 5 + assert len(epochs.times) == 2 * sfreq + kwargs = dict(output="complex", method=method) + if method == "welch": + kwargs["n_fft"] = sfreq + want_dims = ("epoch", "channel", "freq") + want_shape = (5, 1, sfreq // 2 + 1) + if not average: + want_dims = want_dims + ("segment",) + want_shape = want_shape + (2,) + kwargs["average"] = average + else: + assert method == "multitaper" + assert not average + want_dims = ("epoch", "channel", "taper", "freq") + want_shape = (5, 1, 7, sfreq + 1) + spectrum = epochs.compute_psd(**kwargs) + idx = np.argmin(np.abs(spectrum.freqs - freq)) + assert spectrum.freqs[idx] == freq + assert spectrum._dims == want_dims + assert spectrum.shape == want_shape + data = spectrum.get_data() + assert data.dtype == np.complex128 + coef = spectrum.get_data(fmin=freq, fmax=freq).mean(0) + if method == "multitaper": + coef = coef[..., 0, :] # first taper + elif not average: + coef = coef.mean(-1) # over segments + coef = coef.item() + # Test phase matches what was simulated + assert_allclose(np.angle(coef), phase, rtol=1e-4) + # Now test that it warns appropriately + epochs._data[0, 0, :] = 0 # actually zero for one epoch and ch + with pytest.warns(UserWarning, match="Zero value.*channel 0"): + epochs.compute_psd(**kwargs) + # But not if we mark that channel as bad + epochs.info["bads"] = epochs.ch_names[:1] + epochs.compute_psd(**kwargs) + + def test_spectrum_kwarg_triaging(raw): """Test kwarg triaging in legacy plot_psd() method.""" import matplotlib.pyplot as plt @@ -295,44 +439,162 @@ def _check_spectrum_equivalent(spect1, spect2, tmp_path): assert_array_equal(spect1.freqs, spect2.freqs) -def test_spectrum_array_errors(epochs_spectrum): - """Test EpochsSpectrumArray constructor errors.""" - data, freqs = epochs_spectrum.get_data(return_freqs=True) - info = epochs_spectrum.info - with pytest.raises(ValueError, match="Data must be a 3D array"): - EpochsSpectrumArray(np.empty((2, 3, 4, 5)), info, freqs) - with pytest.raises(ValueError, match=r"number of channels.*good data channels"): - EpochsSpectrumArray(data[:, :-1], info, freqs) - with pytest.raises(ValueError, match=r"last dimension.*same number of elements"): - EpochsSpectrumArray(data[..., :-1], info, freqs) +def test_spectrum_array_errors(): + """Test (Epochs)SpectrumArray constructor errors.""" + n_epochs = 10 + n_chans = 5 + n_freqs = 50 + freqs = np.arange(n_freqs) + sfreq = 100 + rng = np.random.default_rng(44) + data = rng.random((n_epochs, n_chans, n_freqs)) + dim_names = ("epoch", "channel", "freq") + info = create_info(n_chans, sfreq, "eeg") + # test incorrect ndims (for SpectrumArray; allows 2-3D data) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + SpectrumArray(data[0, 0, :], info, freqs, dim_names=dim_names) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + SpectrumArray(np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names) + # test incorrect ndims (for EpochsSpectrumArray; allows 3-4D data) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + EpochsSpectrumArray(data[0, :, :], info, freqs, dim_names=dim_names) + with pytest.raises(ValueError, match="Invalid value for the 'data.ndim' parameter"): + EpochsSpectrumArray( + np.expand_dims(data, axis=(3, 4)), info, freqs, dim_names=dim_names + ) + # test incorrect epochs location + with pytest.raises(ValueError, match="'epoch' must be the first dimension"): + EpochsSpectrumArray(data, info, freqs, dim_names=("channel", "epoch", "freq")) # test mismatching events shape - n_epo = data.shape[0] + 1 # +1 so they purposely don't match events = np.vstack( - (np.arange(n_epo), np.zeros(n_epo, dtype=int), np.ones(n_epo, dtype=int)) + ( + np.arange(n_epochs + 1), + np.zeros(n_epochs + 1, dtype=int), + np.ones(n_epochs + 1, dtype=int), + ) ).T with pytest.raises(ValueError, match=r"first dimension.*dimension of `events`"): - EpochsSpectrumArray(data, info, freqs, events) + EpochsSpectrumArray(data, info, freqs, events, dim_names=dim_names) + # test data-dimname mismatch + with pytest.raises(ValueError, match=r"Expected data to have.*dimensions, got.*"): + EpochsSpectrumArray(data, info, freqs, dim_names=dim_names[:-1]) + # test unrecognised dim_names (for SpectrumArray; epoch not allowed) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + SpectrumArray(data[0, :, :], info, freqs, dim_names=("epoch", "channel")) + # test unrecognised dim_names (for EpochsSpectrumArray) + with pytest.raises(ValueError, match="Invalid value for the 'dim_names' parameter"): + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "notfreq") + ) + # test missing dim_names + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray( + data, info, freqs, dim_names=("epoch", "channel", "channel") + ) + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "freq")) + with pytest.raises(ValueError, match="Both 'channel' and 'freq' must be present"): + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "epoch", "epoch")) + # test incorrect channel location (for SpectrumArray; must be 1st dim) + with pytest.raises(ValueError, match="'channel' must be the first dimension"): + SpectrumArray(data[0, :, :], info, freqs, dim_names=("freq", "channel")) + # test incorrect channel location (for EpochsSpectrumArray; must be 2nd dim) + with pytest.raises(ValueError, match="'channel' must be the second dimension"): + EpochsSpectrumArray(data, info, freqs, dim_names=("epoch", "freq", "channel")) + # test mismatching number of channels + with pytest.raises(ValueError, match=r"number of channels.*good data channels"): + EpochsSpectrumArray(data[:, :-1, :], info, freqs, dim_names=dim_names) + # test incorrect taper position + with pytest.raises(ValueError, match="'taper' must be the second to last dim"): + EpochsSpectrumArray( + np.expand_dims(data, axis=3), info, freqs, dim_names=dim_names + ("taper",) + ) + # test incorrect weight size + with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dim_names=("epoch", "channel", "taper", "freq"), + weights=None, + ) + with pytest.raises(ValueError, match=r"Expected size of `weights` to be.*, got.*"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dim_names=("epoch", "channel", "taper", "freq"), + weights=np.ones((1, 2, 1)), + ) + # test incorrect segment position + with pytest.raises(ValueError, match="'segment' must be the last dim"): + EpochsSpectrumArray( + np.expand_dims(data, axis=2), + info, + freqs, + dim_names=("epoch", "channel", "segment", "freq"), + ) + # test mismatching number of frequencies + with pytest.raises(ValueError, match=r"number of frequencies.*number of elements"): + EpochsSpectrumArray(data[:, :, :-1], info, freqs, dim_names=dim_names) -@pytest.mark.parametrize("kind", ("raw", "epochs")) -def test_spectrum_array(kind, tmp_path, request): +@pytest.mark.parametrize( + "kind, method, output, average", + [ + ("raw", "welch", "power", "mean"), # test with precomputed spectrum + ("epochs", "welch", "power", False), # test with segments + ("epochs", "multitaper", "complex", None), # test with tapers + ], # additional variants don't improve coverage +) +def test_spectrum_array(kind, method, output, average, tmp_path, request): """Test EpochsSpectrumArray and SpectrumArray constructors.""" - spectrum = request.getfixturevalue(f"{kind}_spectrum") + dim_names = ("epoch", "channel") if kind == "epochs" else ("channel",) + if method == "welch": + dim_names += ("freq",) if average else ("freq", "segment") + else: # i.e. multitaper + dim_names += ("freq",) if output == "power" else ("taper", "freq") + if method == "welch" and output == "power" and average: + spectrum = request.getfixturevalue(f"{kind}_spectrum") + else: + data = request.getfixturevalue(kind) + kwargs = dict() + if method == "welch": + kwargs.update(average=average) + spectrum = data.compute_psd(method=method, output=output, **kwargs) data, freqs = spectrum.get_data(return_freqs=True) Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray - spect_arr = Klass(data=data, info=spectrum.info, freqs=freqs) + spect_arr = Klass( + data=data, + info=spectrum.info, + freqs=freqs, + dim_names=dim_names, + weights=spectrum.weights, + ) _check_spectrum_equivalent(spectrum, spect_arr, tmp_path) -@pytest.mark.parametrize("kind", ("raw", "epochs")) -@pytest.mark.parametrize("array", (False, True)) -def test_plot_spectrum(kind, array, request): - """Test plotting (Epochs)Spectrum(Array).""" - spectrum = request.getfixturevalue(f"{kind}_spectrum") - if array: - data, freqs = spectrum.get_data(return_freqs=True) - Klass = SpectrumArray if kind == "raw" else EpochsSpectrumArray - spectrum = Klass(data=data, info=spectrum.info, freqs=freqs) +@pytest.mark.parametrize( + "method, output, average", + [ + ("welch", "power", "mean"), # test with precomputed spectrum + ("welch", "complex", False), # test aggr over segments & conversion to power + ("multitaper", "complex", None), # test aggr over tapers & conversion to power + ], # additional variants don't improve coverage +) +def test_plot_spectrum(method, output, average, request): + """Test plotting EpochsSpectrum(Array). + + Testing Spectrum(Array) with raw data doesn't improve coverage. + """ + if method == "welch" and output == "power" and average: + spectrum = request.getfixturevalue("epochs_spectrum") + else: + data = request.getfixturevalue("epochs") + kwargs = dict() + if method == "welch": + kwargs.update(average=average) + spectrum = data.compute_psd(method=method, output=output, **kwargs) spectrum.info["bads"] = spectrum.ch_names[:1] # one grad channel spectrum.plot(average=True, amplitude=True, spatial_colors=True) spectrum.plot(average=True, amplitude=False, spatial_colors=False) diff --git a/mne/utils/docs.py b/mne/utils/docs.py index ff9e11ee776..57a0999fd1e 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -2922,11 +2922,15 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["notes_plot_psd_meth"] = _notes_plot_psd.format("method") docdict["notes_spectrum_array"] = """ -It is assumed that the data passed in represent spectral *power* (not amplitude, -phase, model coefficients, etc) and downstream methods (such as +If the data passed in is real-valued, it is assumed to represent spectral *power* (not +amplitude, phase, etc), and downstream methods (such as :meth:`~mne.time_frequency.SpectrumArray.plot`) assume power data. If you pass in -something other than power, at the very least axis labels will be inaccurate (and -other things may also not work or be incorrect). +real-valued data that is not power, axis labels will be incorrect. + +If the data passed in is complex-valued, it is assumed to represent Fourier +coefficients. Downstream plotting methods will treat the data as such, attempting to +convert this to power before visualisation. If you pass in complex-valued data that is +not Fourier coefficients, axis labels will be incorrect. """ docdict["notes_timefreqs_tfr_plot_joint"] = """ diff --git a/mne/utils/spectrum.py b/mne/utils/spectrum.py index 92ed4170c83..4425616f93d 100644 --- a/mne/utils/spectrum.py +++ b/mne/utils/spectrum.py @@ -9,6 +9,8 @@ def _get_instance_type_string(inst): """Get string representation of the originating instance type.""" + from numpy import ndarray + from ..epochs import BaseEpochs from ..evoked import Evoked, EvokedArray from ..io import BaseRaw @@ -20,6 +22,8 @@ def _get_instance_type_string(inst): inst_type_str = "Epochs" elif inst._inst_type in (Evoked, EvokedArray): inst_type_str = "Evoked" + elif inst._inst_type == ndarray: + inst_type_str = "Array" else: raise RuntimeError( f"Unknown instance type {inst._inst_type} in {type(inst).__name__}"