Skip to content

Commit

Permalink
[ENH, MRG] Add EpochsSpectrumArray and SpectrumArray classes (mne-too…
Browse files Browse the repository at this point in the history
…ls#11803)

Co-authored-by: Daniel McCloy <[email protected]>
  • Loading branch information
2 people authored and snwnde committed Mar 20, 2024
1 parent ad33235 commit 01892d7
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 98 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Enhancements
- Added option ``remove_dc`` to to :meth:`Raw.compute_psd() <mne.io.Raw.compute_psd>`, :meth:`Epochs.compute_psd() <mne.Epochs.compute_psd>`, and :meth:`Evoked.compute_psd() <mne.Evoked.compute_psd>`, to allow skipping DC removal when computing Welch or multitaper spectra (:gh:`11769` by `Nikolai Chapochnikov`_)
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)

Bugs
~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ Time-Frequency
EpochsTFR
CrossSpectralDensity
Spectrum
SpectrumArray
EpochsSpectrum
EpochsSpectrumArray

Functions that operate on mne-python objects:

Expand Down
12 changes: 12 additions & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def raw_ctf():
return raw_ctf


@pytest.fixture(scope="function")
def raw_spectrum(raw):
"""Get raw with power spectral density computed from mne.io.tests.data."""
return raw.compute_psd()


@pytest.fixture(scope="function")
def events():
"""Get events from mne.io.tests.data."""
Expand Down Expand Up @@ -349,6 +355,12 @@ def epochs_full():
return _get_epochs(None).load_data()


@pytest.fixture()
def epochs_spectrum():
"""Get epochs with power spectral density computed from mne.io.tests.data."""
return _get_epochs().load_data().compute_psd()


@pytest.fixture()
def epochs_empty():
"""Get empty epochs from mne.io.tests.data."""
Expand Down
37 changes: 7 additions & 30 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3197,40 +3197,17 @@ class EpochsArray(BaseEpochs):
measure.
%(info_not_none)s Consider using :func:`mne.create_info` to populate this
structure.
events : None | array of int, shape (n_events, 3)
The events typically returned by the read_events function.
If some events don't match the events of interest as specified
by event_id, they will be marked as 'IGNORED' in the drop log.
If None (default), all event values are set to 1 and event time-samples
are set to range(n_epochs).
tmin : float
Start time before event. If nothing provided, defaults to 0.
event_id : int | list of int | dict | None
The id of the event to consider. If dict,
the keys can later be used to access associated events. Example:
dict(auditory=1, visual=3). If int, a dict will be created with
the id as string. If a list, all events with the IDs specified
in the list are used. If None, all events will be used with
and a dict is created with string integer names corresponding
to the event id integers.
%(events_epochs)s
%(tmin_epochs)s
%(event_id)s
%(reject_epochs)s
%(flat)s
reject_tmin : scalar | None
Start of the time window used to reject epochs (with the default None,
the window will start with tmin).
reject_tmax : scalar | None
End of the time window used to reject epochs (with the default None,
the window will end with tmax).
%(epochs_reject_tmin_tmax)s
%(baseline_epochs)s
Defaults to ``None``, i.e. no baseline correction.
proj : bool | 'delayed'
Apply SSP projection vectors. See :class:`mne.Epochs` for details.
on_missing : str
See :class:`mne.Epochs` docstring for details.
metadata : instance of pandas.DataFrame | None
See :class:`mne.Epochs` docstring for details.
.. versionadded:: 0.16
%(proj_epochs)s
%(on_missing_epochs)s
%(metadata_epochs)s
%(selection)s
%(drop_log)s
Expand Down
8 changes: 7 additions & 1 deletion mne/time_frequency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@
"tfr_array_multitaper",
],
"psd": ["psd_array_welch"],
"spectrum": ["EpochsSpectrum", "Spectrum", "read_spectrum"],
"spectrum": [
"EpochsSpectrum",
"EpochsSpectrumArray",
"Spectrum",
"SpectrumArray",
"read_spectrum",
],
"tfr": [
"_BaseTFR",
"AverageTFR",
Expand Down
147 changes: 143 additions & 4 deletions mne/time_frequency/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ def __setstate__(self, state):
self._data_type = state["data_type"]
self.preload = True
# instance type
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked)
inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Array=np.ndarray)
self._inst_type = inst_types[state["inst_type_str"]]
if "weights" in state and state["weights"] is not None:
self._mt_weights = state["weights"]

