-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MRG: Introduce ICA.get_explained_variance_ratio() to easily retrieve relative explained variances after a fit #11141
Changes from all commits
9188fb1
47f63a7
905be68
5d96641
08544d5
c966a43
d63a502
1aca2b2
bbcd19d
b499dcf
dd49c6c
eda143a
5cb3cfd
bd2ccc2
33e3a28
0857cee
b4ec360
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,11 +8,13 @@ | |
|
||
from inspect import isfunction | ||
from collections import namedtuple | ||
from collections.abc import Sequence | ||
from copy import deepcopy | ||
from numbers import Integral | ||
from time import time | ||
from dataclasses import dataclass | ||
from typing import Optional, List | ||
import warnings | ||
|
||
import math | ||
import json | ||
|
@@ -452,7 +454,6 @@ class _InfosForRepr: | |
fit_n_samples: Optional[int] | ||
fit_n_components: Optional[int] | ||
fit_n_pca_components: Optional[int] | ||
fit_explained_variance: Optional[float] | ||
ch_types: List[str] | ||
excludes: List[str] | ||
|
||
|
@@ -470,11 +471,6 @@ class _InfosForRepr: | |
fit_n_pca_components = getattr(self, 'pca_components_', None) | ||
if fit_n_pca_components is not None: | ||
fit_n_pca_components = len(self.pca_components_) | ||
fit_explained_variance = getattr(self, 'pca_explained_variance_', None) | ||
if fit_explained_variance is not None: | ||
abs_vars = self.pca_explained_variance_ | ||
rel_vars = abs_vars / abs_vars.sum() | ||
fit_explained_variance = rel_vars[:fit_n_components].sum() | ||
|
||
if self.info is not None: | ||
ch_types = [c for c in _DATA_CH_TYPES_SPLIT if c in self] | ||
|
@@ -493,7 +489,6 @@ class _InfosForRepr: | |
fit_n_samples=fit_n_samples, | ||
fit_n_components=fit_n_components, | ||
fit_n_pca_components=fit_n_pca_components, | ||
fit_explained_variance=fit_explained_variance, | ||
ch_types=ch_types, | ||
excludes=excludes | ||
) | ||
|
@@ -511,8 +506,6 @@ def __repr__(self): | |
f' (fit in {infos.fit_n_iter} iterations on ' | ||
f'{infos.fit_n_samples} samples), ' | ||
f'{infos.fit_n_components} ICA components ' | ||
f'explaining {round(infos.fit_explained_variance * 100, 1)} % ' | ||
f'of variance ' | ||
f'({infos.fit_n_pca_components} PCA components available), ' | ||
f'channel types: {", ".join(infos.ch_types)}, ' | ||
f'{len(infos.excludes) or "no"} sources marked for exclusion' | ||
|
@@ -531,7 +524,6 @@ def _repr_html_(self): | |
n_samples=infos.fit_n_samples, | ||
n_components=infos.fit_n_components, | ||
n_pca_components=infos.fit_n_pca_components, | ||
explained_variance=infos.fit_explained_variance, | ||
ch_types=infos.ch_types, | ||
excludes=infos.excludes | ||
) | ||
|
@@ -962,6 +954,141 @@ def get_components(self): | |
return np.dot(self.mixing_matrix_[:, :self.n_components_].T, | ||
self.pca_components_[:self.n_components_]).T | ||
|
||
def get_explained_variance_ratio( | ||
self, inst, *, components=None, ch_type=None | ||
): | ||
"""Get the proportion of data variance explained by ICA components. | ||
|
||
Parameters | ||
---------- | ||
inst : mne.io.BaseRaw | mne.BaseEpochs | mne.Evoked | ||
The uncleaned data. | ||
components : array-like of int | int | None | ||
The component(s) for which to do the calculation. If more than one | ||
component is specified, explained variance will be calculated | ||
jointly across all supplied components. If ``None`` (default), uses | ||
all available components. | ||
ch_type : 'mag' | 'grad' | 'planar1' | 'planar2' | 'eeg' | array-like of str | None | ||
The channel type(s) to include in the calculation. If ``None``, all | ||
available channel types will be used. | ||
|
||
Returns | ||
------- | ||
dict (str, float) | ||
The fraction of variance in ``inst`` that can be explained by the | ||
ICA components, calculated separately for each channel type. | ||
Dictionary keys are the channel types, and corresponding explained | ||
variance ratios are the values. | ||
|
||
Notes | ||
----- | ||
A value similar to EEGLAB's ``pvaf`` (percent variance accounted for) | ||
will be calculated for the specified component(s). | ||
|
||
Since ICA components cannot be assumed to be aligned orthogonally, the | ||
sum of the proportion of variance explained by all components may not | ||
be equal to 1. In certain situations, the proportion of variance | ||
explained by a component may even be negative. | ||
|
||
.. versionadded:: 1.2 | ||
""" # noqa: E501 | ||
if self.current_fit == 'unfitted': | ||
raise ValueError('ICA must be fitted first.') | ||
|
||
_validate_type( | ||
item=inst, types=(BaseRaw, BaseEpochs, Evoked), | ||
item_name='inst' | ||
) | ||
_validate_type( | ||
item=components, types=(None, 'int-like', Sequence, np.ndarray), | ||
item_name='components', type_name='int, array-like of int, or None' | ||
) | ||
if isinstance(components, (Sequence, np.ndarray)): | ||
for item in components: | ||
_validate_type( | ||
item=item, types='int-like', | ||
item_name='Elements of "components"' | ||
) | ||
|
||
_validate_type( | ||
item=ch_type, types=(Sequence, np.ndarray, str, None), | ||
item_name='ch_type', type_name='str, array-like of str, or None' | ||
) | ||
if isinstance(ch_type, str): | ||
ch_types = [ch_type] | ||
elif ch_type is None: | ||
ch_types = inst.get_channel_types(unique=True, only_data_chs=True) | ||
else: | ||
assert isinstance(ch_type, (Sequence, np.ndarray)) | ||
ch_types = ch_type | ||
|
||
assert len(ch_types) >= 1 | ||
allowed_ch_types = ('mag', 'grad', 'planar1', 'planar2', 'eeg') | ||
for ch_type in ch_types: | ||
if ch_type not in allowed_ch_types: | ||
raise ValueError( | ||
f'You requested operation on the channel type ' | ||
f'"{ch_type}", but only the following channel types are ' | ||
f'supported: {", ".join(allowed_ch_types)}' | ||
) | ||
del ch_type | ||
|
||
# Input data validation ends here | ||
if components is None: | ||
components = range(self.n_components_) | ||
|
||
explained_var_ratios = [ | ||
self._get_explained_variance_ratio_one_ch_type( | ||
inst=inst, components=components, ch_type=ch_type | ||
) for ch_type in ch_types | ||
] | ||
result = dict(zip(ch_types, explained_var_ratios)) | ||
return result | ||
|
||
def _get_explained_variance_ratio_one_ch_type( | ||
self, *, inst, components, ch_type | ||
): | ||
# The algorithm implemented below should be equivalent to | ||
# https://sccn.ucsd.edu/pipermail/eeglablist/2014/009134.html | ||
# | ||
# Reconstruct ("back-project") the data using only the specified ICA | ||
# components. Don't make use of potential "spare" PCA components in | ||
# this process – we're only interested in the contribution of the ICA | ||
# components! | ||
kwargs = dict( | ||
inst=inst.copy(), | ||
include=[components], | ||
exclude=[], | ||
n_pca_components=0, | ||
verbose=False, | ||
) | ||
if ( | ||
isinstance(inst, (BaseEpochs, Evoked)) and | ||
inst.baseline is not None | ||
): | ||
# Don't warn if data was baseline-corrected. | ||
with warnings.catch_warnings(): | ||
warnings.filterwarnings( | ||
action='ignore', | ||
message='The data.*was baseline-corrected', | ||
category=RuntimeWarning | ||
) | ||
inst_recon = self.apply(**kwargs) | ||
else: | ||
inst_recon = self.apply(**kwargs) | ||
|
||
data_recon = inst_recon.get_data(picks=ch_type) | ||
data_orig = inst.get_data(picks=ch_type) | ||
data_diff = data_orig - data_recon | ||
|
||
# To estimate the data variance, we first compute the variance across | ||
# channels at each time point, and then we average these variances. | ||
mean_var_diff = data_diff.var(axis=0).mean() | ||
mean_var_orig = data_orig.var(axis=0).mean() | ||
Comment on lines
+1084
to
+1087
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I worry a bit about this approach because it is going to be very sensitive to channel scaling. If you use a prewhitener in MNE MEG+EEG can be processed jointly (I think?), but this I think you can fix this by applying the pre-whitener to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Excellent point. How about we make our lives easy and simply apply this algorithm here to each channel type separately, and return a dict with the results? One like we use for ICA property plots also only show one channel type at a time, so this part would be consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think we have to do this. As it makes no sense compute variance with mixed channel types. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is now implemented. |
||
|
||
var_explained_ratio = 1 - mean_var_diff / mean_var_orig | ||
return var_explained_ratio | ||
|
||
def get_sources(self, inst, add_channels=None, start=None, stop=None): | ||
"""Estimate sources given the unmixing matrix. | ||
|
||
|
@@ -2247,6 +2374,8 @@ def _find_sources(sources, target, score_func): | |
def _ica_explained_variance(ica, inst, normalize=False): | ||
"""Check variance accounted for by each component in supplied data. | ||
|
||
This function is only used for sorting the components. | ||
|
||
Parameters | ||
---------- | ||
ica : ICA | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the copy really needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, yes, because we use
ica.apply()
below and this works in place. We could do without the copy if instead of usingica.apply()
, we'd "manually" do the matrix multiplication. Butica.apply()
does quite a few additional things and I'm worried I'd forget something important 😅I trust that if we discover the
copy()
here causes issues, we can optimize things later on.