def __repr__(self):
"""Build string representation of the Spectrum object."""
Expand Down Expand Up @@ -486,6 +488,8 @@ def _get_instance_type_string(self):
inst_type_str = "Epochs"
elif self._inst_type in (Evoked, EvokedArray):
inst_type_str = "Evoked"
elif self._inst_type is np.ndarray:
inst_type_str = "Array"
else:
raise RuntimeError(f"Unknown instance type {self._inst_type} in Spectrum")
return inst_type_str
Expand Down Expand Up @@ -766,6 +770,8 @@ def plot_topo(
layout = find_layout(self.info)

psds, freqs = self.get_data(return_freqs=True)
if "epoch" in self._dims:
psds = np.mean(psds, axis=self._dims.index("epoch"))
if dB:
psds = 10 * np.log10(psds)
y_label = "dB"
Expand Down Expand Up @@ -977,7 +983,7 @@ def to_data_frame(
# check pandas once here, instead of in each private utils function
pd = _check_pandas_installed() # noqa
# triage for Epoch-derived or unaggregated spectra
from_epo = self._get_instance_type_string() == "Epochs"
from_epo = self._dims[0] == "epoch"
unagg_welch = "segment" in self._dims
unagg_mt = "taper" in self._dims
# arg checking
Expand Down Expand Up @@ -1089,6 +1095,7 @@ class Spectrum(BaseSpectrum):
See Also
--------
EpochsSpectrum
SpectrumArray
mne.io.Raw.compute_psd
mne.Epochs.compute_psd
mne.Evoked.compute_psd
Expand Down Expand Up @@ -1190,6 +1197,75 @@ def __getitem__(self, item):
return BaseRaw._getitem(self, item, return_times=False)


def _check_data_shape(data, freqs, info, ndim):
if data.ndim != ndim:
raise ValueError(f"Data must be a {ndim}D array.")
want_n_chan = _pick_data_channels(info).size
want_n_freq = freqs.size
got_n_chan, got_n_freq = data.shape[-2:]
if got_n_chan != want_n_chan:
raise ValueError(
f"The number of channels in `data` ({got_n_chan}) must match the "
f"number of good data channels in `info` ({want_n_chan})."
)
if got_n_freq != want_n_freq:
raise ValueError(
f"The last dimension of `data` ({got_n_freq}) must have the same "
f"number of elements as `freqs` ({want_n_freq})."
)


@fill_doc
class SpectrumArray(Spectrum):
"""Data object for precomputed spectral data (in NumPy array format).
Parameters
----------
data : array, shape (n_channels, n_freqs)
The power spectral density for each channel.
%(info_not_none)s
%(freqs_tfr)s
%(verbose)s
See Also
--------
mne.create_info
mne.EvokedArray
mne.io.RawArray
EpochsSpectrumArray
Notes
-----
%(notes_spectrum_array)s
.. versionadded:: 1.6
"""

@verbose
def __init__(
self,
data,
info,
freqs,
*,
verbose=None,
):
_check_data_shape(data, freqs, info, ndim=2)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=("channel", "freq"),
freqs=freqs,
inst_type_str="Array",
data_type="Power Spectrum",
info=info,
)
)


@fill_doc
class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
"""Data object for spectral representations of epoched data.
Expand Down Expand Up @@ -1225,10 +1301,9 @@ class EpochsSpectrum(BaseSpectrum, GetEpochsMixin):
See Also
--------
EpochsSpectrumArray
Spectrum
mne.io.Raw.compute_psd
mne.Epochs.compute_psd
mne.Evoked.compute_psd
References
----------
Expand Down Expand Up @@ -1385,6 +1460,70 @@ def average(self, method="mean"):
return Spectrum(state, **defaults)


@fill_doc
class EpochsSpectrumArray(EpochsSpectrum):
"""Data object for precomputed epoched spectral data (in NumPy array format).
Parameters
----------
data : array, shape (n_epochs, n_channels, n_freqs)
The power spectral density for each channel in each epoch.
%(info_not_none)s
%(freqs_tfr)s
%(events_epochs)s
%(event_id)s
%(verbose)s
See Also
--------
mne.create_info
mne.EpochsArray
SpectrumArray
Notes
-----
%(notes_spectrum_array)s
.. versionadded:: 1.6
"""

@verbose
def __init__(
self,
data,
info,
freqs,
events=None,
event_id=None,
*,
verbose=None,
):
_check_data_shape(data, freqs, info, ndim=3)
if events is not None and data.shape[0] != events.shape[0]:
raise ValueError(
f"The first dimension of `data` ({data.shape[0]}) must match the "
f"first dimension of `events` ({events.shape[0]})."
)

self.__setstate__(
dict(
method="unknown",
data=data,
sfreq=info["sfreq"],
dims=("epoch", "channel", "freq"),
freqs=freqs,
inst_type_str="Array",
data_type="Power Spectrum",
info=info,
events=events,
event_id=event_id,
metadata=None,
selection=np.arange(data.shape[0]),
drop_log=tuple(tuple() for _ in range(data.shape[0])),
)
)


def read_spectrum(fname):
"""Load a :class:`mne.time_frequency.Spectrum` object from disk.
Expand Down
Loading

0 comments on commit 01892d7

Please sign in to comment.