From 03e061e7700a0cabb3093d7bdb306b6b66df68f8 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Tue, 25 Oct 2022 14:46:15 +0900 Subject: [PATCH 01/24] implemented Atomize and Agglomerate Hierachical Clustering (AAHL) - updated pycrostates.io.fiff to write/read results - create test suite --- docs/references.bib | 41 +- pycrostates/cluster/__init__.py | 3 +- pycrostates/cluster/aahc.py | 372 ++++++++ pycrostates/cluster/tests/test_aahc.py | 1204 ++++++++++++++++++++++++ pycrostates/io/fiff.py | 69 +- pycrostates/utils/_checks.py | 1 + 6 files changed, 1678 insertions(+), 12 deletions(-) create mode 100644 pycrostates/cluster/aahc.py create mode 100644 pycrostates/cluster/tests/test_aahc.py diff --git a/docs/references.bib b/docs/references.bib index 69346c20..6c405da1 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -38,7 +38,46 @@ @article{MICHEL2018577 author = {Christoph M. Michel and Thomas Koenig}, keywords = {EEG microstates, Resting state networks, Consciousness, Psychiatric disease, State-dependent information processing, Metastability}, } -% Modified K-Means + +% Michel 2018 - microstates review +@article{MICHEL2018577, + title = {EEG microstates as a tool for studying the temporal dynamics of whole-brain neuronal networks: A review}, + journal = {NeuroImage}, + volume = {180}, + pages = {577-593}, + year = {2018}, + note = {Brain Connectivity Dynamics}, + issn = {1053-8119}, + doi = {10.1016/j.neuroimage.2017.11.062}, + author = {Christoph M. Michel and Thomas Koenig}, + keywords = {EEG microstates, Resting state networks, Consciousness, Psychiatric disease, State-dependent information processing, Metastability}, +} + +% Atomize and Agglomerate Hierarchical Clustering (AAHC) + +@article{Murray2008, + author = {Murray, Micah M. and Brunet, Denis and Michel, Christoph M.}, + journal = {Brain Topography}, + title = {Topographic {ERP} {Analyses}: {A} {Step}-by-{Step} {Tutorial} {Review}}, + year = {2008}, + volume = {20}, + number = {4}, + pages = {249--264}, + doi = {10.1007/s10548-008-0054-5} +} + + +@inproceedings{Roweis1997, + author = {Roweis, Sam}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {M. Jordan and M. Kearns and S. Solla}, + publisher = {MIT Press}, + title = {EM Algorithms for PCA and SPCA}, + url = {https://proceedings.neurips.cc/paper/1997/file/d9731321ef4e063ebbee79298fa36f56-Paper.pdf}, + volume = {10}, + year = {1997}, +} + @article{Marqui1995, author={Pascual-Marqui, R.D. and Michel, C.M. and Lehmann, D.}, journal={IEEE Transactions on Biomedical Engineering}, diff --git a/pycrostates/cluster/__init__.py b/pycrostates/cluster/__init__.py index 9da7e780..54a6014f 100644 --- a/pycrostates/cluster/__init__.py +++ b/pycrostates/cluster/__init__.py @@ -15,5 +15,6 @@ to segment.""" from .kmeans import ModKMeans # noqa: F401 +from .aahc import AAHCluster -__all__ = ("ModKMeans",) +__all__ = ("ModKMeans", "AAHCluster") diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py new file mode 100644 index 00000000..afdd89b7 --- /dev/null +++ b/pycrostates/cluster/aahc.py @@ -0,0 +1,372 @@ +"""Class and functions to use Atomize and Agglomerate Hierarchical Clustering + (AAHC).""" + +from functools import wraps as func_wraps +from pathlib import Path +from typing import Any, Optional, Tuple, Union + +import numpy as np +from mne import BaseEpochs +from mne.io import BaseRaw +from numpy.typing import NDArray + +from .._typing import Picks +from ..utils import _corr_vectors +from ..utils._checks import _check_type +from ..utils._docs import copy_doc, fill_doc +from ..utils._logs import _set_verbose, logger +from ._base import _BaseCluster + +# if we have numba, use its jit interface +try: + from numba import njit +except ImportError: + + def njit(cache=False): + def decorator(func): + @func_wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorator + + +class AAHCluster(_BaseCluster): + r"""Atomize and Agglomerate Hierarchical Clustering (AAHC) algorithm. + + See :footcite:t:`Murray2008` for additional information. + + Parameters + ---------- + %(n_clusters)s + + ignore_polarity : bool + If true, polarity is ignored when computing distances. + normalize_input : bool + If set, the input data is normalized along the channel dimension. + tol : float + Relative tolerance with regards estimate residual noise in the cluster + centers of two consecutive iterations to declare convergence. + + References + ---------- + .. footbibliography:: + """ + + def __init__( + self, + n_clusters: int, + ignore_polarity: bool = True, + normalize_input: bool = False, + tol: float = 1e-6, + ): + super().__init__() + + self._n_clusters = _BaseCluster._check_n_clusters(n_clusters) + self._cluster_names = [str(k) for k in range(self.n_clusters)] + + self._ignore_polarity = AAHCluster._check_ignore_polarity( + ignore_polarity + ) + self._normalize_input = AAHCluster._check_ignore_polarity( + normalize_input + ) + self._tol = AAHCluster._check_tol(tol) + + # fit variables + self._GEV_ = None + + def _repr_html_(self, caption=None): + from ..html_templates import repr_templates_env + + template = repr_templates_env.get_template("ModKMeans.html.jinja") + if self.fitted: + n_samples = self._fitted_data.shape[-1] + ch_types, ch_counts = np.unique( + self.get_channel_types(), return_counts=True + ) + ch_repr = [ + f"{ch_count} {ch_type.upper()}" + for ch_type, ch_count in zip(ch_types, ch_counts) + ] + GEV = int(self._GEV_ * 100) + else: + n_samples = None + ch_repr = None + GEV = None + + return template.render( + name=self.__class__.__name__, + n_clusters=self._n_clusters, + GEV=GEV, + cluster_names=self._cluster_names, + fitted=self._fitted, + n_samples=n_samples, + ch_repr=ch_repr, + ) + + @copy_doc(_BaseCluster.__eq__) + def __eq__(self, other: Any) -> bool: + """Equality == method.""" + if isinstance(other, AAHCluster): + if not super().__eq__(other): + return False + + attributes = ( + "_ignore_polarity", + "_normalize_input", + "_tol", + "_GEV_", + ) + for attribute in attributes: + try: + attr1 = self.__getattribute__(attribute) + attr2 = other.__getattribute__(attribute) + except AttributeError: + return False + if attr1 != attr2: + return False + + return True + return False + + @copy_doc(_BaseCluster.__ne__) + def __ne__(self, other: Any) -> bool: + """Different != method.""" + return not self.__eq__(other) + + @copy_doc(_BaseCluster._check_fit) + def _check_fit(self): + super()._check_fit() + # sanity-check + assert self.GEV_ is not None + + @copy_doc(_BaseCluster.fit) + @fill_doc + def fit( + self, + inst: Union[BaseRaw, BaseEpochs], + picks: Picks = "eeg", + tmin: Optional[Union[int, float]] = None, + tmax: Optional[Union[int, float]] = None, + reject_by_annotation: bool = True, + n_jobs: int = 1, + *, + verbose: Optional[str] = None, + ) -> None: + """ + %(verbose)s + """ + _set_verbose(verbose) # TODO: decorator nesting is failing + data = super().fit( + inst, picks, tmin, tmax, reject_by_annotation, n_jobs + ) + + gev, maps, segmentation = AAHCluster._aahc( + data, + self._n_clusters, + self._ignore_polarity, + self._normalize_input, + self._tol, + ) + + if gev is not None: + logger.info("AAHC converged with GEV = %.2f%% ", gev * 100) + + self._GEV_ = gev + self._cluster_centers_ = maps + self._labels_ = segmentation + self._fitted = True + + @copy_doc(_BaseCluster.save) + def save(self, fname: Union[str, Path]): + super().save(fname) + # TODO: to be replaced by a general writer than infers the writer from + # the file extension. + from ..io.fiff import _write_cluster # pylint: disable=C0415 + + _write_cluster( + fname, + self._cluster_centers_, + self._info, + "AAHCluster", + self._cluster_names, + self._fitted_data, + self._labels_, + ignore_polarity=self._ignore_polarity, + normalize_input=self._normalize_input, + tol=self._tol, + GEV_=self._GEV_, + ) + + # -------------------------------------------------------------------- + @staticmethod + def _aahc( + data: NDArray[float], + n_clusters: int, + ignore_polarity: bool, + normalize_input: bool, + tol: Union[int, float], + ) -> Tuple[float, NDArray[float], NDArray[int]]: + """Run the AAHC algorithm.""" + gfp_sum_sq = np.sum(data**2) + maps, segmentation = AAHCluster._compute_maps( + data, n_clusters, ignore_polarity, normalize_input, tol + ) + map_corr = _corr_vectors(data, maps[segmentation].T) + gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq + return gev, maps, segmentation + + @staticmethod + def _compute_maps( + data: NDArray[float], + n_clusters: int, + ignore_polarity: bool, + normalize_input: bool, + tol: Union[int, float], + ) -> Tuple[NDArray[float], NDArray[int]]: + """Compute microstates maps.""" + n_chan, n_frame = data.shape + + cluster = data.copy() + cluster /= np.linalg.norm(cluster, axis=0, keepdims=True) + + if normalize_input: + data = cluster.copy() + + GEV = np.sum(data * cluster, axis=0) + + assignment = np.arange(n_frame) + + while cluster.shape[1] > n_clusters: + + to_remove = np.argmin(GEV) + orphans = assignment == to_remove + + cluster = np.delete(cluster, to_remove, axis=1) + GEV = np.delete(GEV, to_remove, axis=0) + assignment[assignment > to_remove] = ( + assignment[assignment > to_remove] - 1 + ) + + fit = data[:, orphans].T @ cluster + if ignore_polarity: + fit = np.abs(fit) + new_assignment = np.argmax(fit, axis=1) + assignment[orphans] = new_assignment + + cluster_to_update = np.unique(new_assignment) + for c in cluster_to_update: + members = assignment == c + if ignore_polarity: + v, _ = AAHCluster._first_principal_component( + data[:, members], tol + ) + cluster[:, c] = v + else: + cluster[:, c] = np.mean(data[:, members], axis=0) + cluster[:, c] /= np.linalg.norm( + cluster[:, c], axis=0, keepdims=True + ) + new_fit = cluster[:, slice(c, c + 1)] * data[:, members] + if ignore_polarity: + new_fit = np.abs(new_fit) + GEV[c] = np.sum(new_fit) + return cluster.T, assignment + + @staticmethod + @njit(cache=True) + def _first_principal_component( + X: NDArray[float], tol: float, max_iter: int = 100 + ) -> Tuple[NDArray[float], float]: + """Compute first principal component. + + See :footcite:t:`Roweis1997` for additional information. + """ + + v = np.random.rand(X.shape[0]) + # v = np.ones((X.shape[0],)) + # v[::2] = -1 + v /= np.linalg.norm(v) + + for _ in range(max_iter): + s = np.sum((np.expand_dims(v, 0) @ X) * X, axis=1) + + eig = s.dot(v) + s_norm = np.linalg.norm(s) + + if np.linalg.norm(eig * v - s) / s_norm < tol: + break + v = s / s_norm + # else: + # logger.warn("First PC estimation: max iteration reached!") + return v, eig + + # -------------------------------------------------------------------- + + @property + def ignore_polarity(self) -> bool: + """If true, polarity is ignored when computing distances. + + :type: `bool` + """ + return self._ignore_polarity + + @property + def normalize_input(self) -> bool: + """If set, the input data is normalized along the channel dimension. + + :type: `bool` + """ + return self._normalize_input + + @property + def tol(self) -> Union[int, float]: + """Relative tolerance to reach convergence. + + :type: `float` + """ + return self._tol + + @property + def GEV_(self) -> float: + """Global Explained Variance. + + :type: `float` + """ + if self._GEV_ is None: + assert not self._fitted # sanity-check + logger.warning("Clustering algorithm has not been fitted.") + return self._GEV_ + + @_BaseCluster.fitted.setter + @copy_doc(_BaseCluster.fitted.setter) + def fitted(self, fitted): + super(self.__class__, self.__class__).fitted.__set__(self, fitted) + if not fitted: + self._GEV_ = None + + @staticmethod + def _check_ignore_polarity(ignore_polarity: bool) -> bool: + """Check that ignore_polarity is a boolean.""" + _check_type(ignore_polarity, ("bool",), item_name="ignore_polarity") + return ignore_polarity + + @staticmethod + def _check_normalize_input(normalize_input: bool) -> bool: + """Check that normalize_input is a boolean.""" + _check_type(normalize_input, ("bool",), item_name="normalize_input") + return normalize_input + + @staticmethod + def _check_tol(tol: Union[int, float]) -> Union[int, float]: + """Check that tol is a positive number.""" + _check_type(tol, ("numeric",), item_name="tol") + if tol <= 0: + raise ValueError( + "The tolerance must be a positive number. " + f"Provided: '{tol}'." + ) + return tol diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py new file mode 100644 index 00000000..ea3c68ac --- /dev/null +++ b/pycrostates/cluster/tests/test_aahc.py @@ -0,0 +1,1204 @@ + +"""Test AAHCluster.""" + +import logging +import re +from copy import deepcopy +from itertools import groupby + +import numpy as np +import pytest +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from mne import Annotations, Epochs, create_info, make_fixed_length_events +from mne.channels import DigMontage +from mne.datasets import testing +from mne.io import RawArray, read_raw_fif +from mne.io.pick import _picks_to_idx + +from pycrostates import __version__ +from pycrostates.cluster import AAHCluster +from pycrostates.io import ChData, ChInfo, read_cluster +from pycrostates.io.fiff import _read_cluster +from pycrostates.segmentation import EpochsSegmentation, RawSegmentation +from pycrostates.utils._logs import logger, set_log_level + +set_log_level("INFO") +logger.propagate = True + +directory = testing.data_path() / "MEG" / "sample" +fname = directory / "sample_audvis_trunc_raw.fif" + +# raw +raw_meg = read_raw_fif(fname, preload=False) +raw_meg.crop(0, 10) +raw_eeg = raw_meg.copy().pick("eeg").load_data().apply_proj() +raw_meg.pick_types(meg=True, eeg=True, exclude="bads") +raw_meg.load_data().apply_proj() +# epochs +events = make_fixed_length_events(raw_meg, duration=1) +epochs_meg = Epochs( + raw_meg, events, tmin=0, tmax=0.5, baseline=None, preload=True +) +epochs_eeg = Epochs( + raw_eeg, events, tmin=0, tmax=0.5, baseline=None, preload=True +) +# ch_data +ch_data = ChData(raw_eeg.get_data(), raw_eeg.info) +# Fit one for general purposes +n_clusters = 4 + +aahCluster = AAHCluster( + n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, + tol=1e-4 +) + +aahCluster.fit(ch_data) + +# pylint: disable=protected-access +def _check_fitted(aahCluster): + """ + Checks that the aahCluster is fitted. + """ + assert aahCluster.fitted + assert aahCluster.n_clusters == n_clusters + assert len(aahCluster._cluster_names) == n_clusters + assert len(aahCluster._cluster_centers_) == n_clusters + assert aahCluster._fitted_data is not None + assert aahCluster._info is not None + assert aahCluster.GEV_ is not None + assert aahCluster._labels_ is not None + + +def _check_unfitted(aahCluster): + """ + Checks that the aahCluster is not fitted. + """ + assert not aahCluster.fitted + assert aahCluster.n_clusters == n_clusters + assert len(aahCluster._cluster_names) == n_clusters + assert aahCluster._cluster_centers_ is None + assert aahCluster._fitted_data is None + assert aahCluster._info is None + assert aahCluster.GEV_ is None + assert aahCluster._labels_ is None + + +def _check_fitted_data_raw( + fitted_data, raw, picks, tmin, tmax, reject_by_annotation +): + """Check the fitted data array for a raw instance.""" + # Trust MNE .get_data() to correctly select data + picks = _picks_to_idx(raw.info, picks) + data = raw.get_data( + picks=picks, + tmin=tmin, + tmax=tmax, + reject_by_annotation=reject_by_annotation, + ) + assert data.shape == fitted_data.shape + + +def _check_fitted_data_epochs(fitted_data, epochs, picks, tmin, tmax): + """Check the fitted data array for an epoch instance.""" + picks = _picks_to_idx(epochs.info, picks) + # Trust MNE .get_data() to correctly select data + data = epochs.get_data(picks=picks, tmin=tmin, tmax=tmax) + # check channels + assert fitted_data.shape[0] == data.shape[1] + # check samples + assert fitted_data.shape[1] == int(data.shape[0] * data.shape[2]) + + +def test_aahClusterMeans(): + """Test K-Means default functionalities.""" + aahCluster1 = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + + # Test properties + assert aahCluster1.ignore_polarity == True + assert aahCluster1.normalize_input == False + assert aahCluster1.tol == 1e-4 + _check_unfitted(aahCluster1) + + # Test default clusters names + assert aahCluster1._cluster_names == ["0", "1", "2", "3"] + + # Test fit on RAW + aahCluster1.fit(raw_eeg, n_jobs=1) + _check_fitted(aahCluster1) + assert aahCluster1._cluster_centers_.shape == ( + n_clusters, + len(raw_eeg.info["ch_names"]) - len(raw_eeg.info["bads"]), + ) + + # Test reset + aahCluster1.fitted = False + _check_unfitted(aahCluster1) + + # Test fit on Epochs + aahCluster1.fit(epochs_eeg, n_jobs=1) + _check_fitted(aahCluster1) + assert aahCluster1._cluster_centers_.shape == ( + n_clusters, + len(epochs_eeg.info["ch_names"]) - len(epochs_eeg.info["bads"]), + ) + + # Test fit on ChData + aahCluster1.fitted = False + aahCluster1.fit(ch_data, n_jobs=1) + _check_fitted(aahCluster1) + assert aahCluster1._cluster_centers_.shape == ( + n_clusters, + len(raw_eeg.info["ch_names"]) - len(raw_eeg.info["bads"]), + ) + + # Test copy + aahCluster2 = aahCluster1.copy() + _check_fitted(aahCluster2) + assert np.isclose(aahCluster2._cluster_centers_, aahCluster1._cluster_centers_).all() + assert np.isclose(aahCluster2.GEV_, aahCluster1.GEV_) + assert np.isclose(aahCluster2._labels_, aahCluster1._labels_).all() + aahCluster2.fitted = False + _check_fitted(aahCluster1) + _check_unfitted(aahCluster2) + + aahCluster3 = aahCluster1.copy(deep=False) + _check_fitted(aahCluster3) + assert np.isclose(aahCluster3._cluster_centers_, aahCluster1._cluster_centers_).all() + assert np.isclose(aahCluster3.GEV_, aahCluster1.GEV_) + assert np.isclose(aahCluster3._labels_, aahCluster1._labels_).all() + aahCluster3.fitted = False + _check_fitted(aahCluster1) + _check_unfitted(aahCluster3) + + # Test representation + expected = f"" + assert expected == aahCluster1.__repr__() + assert "" == aahCluster2.__repr__() + + # Test HTML representation + html = aahCluster1._repr_html_() + assert html is not None + assert "not fitted" not in html + html = aahCluster2._repr_html_() + assert html is not None + assert "not fitted" in html + + # Test plot + f = aahCluster1.plot(block=False) + assert isinstance(f, Figure) + with pytest.raises(RuntimeError, match="must be fitted before"): + aahCluster2.plot(block=False) + plt.close("all") + + +def test_invert_polarity(): + """Test invert polarity method.""" + # list/tuple + aahCluster_ = aahCluster.copy() + cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) + aahCluster_.invert_polarity([True, False, True, False]) + assert np.isclose( + aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] + ).all() + + # bool + aahCluster_ = aahCluster.copy() + cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) + aahCluster_.invert_polarity(True) + assert np.isclose( + aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[1, :], -cluster_centers_[1, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[3, :], -cluster_centers_[3, :] + ).all() + + # np.array + aahCluster_ = aahCluster.copy() + cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) + aahCluster_.invert_polarity(np.array([True, False, True, False])) + assert np.isclose( + aahCluster_._cluster_centers_[0, :], -cluster_centers_[0, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[1, :], cluster_centers_[1, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[2, :], -cluster_centers_[2, :] + ).all() + assert np.isclose( + aahCluster_._cluster_centers_[3, :], cluster_centers_[3, :] + ).all() + + # Test invalid arguments + with pytest.raises(ValueError, match="not a 2D iterable"): + aahCluster_.invert_polarity(np.zeros((2, 4))) + with pytest.raises( + ValueError, match=re.escape("list of bools of length 'n_clusters' (4)") + ): + aahCluster_.invert_polarity([True, False, True, False, True]) + with pytest.raises(TypeError, match="'invert' must be an instance of "): + aahCluster_.invert_polarity(101) + + # Test unfitted + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + with pytest.raises(RuntimeError, match="must be fitted before"): + aahCluster_.invert_polarity([True, False, True, False]) + + +def test_rename(caplog): + """Test renaming of clusters.""" + alphabet = ["A", "B", "C", "D"] + + # Test mapping + aahCluster_ = aahCluster.copy() + mapping = {old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names)} + for key, value in mapping.items(): + assert isinstance(key, str) + assert isinstance(value, str) + assert key != value + aahCluster_.rename_clusters(mapping=mapping) + assert aahCluster_._cluster_names == alphabet + assert aahCluster_._cluster_names != aahCluster._cluster_names + + # Test new_names + aahCluster_ = aahCluster.copy() + aahCluster_.rename_clusters(new_names=alphabet) + assert aahCluster_._cluster_names == alphabet + assert aahCluster_._cluster_names != aahCluster._cluster_names + + # Test invalid arguments + aahCluster_ = aahCluster.copy() + with pytest.raises(TypeError, match="'mapping' must be an instance of "): + aahCluster_.rename_clusters(mapping=101) + with pytest.raises(ValueError, match="Invalid value for the 'old name'"): + mapping = { + old + "101": alphabet[k] + for k, old in enumerate(aahCluster._cluster_names) + } + aahCluster_.rename_clusters(mapping=mapping) + with pytest.raises(TypeError, match="'new name' must be an instance of "): + mapping = {old: k for k, old in enumerate(aahCluster._cluster_names)} + aahCluster_.rename_clusters(mapping=mapping) + with pytest.raises( + ValueError, match="Argument 'new_names' should contain" + ): + aahCluster_.rename_clusters(new_names=alphabet + ["E"]) + + aahCluster_.rename_clusters() + assert "Either 'mapping' or 'new_names' should not be" in caplog.text + + with pytest.raises( + ValueError, match="Only one of 'mapping' or 'new_names'" + ): + mapping = { + old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + } + aahCluster_.rename_clusters(mapping=mapping, new_names=alphabet) + + # Test unfitted + aahCluster_ = aahCluster.copy() + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + with pytest.raises(RuntimeError, match="must be fitted before"): + mapping = { + old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + } + aahCluster_.rename_clusters(mapping=mapping) + with pytest.raises(RuntimeError, match="must be fitted before"): + aahCluster_.rename_clusters(new_names=alphabet) + + +def test_reorder(caplog): + """Test reordering of clusters.""" + # Test mapping + aahCluster_ = aahCluster.copy() + aahCluster_.reorder_clusters(mapping={0: 1}) + assert np.isclose( + aahCluster._cluster_centers_[0, :], aahCluster_._cluster_centers_[1, :] + ).all() + assert np.isclose( + aahCluster._cluster_centers_[1, :], aahCluster_._cluster_centers_[0, :] + ).all() + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + + # Test order + aahCluster_ = aahCluster.copy() + aahCluster_.reorder_clusters(order=[1, 0, 2, 3]) + assert np.isclose( + aahCluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] + ).all() + assert np.isclose( + aahCluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] + ).all() + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + + aahCluster_ = aahCluster.copy() + aahCluster_.reorder_clusters(order=np.array([1, 0, 2, 3])) + assert np.isclose( + aahCluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] + ).all() + assert np.isclose( + aahCluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] + ).all() + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + + # test ._labels_ reordering + y = aahCluster._labels_[:20] + y[y == 0] = -1 + y[y == 1] = 0 + y[y == -1] = 1 + x = aahCluster_._labels_[:20] + assert np.all(x == y) + + # Test invalid arguments + aahCluster_ = aahCluster.copy() + with pytest.raises(TypeError, match="'mapping' must be an instance of "): + aahCluster_.reorder_clusters(mapping=101) + with pytest.raises( + ValueError, match="Invalid value for the 'old position'" + ): + aahCluster_.reorder_clusters(mapping={4: 1}) + with pytest.raises( + ValueError, match="Invalid value for the 'new position'" + ): + aahCluster_.reorder_clusters(mapping={0: 4}) + with pytest.raises( + ValueError, match="Position in the new order can not be repeated." + ): + aahCluster_.reorder_clusters(mapping={0: 1, 2: 1}) + with pytest.raises( + ValueError, match="A position can not be present in both" + ): + aahCluster_.reorder_clusters(mapping={0: 1, 1: 2}) + + with pytest.raises(TypeError, match="'order' must be an instance of "): + aahCluster_.reorder_clusters(order=101) + with pytest.raises(ValueError, match="Invalid value for the 'order'"): + aahCluster_.reorder_clusters(order=[4, 3, 1, 2]) + with pytest.raises( + ValueError, match="Argument 'order' should contain 'n_clusters'" + ): + aahCluster_.reorder_clusters(order=[0, 3, 1, 2, 0]) + with pytest.raises( + ValueError, match="Argument 'order' should be a 1D iterable" + ): + aahCluster_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) + + aahCluster_.reorder_clusters() + assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text + + with pytest.raises(ValueError, match="Only one of 'mapping' or 'order'"): + aahCluster_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3]) + + # Test unfitted + aahCluster_ = aahCluster.copy() + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + with pytest.raises(RuntimeError, match="must be fitted before"): + aahCluster_.reorder_clusters(mapping={0: 1}) + with pytest.raises(RuntimeError, match="must be fitted before"): + aahCluster_.reorder_clusters(order=[1, 0, 2, 3]) + + +def test_properties(caplog): + """Test properties.""" + caplog.set_level(logging.WARNING) + + # Unfitted + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + + aahCluster_.cluster_centers_ # pylint: disable=pointless-statement + assert "Clustering algorithm has not been fitted." in caplog.text + caplog.clear() + + aahCluster_.info # pylint: disable=pointless-statement + assert "Clustering algorithm has not been fitted." in caplog.text + caplog.clear() + + aahCluster_.fitted_data # pylint: disable=pointless-statement + assert "Clustering algorithm has not been fitted." in caplog.text + caplog.clear() + + # Fitted + aahCluster_ = aahCluster.copy() + + aahCluster_.cluster_centers_ + assert "Clustering algorithm has not been fitted." not in caplog.text + caplog.clear() + + aahCluster_.info + assert "Clustering algorithm has not been fitted." not in caplog.text + caplog.clear() + + aahCluster_.fitted_data + assert "Clustering algorithm has not been fitted." not in caplog.text + caplog.clear() + + # Test fitted property + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + with pytest.raises(TypeError, match="'fitted' must be an instance of"): + aahCluster_.fitted = "101" + caplog.clear() + aahCluster_.fitted = True + log = "'fitted' can not be set to 'True' directly. Please use the .fit()" + assert log in caplog.text + caplog.clear() + aahCluster_ = aahCluster.copy() + aahCluster_.fitted = True + log = "'fitted' can not be set to 'True' directly. The clustering" + assert log in caplog.text + + +def test_invalid_arguments(): + """Test invalid arguments for init and for fit.""" + # n_clusters + with pytest.raises( + TypeError, match="'n_clusters' must be an instance of " + ): + aahCluster_ = AAHCluster(n_clusters="4") + with pytest.raises(ValueError, match="The number of clusters must be a"): + aahCluster_ = AAHCluster(n_clusters=0) + with pytest.raises(ValueError, match="The number of clusters must be a"): + aahCluster_ = AAHCluster(n_clusters=-101) + + # tol + with pytest.raises(TypeError, match="'tol' must be an instance of "): + aahCluster_ = AAHCluster(n_clusters=4, tol="100") + with pytest.raises(ValueError, match="The tolerance must be a"): + aahCluster_ = AAHCluster(n_clusters=4, tol=0) + with pytest.raises(ValueError, match="The tolerance must be a"): + aahCluster_ = AAHCluster(n_clusters=4, tol=-101) + + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + # inst + with pytest.raises(TypeError, match="'inst' must be an instance of "): + aahCluster_.fit(epochs_eeg.average()) + + # tmin/tmax + with pytest.raises(TypeError, match="'tmin' must be an instance of "): + aahCluster_.fit(raw_eeg, tmin="101") + with pytest.raises(TypeError, match="'tmax' must be an instance of "): + aahCluster_.fit(raw_eeg, tmax="101") + with pytest.raises(ValueError, match="Argument 'tmin' must be positive"): + aahCluster_.fit(raw_eeg, tmin=-101, tmax=None) + with pytest.raises(ValueError, match="Argument 'tmax' must be positive"): + aahCluster_.fit(raw_eeg, tmin=None, tmax=-101) + with pytest.raises( + ValueError, + match="Argument 'tmax' must be strictly larger than 'tmin'.", + ): + aahCluster_.fit(raw_eeg, tmin=5, tmax=1) + with pytest.raises( + ValueError, + match="Argument 'tmin' must be shorter than the instance length.", + ): + aahCluster_.fit(raw_eeg, tmin=101, tmax=None) + with pytest.raises( + ValueError, + match="Argument 'tmax' must be shorter than the instance length.", + ): + aahCluster_.fit(raw_eeg, tmin=None, tmax=101) + + # reject_by_annotation + with pytest.raises( + TypeError, match="'reject_by_annotation' must be an instance of " + ): + aahCluster_.fit(raw_eeg, reject_by_annotation=1) + with pytest.raises(ValueError, match="only allows for"): + aahCluster_.fit(raw_eeg, reject_by_annotation="101") + + +def test_fit_data_shapes(): + """Test different tmin/tmax, rejection with fit.""" + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-8, + ) + + # tmin + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + raw_eeg, + n_jobs=1, + picks="eeg", + tmin=5, + tmax=None, + reject_by_annotation=False, + ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 5, None, None) + # save for later + fitted_data_5_end = deepcopy(aahCluster_._fitted_data) + + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + epochs_eeg, + n_jobs=1, + picks="eeg", + tmin=0.2, + tmax=None, + reject_by_annotation=False, + ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.2, None) + + # tmax + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + raw_eeg, + n_jobs=1, + picks="eeg", + tmin=None, + tmax=5, + reject_by_annotation=False, + ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", None, 5, None) + # save for later + fitted_data_0_5 = deepcopy(aahCluster_._fitted_data) + + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + epochs_eeg, + n_jobs=1, + picks="eeg", + tmin=None, + tmax=0.3, + reject_by_annotation=False, + ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", None, 0.3) + + # tmin, tmax + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + raw_eeg, + n_jobs=1, + picks="eeg", + tmin=2, + tmax=8, + reject_by_annotation=False, + ) + _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 2, 8, None) + + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + aahCluster_.fit( + epochs_eeg, + n_jobs=1, + picks="eeg", + tmin=0.1, + tmax=0.4, + reject_by_annotation=False, + ) + _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.1, 0.4) + + # --------------------- + # Reject by annotations + # --------------------- + bad_annot = Annotations([1], [2], "bad") + raw_ = raw_eeg.copy() + raw_.set_annotations(bad_annot) + + aahCluster_.fitted = False + _check_unfitted(aahCluster_) + + aahCluster_no_reject = aahCluster_.copy() + aahCluster_no_reject.fit(raw_, n_jobs=1, reject_by_annotation=False) + aahCluster_reject_True = aahCluster_.copy() + aahCluster_reject_True.fit(raw_, n_jobs=1, reject_by_annotation=True) + aahCluster_reject_omit = aahCluster_.copy() + aahCluster_reject_omit.fit(raw_, n_jobs=1, reject_by_annotation="omit") + + # Compare 'omit' and True + assert np.isclose( + aahCluster_reject_omit._fitted_data, + aahCluster_reject_True._fitted_data + ).all() + assert np.isclose(aahCluster_reject_omit.GEV_, aahCluster_reject_True.GEV_) + assert np.isclose( + aahCluster_reject_omit._labels_, aahCluster_reject_True._labels_ + ).all() + # due to internal randomness, the sign can be flipped + sgn = np.sign( + np.sum( + aahCluster_reject_True._cluster_centers_ * + aahCluster_reject_omit._cluster_centers_, axis=1 + ) + ) + aahCluster_reject_True._cluster_centers_ *= sgn[:, None] + assert np.isclose( + aahCluster_reject_omit._cluster_centers_, + aahCluster_reject_True._cluster_centers_ + ).all() + + # Make sure there is a shape diff between True and False + assert ( + aahCluster_reject_True._fitted_data.shape + != aahCluster_no_reject._fitted_data.shape + ) + + # Check fitted data + _check_fitted_data_raw( + aahCluster_reject_True._fitted_data, raw_, "eeg", None, None, "omit" + ) + _check_fitted_data_raw( + aahCluster_no_reject._fitted_data, raw_, "eeg", None, None, None + ) + + # Check with reject with tmin/tmax + aahCluster_rej_0_5 = aahCluster_.copy() + aahCluster_rej_0_5.fit(raw_, n_jobs=1, tmin=0, tmax=5, reject_by_annotation=True) + aahCluster_rej_5_end = aahCluster_.copy() + aahCluster_rej_5_end.fit( + raw_, n_jobs=1, tmin=5, tmax=None, reject_by_annotation=True + ) + _check_fitted(aahCluster_rej_0_5) + _check_fitted(aahCluster_rej_5_end) + _check_fitted_data_raw( + aahCluster_rej_0_5._fitted_data, raw_, "eeg", None, 5, "omit" + ) + _check_fitted_data_raw( + aahCluster_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" + ) + assert aahCluster_rej_0_5._fitted_data.shape != fitted_data_0_5.shape + assert np.isclose(fitted_data_5_end, aahCluster_rej_5_end._fitted_data).all() + + +def test_refit(): + """Test that re-fit does not overwrite the current instance.""" + raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + aahCluster_.fit(raw, picks="eeg") + eeg_ch_names = aahCluster_.info["ch_names"] + eeg_cluster_centers = aahCluster_.cluster_centers_ + aahCluster_.fitted = False # unfit + aahCluster_.fit(raw, picks="mag") + mag_ch_names = aahCluster_.info["ch_names"] + mag_cluster_centers = aahCluster_.cluster_centers_ + assert eeg_ch_names != mag_ch_names + assert eeg_cluster_centers.shape != mag_cluster_centers.shape + + # invalid + raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + aahCluster_.fit(raw, picks="eeg") # works + eeg_ch_names = aahCluster_.info["ch_names"] + eeg_cluster_centers = aahCluster_.cluster_centers_ + with pytest.raises(RuntimeError, match="must be unfitted"): + aahCluster_.fit(raw, picks="mag") # works + assert eeg_ch_names == aahCluster_.info["ch_names"] + assert np.allclose(eeg_cluster_centers, aahCluster_.cluster_centers_) + + +def test_predict_default(caplog): + """Test predict method default behaviors.""" + # raw, no smoothing, no_edge + segmentation = aahCluster.predict(raw_eeg, factor=0, reject_edges=False) + assert isinstance(segmentation, RawSegmentation) + assert "Segmenting data without smoothing" in caplog.text + caplog.clear() + + # raw, no smoothing, with edge rejection + segmentation = aahCluster.predict(raw_eeg, factor=0, reject_edges=True) + assert isinstance(segmentation, RawSegmentation) + assert segmentation._labels[0] == -1 + assert segmentation._labels[-1] == -1 + assert "Rejecting first and last segments." in caplog.text + caplog.clear() + + # raw, with smoothing + segmentation = aahCluster.predict(raw_eeg, factor=3, reject_edges=True) + assert isinstance(segmentation, RawSegmentation) + assert segmentation._labels[0] == -1 + assert segmentation._labels[-1] == -1 + assert "Segmenting data with factor 3" in caplog.text + caplog.clear() + + # raw with min_segment_length + segmentation = aahCluster.predict( + raw_eeg, factor=0, reject_edges=False, min_segment_length=5 + ) + assert isinstance(segmentation, RawSegmentation) + segment_lengths = [ + len(list(group)) for _, group in groupby(segmentation._labels) + ] + assert all(5 <= size for size in segment_lengths[1:-1]) + assert "Rejecting segments shorter than" in caplog.text + caplog.clear() + + # epochs, no smoothing, no_edge + segmentation = aahCluster.predict(epochs_eeg, factor=0, reject_edges=False) + assert isinstance(segmentation, EpochsSegmentation) + assert "Segmenting data without smoothing" in caplog.text + caplog.clear() + + # epochs, no smoothing, with edge rejection + segmentation = aahCluster.predict(epochs_eeg, factor=0, reject_edges=True) + assert isinstance(segmentation, EpochsSegmentation) + for epoch_labels in segmentation._labels: + assert epoch_labels[0] == -1 + assert epoch_labels[-1] == -1 + assert "Rejecting first and last segments." in caplog.text + caplog.clear() + + # epochs, with smoothing + segmentation = aahCluster.predict(epochs_eeg, factor=3, reject_edges=True) + assert isinstance(segmentation, EpochsSegmentation) + for epoch_labels in segmentation._labels: + assert epoch_labels[0] == -1 + assert epoch_labels[-1] == -1 + assert "Segmenting data with factor 3" in caplog.text + caplog.clear() + + # epochs with min_segment_length + segmentation = aahCluster.predict( + epochs_eeg, factor=0, reject_edges=False, min_segment_length=5 + ) + assert isinstance(segmentation, EpochsSegmentation) + for epoch_labels in segmentation._labels: + segment_lengths = [ + len(list(group)) for _, group in groupby(epoch_labels) + ] + assert all(5 <= size for size in segment_lengths[1:-1]) + assert "Rejecting segments shorter than" in caplog.text + caplog.clear() + + # raw with reject_by_annotation + bad_annot = Annotations([1], [2], "bad") + raw_ = raw_eeg.copy() + raw_.set_annotations(bad_annot) + segmentation_rej_True = aahCluster.predict( + raw_, factor=0, reject_edges=True, reject_by_annotation=True + ) + segmentation_rej_False = aahCluster.predict( + raw_, factor=0, reject_edges=True, reject_by_annotation=False + ) + segmentation_rej_None = aahCluster.predict( + raw_, factor=0, reject_edges=True, reject_by_annotation=None + ) + segmentation_no_annot = aahCluster.predict( + raw_eeg, factor=0, reject_edges=True, reject_by_annotation="omit" + ) + assert not np.isclose( + segmentation_rej_True._labels, segmentation_rej_False._labels + ).all() + assert np.isclose( + segmentation_no_annot._labels, segmentation_rej_False._labels + ).all() + assert np.isclose( + segmentation_rej_None._labels, segmentation_rej_False._labels + ).all() + + # test different half_window_size + segmentation1 = aahCluster.predict( + raw_eeg, factor=3, reject_edges=False, half_window_size=3 + ) + segmentation2 = aahCluster.predict( + raw_eeg, factor=3, reject_edges=False, half_window_size=60 + ) + segmentation3 = aahCluster.predict( + raw_eeg, factor=0, reject_edges=False, half_window_size=3 + ) + assert not np.isclose(segmentation1._labels, segmentation2._labels).all() + assert not np.isclose(segmentation1._labels, segmentation3._labels).all() + assert not np.isclose(segmentation2._labels, segmentation3._labels).all() + + +def test_picks_fit_predict(caplog): + """Test fitting and prediction with different picks.""" + raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + + # test invalid fit + with pytest.raises(ValueError, match="Only one datatype can be selected"): + aahCluster_.fit(raw, picks=None) # fails -> eeg + grad + mag + with pytest.raises(ValueError, match="Only one datatype can be selected"): + aahCluster_.fit(raw, picks="meg") # fails -> grad + mag + with pytest.raises(ValueError, match="Only one datatype can be selected"): + aahCluster_.fit(raw, picks="data") # fails -> eeg + grad + mag + + # test valid fit + aahCluster_.fit(raw, picks="mag") # works + aahCluster_.fitted = False + aahCluster_.fit(raw, picks="eeg") # works + aahCluster_.fitted = False + + # create mock raw for fitting + info_ = create_info( + ["Fp1", "Fp2", "CP1", "CP2"], sfreq=1024, ch_types="eeg" + ) + info_.set_montage("standard_1020") + data = np.random.randn(4, 1024 * 10) + + # Ignore bad channel Fp2 during fitting + info = info_.copy() + info["bads"] = ["Fp2"] + raw = RawArray(data, info) + + caplog.clear() + aahCluster_.fit(raw, picks="eeg") + assert aahCluster_.info["ch_names"] == ["Fp1", "CP1", "CP2"] + assert "The channel Fp2 is set as bad and ignored" in caplog.text + caplog.clear() + + # predict with the same channels in the instance used for prediction + info = info_.copy() + raw_predict = RawArray(data, info) + caplog.clear() + aahCluster_.predict(raw_predict, picks="eeg") # -> warning for selected Fp2 + assert "Fp2 which was not used during fitting" in caplog.text + caplog.clear() + aahCluster_.predict(raw_predict, picks=["Fp1", "CP1", "CP2"]) + assert "Fp2 which was not used during fitting" not in caplog.text + caplog.clear() + raw_predict.info["bads"] = ["Fp2"] + aahCluster_.predict(raw_predict, picks="eeg") + assert "Fp2 which was not used during fitting" not in caplog.text + + # predict with a channel used for fitting that is now missing + # fails, because aahCluster.info includes Fp1 which is bad in prediction instance + raw_predict.info["bads"] = ["Fp1"] + with pytest.raises(ValueError, match="Fp1 is required to predict"): + aahCluster_.predict(raw_predict, picks="eeg") + caplog.clear() + + # predict with a channel used for fitting that is now bad + aahCluster_.predict(raw_predict, picks=["Fp1", "CP1", "CP2"]) + assert "Fp1 is set as bad in the instance but was selected" in caplog.text + caplog.clear() + + # fails, because aahCluster_.info includes Fp1 which is missing from prediction + # instance selection + with pytest.raises(ValueError, match="Fp1 is required to predict"): + aahCluster_.predict(raw_predict, picks=["CP2", "CP1"]) + + # Try with one additional channel in the instance used for prediction. + info_ = create_info( + ["Fp1", "Fp2", "Fpz", "CP2", "CP1"], sfreq=1024, ch_types="eeg" + ) + info_.set_montage("standard_1020") + data = np.random.randn(5, 1024 * 10) + raw_predict = RawArray(data, info_) + + # works, with warning because Fpz, Fp2 are missing from aahCluster_.info + caplog.clear() + aahCluster_.predict(raw_predict, picks="eeg") + # handle non-deterministic sets + msg1 = "Fp2, Fpz which were not used during fitting" + msg2 = "Fpz, Fp2 which were not used during fitting" + assert msg1 in caplog.text or msg2 in caplog.text + caplog.clear() + + # fails, because aahCluster_.info includes Fp1 which is missing from prediction + # instance selection + with pytest.raises(ValueError, match="Fp1 is required to predict"): + aahCluster_.predict(raw_predict, picks=["Fp2", "Fpz", "CP2", "CP1"]) + caplog.clear() + + # works, with warning because Fpz is missing from aahCluster_.info + aahCluster_.predict(raw_predict, picks=["Fp1", "Fpz", "CP2", "CP1"]) + assert "Fpz which was not used during fitting" in caplog.text + caplog.clear() + + # try with a missing channel from the prediction instance + # fails, because Fp1 is used in aahCluster.info + raw_predict.drop_channels(["Fp1"]) + with pytest.raises( + ValueError, match="Fp1 was used during fitting but is missing" + ): + aahCluster_.predict(raw_predict, picks="eeg") + + # set a bad channel during fitting + info = info_.copy() + info["bads"] = ["Fp2"] + raw = RawArray(data, info) + + aahCluster_.fitted = False + caplog.clear() + aahCluster_.fit(raw, picks=["Fp1", "Fp2", "CP2", "CP1"]) + assert aahCluster_.info["ch_names"] == ["Fp1", "Fp2", "CP2", "CP1"] + assert "Fp2 is set as bad and will be used" in caplog.text + caplog.clear() + + # predict with the same channels in the instance used for prediction + info = info_.copy() + raw_predict = RawArray(data, info) + # works, with warning because a channel is bads in aahCluster_.info + caplog.clear() + aahCluster_.predict(raw_predict, picks="eeg") + predict_warning = "fit contains bad channel Fp2 which will be used" + assert predict_warning in caplog.text + caplog.clear() + + # works, with warning because a channel is bads in aahCluster_.info + raw_predict.info["bads"] = [] + aahCluster_.predict(raw_predict, picks=["Fp1", "Fp2", "CP2", "CP1"]) + assert predict_warning in caplog.text + caplog.clear() + + # fails, because Fp2 is used in aahCluster_.info + with pytest.raises(ValueError, match="Fp2 is required to predict"): + aahCluster_.predict(raw_predict, picks=["Fp1", "CP2", "CP1"]) + + # fails, because Fp2 is used in aahCluster_.info + raw_predict.info["bads"] = ["Fp2"] + with pytest.raises(ValueError, match="Fp2 is required to predict"): + aahCluster_.predict(raw_predict, picks="eeg") + + # works, because same channels as aahCluster_.info + caplog.clear() + aahCluster_.predict(raw_predict, picks=["Fp1", "Fp2", "CP2", "CP1"]) + assert predict_warning in caplog.text + assert "Fp2 is set as bad in the instance but was selected" in caplog.text + caplog.clear() + + # fails, because aahCluster_.info includes Fp1 which is bad in prediction + # instance + raw_predict.info["bads"] = ["Fp1"] + with pytest.raises(ValueError, match="Fp1 is required to predict because"): + aahCluster_.predict(raw_predict, picks="eeg") + + # fails, because aahCluster_.info includes bad Fp2 + with pytest.raises(ValueError, match="Fp2 is required to predict"): + aahCluster_.predict(raw_predict, picks=["Fp1", "CP2", "CP1"]) + + # works, because same channels as aahCluster_.info (with warnings for Fp1, Fp2) + caplog.clear() + aahCluster_.predict(raw_predict, picks=["Fp1", "Fp2", "CP2", "CP1"]) + assert predict_warning in caplog.text + assert "Fp1 is set as bad in the instance but was selected" in caplog.text + caplog.clear() + + +def test_predict_invalid_arguments(): + """Test invalid arguments passed to predict.""" + with pytest.raises(TypeError, match="'inst' must be an instance of "): + aahCluster.predict(epochs_eeg.average()) + with pytest.raises(TypeError, match="'factor' must be an instance of "): + aahCluster.predict(raw_eeg, factor="0") + with pytest.raises( + TypeError, match="'reject_edges' must be an instance of " + ): + aahCluster.predict(raw_eeg, reject_edges=1) + with pytest.raises( + TypeError, match="'half_window_size' must be an instance of " + ): + aahCluster.predict(raw_eeg, half_window_size="1") + with pytest.raises(TypeError, match="'tol' must be an instance of "): + aahCluster.predict(raw_eeg, tol="0") + with pytest.raises( + TypeError, match="'min_segment_length' must be an instance of " + ): + aahCluster.predict(raw_eeg, min_segment_length="0") + with pytest.raises( + TypeError, match="'reject_by_annotation' must be an instance of " + ): + aahCluster.predict(raw_eeg, reject_by_annotation=1) + with pytest.raises(ValueError, match="'reject_by_annotation' can be"): + aahCluster.predict(raw_eeg, reject_by_annotation="101") + + +def test_contains_mixin(): + """Test contains mixin class.""" + assert "eeg" in aahCluster + assert aahCluster.compensation_grade is None + assert aahCluster.get_channel_types() == ["eeg"] * aahCluster._info["nchan"] + + # test raise with non-fitted instance + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + with pytest.raises( + ValueError, match="Instance 'AAHCluster' attribute 'info' is None." + ): + "eeg" in aahCluster_ + with pytest.raises( + ValueError, match="Instance 'AAHCluster' attribute 'info' is None." + ): + aahCluster_.get_channel_types() + with pytest.raises( + ValueError, match="Instance 'AAHCluster' attribute 'info' is None." + ): + aahCluster_.compensation_grade + + +def test_montage_mixin(): + """Test montage mixin class.""" + aahCluster_ = aahCluster.copy() + montage = aahCluster.get_montage() + assert isinstance(montage, DigMontage) + assert montage.dig[-1]["r"][0] != 0 + montage.dig[-1]["r"][0] = 0 + aahCluster_.set_montage(montage) + montage_ = aahCluster_.get_montage() + assert montage_.dig[-1]["r"][0] == 0 + + # test raise with non-fitted instance + aahCluster_ = AAHCluster( + n_clusters=n_clusters, + ignore_polarity=True, + normalize_input=False, + tol=1e-4, + ) + with pytest.raises( + ValueError, match="Instance 'AAHCluster' attribute 'info' is None." + ): + aahCluster_.set_montage("standard_1020") + + with pytest.raises( + ValueError, match="Instance 'AAHCluster' attribute 'info' is None." + ): + aahCluster_.get_montage() + + +def test_save(tmp_path, caplog): + """Test .save() method.""" + # writing to .fif + fname1 = tmp_path / "cluster.fif" + aahCluster.save(fname1) + + # writing to .gz (compression) + fname2 = tmp_path / "cluster.fif.gz" + aahCluster.save(fname2) + + # re-load + caplog.clear() + aahCluster1 = read_cluster(fname1) + assert __version__ in caplog.text + caplog.clear() + aahCluster2, version = _read_cluster(fname2) + assert version == __version__ + assert __version__ not in caplog.text + + # compare + assert aahCluster == aahCluster1 + assert aahCluster == aahCluster2 + assert aahCluster1 == aahCluster2 # sanity-check + + # test prediction + segmentation = aahCluster.predict(raw_eeg, picks="eeg") + segmentation1 = aahCluster1.predict(raw_eeg, picks="eeg") + segmentation2 = aahCluster2.predict(raw_eeg, picks="eeg") + + assert np.allclose(segmentation._labels, segmentation1._labels) + assert np.allclose(segmentation._labels, segmentation2._labels) + assert np.allclose(segmentation1._labels, segmentation2._labels) + + +def test_comparison(caplog): + """Test == and != methods.""" + aahCluster1 = aahCluster.copy() + aahCluster2 = aahCluster.copy() + assert aahCluster1 == aahCluster2 + + # with different aahClustermeans variables + aahCluster1.fitted = False + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._ignore_polarity = False + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._normalize_input = True + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._tol = 0.101 + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._GEV_ = 0.101 + assert aahCluster1 != aahCluster2 + + # with different object + assert aahCluster1 != 101 + + # with different base variables + aahCluster1 = aahCluster.copy() + aahCluster2 = aahCluster.copy() + assert aahCluster1 == aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._n_clusters = 101 + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._info = ChInfo( + ch_names=[str(k) for k in range(aahCluster1._cluster_centers_.shape[1])], + ch_types=["eeg"] * aahCluster1._cluster_centers_.shape[1], + ) + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._labels_ = aahCluster1._labels_[::-1] + assert aahCluster1 != aahCluster2 + aahCluster1 = aahCluster.copy() + aahCluster1._fitted_data = aahCluster1._fitted_data[:, ::-1] + assert aahCluster1 != aahCluster2 + + # different cluster names + aahCluster1 = aahCluster.copy() + aahCluster2 = aahCluster.copy() + caplog.clear() + assert aahCluster1 == aahCluster2 + assert "Cluster names differ between both clustering" not in caplog.text + aahCluster1._cluster_names = aahCluster1._cluster_names[::-1] + caplog.clear() + assert aahCluster1 == aahCluster2 + assert "Cluster names differ between both clustering" in caplog.text diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index ada00075..7bc2dfee 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -34,7 +34,7 @@ from .. import __version__ from .._typing import CHInfo -from ..cluster import ModKMeans +from ..cluster import ModKMeans, AAHCluster from ..utils._checks import _check_type, _check_value from ..utils._docs import fill_doc from ..utils._logs import logger @@ -110,7 +110,7 @@ def _write_cluster( if isinstance(chinfo, Info): chinfo = ChInfo(chinfo) # convert to ChInfo if a MNE Info is provided _check_type(algorithm, (str,), "algorithm") - _check_value(algorithm, ("ModKMeans",), "algorithm") + _check_value(algorithm, ("ModKMeans", "AAHCluster"), "algorithm") _check_type(cluster_names, (list,), "cluster_names") if len(cluster_names) != cluster_centers_.shape[0]: raise ValueError( @@ -178,6 +178,10 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): "parameters": ["n_init", "max_iter", "tol"], "variables": ["GEV_"], }, + "AAHCluster": { + "parameters": ["ignore_polarity", "normalize_input", "tol"], + "variables": ["GEV_"], + }, } # retrieve list of expected kwargs for this algorithm @@ -202,13 +206,23 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): continue # ModKMeans - if key == "n_init": - fit_parameters["n_init"] = ModKMeans._check_n_init(value) - elif key == "max_iter": - fit_parameters["max_iter"] = ModKMeans._check_max_iter(value) - elif key == "tol": - fit_parameters["tol"] = ModKMeans._check_tol(value) - elif key == "GEV_": + if algorithm == "ModKMeans": + if key == "n_init": + fit_parameters["n_init"] = ModKMeans._check_n_init(value) + elif key == "max_iter": + fit_parameters["max_iter"] = ModKMeans._check_max_iter(value) + elif key == "tol": + fit_parameters["tol"] = ModKMeans._check_tol(value) + elif algorithm == "AAHCluster": + if key == "ignore_polarity": + fit_parameters["ignore_polarity"] = \ + AAHCluster._check_ignore_polarity(value) + elif key == "normalize_input": + fit_parameters["normalize_input"] = \ + AAHCluster._check_normalize_input(value) + elif key == "tol": + fit_parameters["tol"] = AAHCluster._check_tol(value) + if key == "GEV_": _check_type(value, ("numeric",), "GEV_") if value < 0 or 1 < value: raise ValueError( @@ -310,7 +324,10 @@ def _read_cluster(fname: Union[str, Path]): ) # reconstruct cluster instance - function = {"ModKMeans": _create_ModKMeans} + function = { + "ModKMeans": _create_ModKMeans, + "AAHCluster": _create_AAHCluster + } return ( function[algorithm]( @@ -336,6 +353,10 @@ def _check_fit_parameters_and_variables( "parameters": ["n_init", "max_iter", "tol"], "variables": ["GEV_"], }, + "AAHCluster": { + "parameters": ["ignore_polarity", "normalize_input", "tol"], + "variables": ["GEV_"], + }, } if "algorithm" not in fit_parameters: raise ValueError("Key 'algorithm' is missing from .fif file.") @@ -379,6 +400,32 @@ def _create_ModKMeans( return cluster +def _create_AAHCluster( + cluster_centers_: NDArray[float], + info: CHInfo, + cluster_names: List[str], + fitted_data: NDArray[float], + labels_: NDArray[int], + ignore_polarity: bool, + normalize_input: bool, + tol: Union[int, float], + GEV_: float, +): + + print(normalize_input, ignore_polarity) + """Create a AAHCluster object.""" + cluster = AAHCluster( + cluster_centers_.shape[0], ignore_polarity, normalize_input, tol + ) + cluster._cluster_centers_ = cluster_centers_ + cluster._info = info + cluster._cluster_names = cluster_names + cluster._fitted_data = fitted_data + cluster._labels_ = labels_ + cluster._GEV_ = GEV_ + cluster._fitted = True + return cluster + # ---------------------------------------------------------------------------- def _write_meas_info(fid, info: CHInfo): """Write measurement info into a file id (from a fif file). @@ -560,6 +607,8 @@ def _serialize(dict_: dict, outer_sep: str = ";", inner_sep: str = ":"): for key, value in dict_.items(): if callable(value): value = value.__name__ + elif isinstance(value, bool): + pass elif isinstance(value, Integral): value = int(value) elif isinstance(value, dict): diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index e1b50742..3da0528e 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -67,6 +67,7 @@ def __instancecheck__(cls, other): "path-like": (str, Path, os.PathLike), "int": (_IntLike(),), "callable": (_Callable(),), + "bool": (bool,), } From 192f75e9fab250e785766d7b2c2a10c1a3bbc550 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Fri, 28 Oct 2022 10:32:38 +0900 Subject: [PATCH 02/24] - added AAHC to sphinx documentation - modified AAHC to use the previous cluster center as initialization for the first PC power iteration algorithm --- docs/references.bib | 15 ------ docs/source/api/cluster.rst | 1 + pycrostates/cluster/aahc.py | 46 +++++++++++++++---- .../html_templates/repr/AAHCluster.html.jinja | 29 ++++++++++++ 4 files changed, 66 insertions(+), 25 deletions(-) create mode 100644 pycrostates/html_templates/repr/AAHCluster.html.jinja diff --git a/docs/references.bib b/docs/references.bib index 6c405da1..814d5e83 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -39,22 +39,7 @@ @article{MICHEL2018577 keywords = {EEG microstates, Resting state networks, Consciousness, Psychiatric disease, State-dependent information processing, Metastability}, } -% Michel 2018 - microstates review -@article{MICHEL2018577, - title = {EEG microstates as a tool for studying the temporal dynamics of whole-brain neuronal networks: A review}, - journal = {NeuroImage}, - volume = {180}, - pages = {577-593}, - year = {2018}, - note = {Brain Connectivity Dynamics}, - issn = {1053-8119}, - doi = {10.1016/j.neuroimage.2017.11.062}, - author = {Christoph M. Michel and Thomas Koenig}, - keywords = {EEG microstates, Resting state networks, Consciousness, Psychiatric disease, State-dependent information processing, Metastability}, -} - % Atomize and Agglomerate Hierarchical Clustering (AAHC) - @article{Murray2008, author = {Murray, Micah M. and Brunet, Denis and Michel, Christoph M.}, journal = {Brain Topography}, diff --git a/docs/source/api/cluster.rst b/docs/source/api/cluster.rst index 76d71165..17dc1660 100644 --- a/docs/source/api/cluster.rst +++ b/docs/source/api/cluster.rst @@ -13,3 +13,4 @@ Cluster :toctree: generated/ ModKMeans + AAHCluster diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index afdd89b7..83807a66 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -21,7 +21,6 @@ try: from numba import njit except ImportError: - def njit(cache=False): def decorator(func): @func_wraps(func) @@ -33,6 +32,7 @@ def wrapper(*args, **kwargs): return decorator +@fill_doc class AAHCluster(_BaseCluster): r"""Atomize and Agglomerate Hierarchical Clustering (AAHC) algorithm. @@ -41,7 +41,6 @@ class AAHCluster(_BaseCluster): Parameters ---------- %(n_clusters)s - ignore_polarity : bool If true, polarity is ignored when computing distances. normalize_input : bool @@ -81,7 +80,7 @@ def __init__( def _repr_html_(self, caption=None): from ..html_templates import repr_templates_env - template = repr_templates_env.get_template("ModKMeans.html.jinja") + template = repr_templates_env.get_template("AAHCluster.html.jinja") if self.fitted: n_samples = self._fitted_data.shape[-1] ch_types, ch_counts = np.unique( @@ -262,7 +261,7 @@ def _compute_maps( members = assignment == c if ignore_polarity: v, _ = AAHCluster._first_principal_component( - data[:, members], tol + data[:, members], tol, cluster[:, c] ) cluster[:, c] = v else: @@ -279,16 +278,45 @@ def _compute_maps( @staticmethod @njit(cache=True) def _first_principal_component( - X: NDArray[float], tol: float, max_iter: int = 100 + X: NDArray[float], + tol: float, + v0: Optional[NDArray[float]] = None, + max_iter: int = 100, ) -> Tuple[NDArray[float], float]: """Compute first principal component. + Parameters + ---------- + X : numpy.ndarray + Input data matrix with dimensions (channels x observations) + tol : float + Tolerance for convergence. + v0 : numpy.ndarray or None (default) + Initial estimate of the principal vector with dimensions + (channels, ) + If None, a random initialization is used + max_iter : int + Maximum number of iterations to estiamte the first principal + component. + + Returns + ------- + + tuple(v, eig) + + v : numpy.ndarray + estimated principal component vector + eig : float + eigenvalue of the covariance matrix of X (channels x channels) + See :footcite:t:`Roweis1997` for additional information. """ - v = np.random.rand(X.shape[0]) - # v = np.ones((X.shape[0],)) - # v[::2] = -1 + if v0 is None: # use a random choice + v = np.random.rand(X.shape[0]) + else: + v = v0.flatten() + assert v.shape[0] == X.shape[0] v /= np.linalg.norm(v) for _ in range(max_iter): @@ -300,8 +328,6 @@ def _first_principal_component( if np.linalg.norm(eig * v - s) / s_norm < tol: break v = s / s_norm - # else: - # logger.warn("First PC estimation: max iteration reached!") return v, eig # -------------------------------------------------------------------- diff --git a/pycrostates/html_templates/repr/AAHCluster.html.jinja b/pycrostates/html_templates/repr/AAHCluster.html.jinja new file mode 100644 index 00000000..55737e99 --- /dev/null +++ b/pycrostates/html_templates/repr/AAHCluster.html.jinja @@ -0,0 +1,29 @@ + + + + + + + + + + + {% if fitted %} + + + + + + + + + + + + + + + + + {% endif %} +
Method{{ name }}
Fit{% if fitted %}fitted to {{ n_samples }} samples{% else %}not fitted{% endif %}
Cluster centers{{ n_clusters }}
GEV{{ GEV }} %
Cluster centers names{{ cluster_names|join(', ') }}
Channels{{ ch_repr|join(', ') }}
From c67f3a85ab814d4e30459d371287d4936403c184 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Fri, 28 Oct 2022 11:22:34 +0900 Subject: [PATCH 03/24] aahc: changed initialization of first PC to weighted sum of old and new cluster members --- pycrostates/cluster/aahc.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 83807a66..2528b007 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -244,6 +244,8 @@ def _compute_maps( to_remove = np.argmin(GEV) orphans = assignment == to_remove + old_cluster = cluster[:, to_remove].copy() + cluster = np.delete(cluster, to_remove, axis=1) GEV = np.delete(GEV, to_remove, axis=0) assignment[assignment > to_remove] = ( @@ -260,8 +262,16 @@ def _compute_maps( for c in cluster_to_update: members = assignment == c if ignore_polarity: + + old_weight = len(new_assignment == c) / members.sum() + + sgn = np.sign(old_cluster @ cluster[:, c]) + + v0 = old_weight * sgn * old_cluster + \ + (1. - old_weight) * cluster[:, c] + v, _ = AAHCluster._first_principal_component( - data[:, members], tol, cluster[:, c] + data[:, members], tol, v0 ) cluster[:, c] = v else: From 66aaded5c538de7acf9e4c19257e23219a1bbb17 Mon Sep 17 00:00:00 2001 From: rkobler Date: Tue, 1 Nov 2022 06:46:32 +0900 Subject: [PATCH 04/24] Update pycrostates/cluster/__init__.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycrostates/cluster/__init__.py b/pycrostates/cluster/__init__.py index 54a6014f..397907ce 100644 --- a/pycrostates/cluster/__init__.py +++ b/pycrostates/cluster/__init__.py @@ -15,6 +15,6 @@ to segment.""" from .kmeans import ModKMeans # noqa: F401 -from .aahc import AAHCluster +from .aahc import AAHCluster # noqa: F401 __all__ = ("ModKMeans", "AAHCluster") From 6074a7f5c3ef7433058e7c94da87c6395e916d06 Mon Sep 17 00:00:00 2001 From: rkobler Date: Tue, 1 Nov 2022 06:47:53 +0900 Subject: [PATCH 05/24] Update doc in pycrostates/cluster/aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/aahc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 2528b007..e565a066 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -303,8 +303,7 @@ def _first_principal_component( Tolerance for convergence. v0 : numpy.ndarray or None (default) Initial estimate of the principal vector with dimensions - (channels, ) - If None, a random initialization is used + (channels,). If None, a random initialization is used. max_iter : int Maximum number of iterations to estiamte the first principal component. From 62dc4a1230ca8b045c4db79d6d8d7f98862add1e Mon Sep 17 00:00:00 2001 From: rkobler Date: Tue, 1 Nov 2022 07:08:56 +0900 Subject: [PATCH 06/24] Update doc in pycrostates/cluster/aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/aahc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index e565a066..0d9edf0c 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -310,13 +310,10 @@ def _first_principal_component( Returns ------- - - tuple(v, eig) - v : numpy.ndarray - estimated principal component vector + Estimated principal component vector. eig : float - eigenvalue of the covariance matrix of X (channels x channels) + Eigenvalue of the covariance matrix of X (channels x channels) See :footcite:t:`Roweis1997` for additional information. """ From 7c69bf8745bcc30a1f13c6b711799ba9aee6c80a Mon Sep 17 00:00:00 2001 From: rkobler Date: Tue, 1 Nov 2022 07:09:10 +0900 Subject: [PATCH 07/24] Update doc in pycrostates/cluster/aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/aahc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 0d9edf0c..79de4f97 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -305,7 +305,7 @@ def _first_principal_component( Initial estimate of the principal vector with dimensions (channels,). If None, a random initialization is used. max_iter : int - Maximum number of iterations to estiamte the first principal + Maximum number of iterations to estimate the first principal component. Returns From a3f7afc062b2f43cf61a9bc2b43f7af29a470524 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Tue, 1 Nov 2022 11:58:47 +0900 Subject: [PATCH 08/24] - BUGFIX in pycrostates/cluster/aahc.py: wrong axis in cluster map computation - test cases for ignore_polarity True/False - removed tol parameter in pycrostates/cluster/aahc.py - revered 'bool' check in pycrostates.utils/_checks.py --- pycrostates/cluster/__init__.py | 2 +- pycrostates/cluster/aahc.py | 84 ++++------- pycrostates/cluster/tests/test_aahc.py | 200 ++++++++++++++++++------- pycrostates/io/fiff.py | 21 +-- pycrostates/utils/_checks.py | 1 - 5 files changed, 191 insertions(+), 117 deletions(-) diff --git a/pycrostates/cluster/__init__.py b/pycrostates/cluster/__init__.py index 397907ce..712d5dc6 100644 --- a/pycrostates/cluster/__init__.py +++ b/pycrostates/cluster/__init__.py @@ -14,7 +14,7 @@ :class:`~pycrostates.segmentation.EpochsSegmentation` depending on the dataset to segment.""" -from .kmeans import ModKMeans # noqa: F401 from .aahc import AAHCluster # noqa: F401 +from .kmeans import ModKMeans # noqa: F401 __all__ = ("ModKMeans", "AAHCluster") diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 79de4f97..12d59f91 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -21,6 +21,7 @@ try: from numba import njit except ImportError: + def njit(cache=False): def decorator(func): @func_wraps(func) @@ -45,9 +46,6 @@ class AAHCluster(_BaseCluster): If true, polarity is ignored when computing distances. normalize_input : bool If set, the input data is normalized along the channel dimension. - tol : float - Relative tolerance with regards estimate residual noise in the cluster - centers of two consecutive iterations to declare convergence. References ---------- @@ -59,7 +57,6 @@ def __init__( n_clusters: int, ignore_polarity: bool = True, normalize_input: bool = False, - tol: float = 1e-6, ): super().__init__() @@ -72,7 +69,6 @@ def __init__( self._normalize_input = AAHCluster._check_ignore_polarity( normalize_input ) - self._tol = AAHCluster._check_tol(tol) # fit variables self._GEV_ = None @@ -116,7 +112,6 @@ def __eq__(self, other: Any) -> bool: attributes = ( "_ignore_polarity", "_normalize_input", - "_tol", "_GEV_", ) for attribute in attributes: @@ -168,7 +163,6 @@ def fit( self._n_clusters, self._ignore_polarity, self._normalize_input, - self._tol, ) if gev is not None: @@ -196,7 +190,6 @@ def save(self, fname: Union[str, Path]): self._labels_, ignore_polarity=self._ignore_polarity, normalize_input=self._normalize_input, - tol=self._tol, GEV_=self._GEV_, ) @@ -207,12 +200,11 @@ def _aahc( n_clusters: int, ignore_polarity: bool, normalize_input: bool, - tol: Union[int, float], ) -> Tuple[float, NDArray[float], NDArray[int]]: """Run the AAHC algorithm.""" gfp_sum_sq = np.sum(data**2) maps, segmentation = AAHCluster._compute_maps( - data, n_clusters, ignore_polarity, normalize_input, tol + data, n_clusters, ignore_polarity, normalize_input ) map_corr = _corr_vectors(data, maps[segmentation].T) gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq @@ -224,7 +216,6 @@ def _compute_maps( n_clusters: int, ignore_polarity: bool, normalize_input: bool, - tol: Union[int, float], ) -> Tuple[NDArray[float], NDArray[int]]: """Compute microstates maps.""" n_chan, n_frame = data.shape @@ -267,15 +258,24 @@ def _compute_maps( sgn = np.sign(old_cluster @ cluster[:, c]) - v0 = old_weight * sgn * old_cluster + \ - (1. - old_weight) * cluster[:, c] + v0 = ( + old_weight * sgn * old_cluster + + (1.0 - old_weight) * cluster[:, c] + ) - v, _ = AAHCluster._first_principal_component( - data[:, members], tol, v0 + v, _, converged = AAHCluster._first_principal_component( + data[:, members], v0, max_iter=n_chan ) + if not converged: + # fall back to covariance estimation + # and eigenvalue computation + Cxx = data[:, members] @ data[:, members].T + _, V = np.linalg.eigh(Cxx) + v = V[:, -1] + cluster[:, c] = v else: - cluster[:, c] = np.mean(data[:, members], axis=0) + cluster[:, c] = np.mean(data[:, members], axis=1) cluster[:, c] /= np.linalg.norm( cluster[:, c], axis=0, keepdims=True ) @@ -289,22 +289,22 @@ def _compute_maps( @njit(cache=True) def _first_principal_component( X: NDArray[float], - tol: float, - v0: Optional[NDArray[float]] = None, + v0: NDArray[float], + tol: float = 1e-6, max_iter: int = 100, - ) -> Tuple[NDArray[float], float]: + ) -> Tuple[NDArray[float], float, bool]: """Compute first principal component. Parameters ---------- X : numpy.ndarray Input data matrix with dimensions (channels x observations) - tol : float - Tolerance for convergence. - v0 : numpy.ndarray or None (default) + v0 : numpy.ndarray Initial estimate of the principal vector with dimensions - (channels,). If None, a random initialization is used. - max_iter : int + (channels,). + tol : float (1e-6 by default) + Tolerance for convergence. + max_iter : int (100 by default) Maximum number of iterations to estimate the first principal component. @@ -314,17 +314,17 @@ def _first_principal_component( Estimated principal component vector. eig : float Eigenvalue of the covariance matrix of X (channels x channels) + converged : bool + False, if `max_iter` were reached. See :footcite:t:`Roweis1997` for additional information. """ - if v0 is None: # use a random choice - v = np.random.rand(X.shape[0]) - else: - v = v0.flatten() - assert v.shape[0] == X.shape[0] + v = v0.flatten() + assert v.shape[0] == X.shape[0] v /= np.linalg.norm(v) + converged = False for _ in range(max_iter): s = np.sum((np.expand_dims(v, 0) @ X) * X, axis=1) @@ -332,9 +332,10 @@ def _first_principal_component( s_norm = np.linalg.norm(s) if np.linalg.norm(eig * v - s) / s_norm < tol: + converged = True break v = s / s_norm - return v, eig + return v, eig, converged # -------------------------------------------------------------------- @@ -354,14 +355,6 @@ def normalize_input(self) -> bool: """ return self._normalize_input - @property - def tol(self) -> Union[int, float]: - """Relative tolerance to reach convergence. - - :type: `float` - """ - return self._tol - @property def GEV_(self) -> float: """Global Explained Variance. @@ -383,22 +376,11 @@ def fitted(self, fitted): @staticmethod def _check_ignore_polarity(ignore_polarity: bool) -> bool: """Check that ignore_polarity is a boolean.""" - _check_type(ignore_polarity, ("bool",), item_name="ignore_polarity") + _check_type(ignore_polarity, (bool,), item_name="ignore_polarity") return ignore_polarity @staticmethod def _check_normalize_input(normalize_input: bool) -> bool: """Check that normalize_input is a boolean.""" - _check_type(normalize_input, ("bool",), item_name="normalize_input") + _check_type(normalize_input, (bool,), item_name="normalize_input") return normalize_input - - @staticmethod - def _check_tol(tol: Union[int, float]) -> Union[int, float]: - """Check that tol is a positive number.""" - _check_type(tol, ("numeric",), item_name="tol") - if tol <= 0: - raise ValueError( - "The tolerance must be a positive number. " - f"Provided: '{tol}'." - ) - return tol diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index ea3c68ac..1df2ceef 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -1,4 +1,3 @@ - """Test AAHCluster.""" import logging @@ -49,12 +48,97 @@ n_clusters = 4 aahCluster = AAHCluster( - n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4 + n_clusters=n_clusters, ignore_polarity=True, normalize_input=False ) aahCluster.fit(ch_data) +# simulated data + +# extract 3D positions from raw_eeg +pos = np.vstack([ch["loc"][:3] for ch in raw_eeg.info["chs"]]) +# place 4 sources [3D origin, 3D orientation] +sources = np.array( + [ + [0.000, 0.025, 0.060, 0.000, -0.050, 0.000], + [0.000, 0.015, 0.080, 0.000, 0.025, 0.040], + [0.000, -0.025, 0.050, 0.050, -0.040, 0.025], + [0.000, -0.025, 0.050, -0.050, -0.040, 0.025], + ], + dtype=np.double, +) +sim_n_ms = sources.shape[0] +sim_n_frames = 1000 # number of samples to generate +sim_n_chans = pos.shape[0] # number of channels +# compute forward model +A = np.sum( + (pos[None, ...] - sources[:, None, :3]) * sources[:, None, 3:], axis=2 +) +A /= np.linalg.norm(A, axis=1, keepdims=True) +# simulate source actvities for 4 sources +# with positive and negative polarity +mapping = np.arange(sim_n_frames) % (sim_n_ms * 2) +s = ( + np.sign(mapping - sim_n_ms + 0.01) + * np.eye(sim_n_ms)[:, mapping % sim_n_ms] +) +# apply forward model +X = A.T @ s +# add i.i.d. noise +sim_sigma = 0.05 +X += sim_sigma * np.random.randn(*X.shape) +# generate the mne object +raw_sim = RawArray(X, raw_eeg.info, copy="info") +raw_sim.info["bads"] = [] + + +def test_ignore_polarity_true(): + obj = AAHCluster(n_clusters=sim_n_ms, ignore_polarity=True) + obj.fit(raw_sim) + + # extract cluster centers + A_hat = obj._cluster_centers_ + + # compute Euclidean distances (using the sign that minimizes the distance) + sgn = np.sign(A @ A_hat.T) + dists = np.linalg.norm( + (A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2 + ) + # compute tolerance (2 times the expected noise level) + tol = ( + 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) + ) + # check if there is a cluster center whose distance + # is within the tolerance + assert (dists.min(axis=0) < tol).all() + # ensure that all cluster centers were identified + assert len(set(dists.argmin(axis=0))) == sim_n_ms + + +def test_ignore_polarity_false(): + obj = AAHCluster(n_clusters=sim_n_ms * 2, ignore_polarity=False) + obj.fit(raw_sim) + + # extract cluster centers + A_hat = obj._cluster_centers_ + # create extended targets (pos. and neg. polarity) + A_ = np.concatenate((A, -A), axis=0) + # compute Euclidean distances + dists = np.linalg.norm((A_hat[None, ...] - A_[:, None]), axis=2) + # compute tolerance (2 times the expected noise level) + tol = ( + 2 + * sim_sigma + / np.sqrt(sim_n_frames / sim_n_ms / 2) + * np.sqrt(sim_n_chans) + ) + # check if there is a cluster center whose distance + # is within the tolerance + assert (dists.min(axis=0) < tol).all() + # ensure that all cluster centers were identified + assert len(set(dists.argmin(axis=0))) == 2 * sim_n_ms + + # pylint: disable=protected-access def _check_fitted(aahCluster): """ @@ -116,13 +200,11 @@ def test_aahClusterMeans(): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) # Test properties - assert aahCluster1.ignore_polarity == True - assert aahCluster1.normalize_input == False - assert aahCluster1.tol == 1e-4 + assert aahCluster1.ignore_polarity is True + assert aahCluster1.normalize_input is False _check_unfitted(aahCluster1) # Test default clusters names @@ -160,7 +242,9 @@ def test_aahClusterMeans(): # Test copy aahCluster2 = aahCluster1.copy() _check_fitted(aahCluster2) - assert np.isclose(aahCluster2._cluster_centers_, aahCluster1._cluster_centers_).all() + assert np.isclose( + aahCluster2._cluster_centers_, aahCluster1._cluster_centers_ + ).all() assert np.isclose(aahCluster2.GEV_, aahCluster1.GEV_) assert np.isclose(aahCluster2._labels_, aahCluster1._labels_).all() aahCluster2.fitted = False @@ -169,7 +253,9 @@ def test_aahClusterMeans(): aahCluster3 = aahCluster1.copy(deep=False) _check_fitted(aahCluster3) - assert np.isclose(aahCluster3._cluster_centers_, aahCluster1._cluster_centers_).all() + assert np.isclose( + aahCluster3._cluster_centers_, aahCluster1._cluster_centers_ + ).all() assert np.isclose(aahCluster3.GEV_, aahCluster1.GEV_) assert np.isclose(aahCluster3._labels_, aahCluster1._labels_).all() aahCluster3.fitted = False @@ -273,7 +359,9 @@ def test_rename(caplog): # Test mapping aahCluster_ = aahCluster.copy() - mapping = {old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names)} + mapping = { + old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + } for key, value in mapping.items(): assert isinstance(key, str) assert isinstance(value, str) @@ -407,7 +495,9 @@ def test_reorder(caplog): with pytest.raises( ValueError, match="Argument 'order' should be a 1D iterable" ): - aahCluster_.reorder_clusters(order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]])) + aahCluster_.reorder_clusters( + order=np.array([[0, 1, 2, 3], [0, 1, 2, 3]]) + ) aahCluster_.reorder_clusters() assert "Either 'mapping' or 'order' should not be 'None' " in caplog.text @@ -434,7 +524,6 @@ def test_properties(caplog): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) aahCluster_.cluster_centers_ # pylint: disable=pointless-statement @@ -469,7 +558,6 @@ def test_properties(caplog): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) with pytest.raises(TypeError, match="'fitted' must be an instance of"): aahCluster_.fitted = "101" @@ -496,19 +584,10 @@ def test_invalid_arguments(): with pytest.raises(ValueError, match="The number of clusters must be a"): aahCluster_ = AAHCluster(n_clusters=-101) - # tol - with pytest.raises(TypeError, match="'tol' must be an instance of "): - aahCluster_ = AAHCluster(n_clusters=4, tol="100") - with pytest.raises(ValueError, match="The tolerance must be a"): - aahCluster_ = AAHCluster(n_clusters=4, tol=0) - with pytest.raises(ValueError, match="The tolerance must be a"): - aahCluster_ = AAHCluster(n_clusters=4, tol=-101) - aahCluster_ = AAHCluster( n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) # inst with pytest.raises(TypeError, match="'inst' must be an instance of "): @@ -554,7 +633,6 @@ def test_fit_data_shapes(): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-8, ) # tmin @@ -568,7 +646,9 @@ def test_fit_data_shapes(): tmax=None, reject_by_annotation=False, ) - _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 5, None, None) + _check_fitted_data_raw( + aahCluster_._fitted_data, raw_eeg, "eeg", 5, None, None + ) # save for later fitted_data_5_end = deepcopy(aahCluster_._fitted_data) @@ -582,7 +662,9 @@ def test_fit_data_shapes(): tmax=None, reject_by_annotation=False, ) - _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.2, None) + _check_fitted_data_epochs( + aahCluster_._fitted_data, epochs_eeg, "eeg", 0.2, None + ) # tmax aahCluster_.fitted = False @@ -595,7 +677,9 @@ def test_fit_data_shapes(): tmax=5, reject_by_annotation=False, ) - _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", None, 5, None) + _check_fitted_data_raw( + aahCluster_._fitted_data, raw_eeg, "eeg", None, 5, None + ) # save for later fitted_data_0_5 = deepcopy(aahCluster_._fitted_data) @@ -609,7 +693,9 @@ def test_fit_data_shapes(): tmax=0.3, reject_by_annotation=False, ) - _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", None, 0.3) + _check_fitted_data_epochs( + aahCluster_._fitted_data, epochs_eeg, "eeg", None, 0.3 + ) # tmin, tmax aahCluster_.fitted = False @@ -622,7 +708,9 @@ def test_fit_data_shapes(): tmax=8, reject_by_annotation=False, ) - _check_fitted_data_raw(aahCluster_._fitted_data, raw_eeg, "eeg", 2, 8, None) + _check_fitted_data_raw( + aahCluster_._fitted_data, raw_eeg, "eeg", 2, 8, None + ) aahCluster_.fitted = False _check_unfitted(aahCluster_) @@ -634,7 +722,9 @@ def test_fit_data_shapes(): tmax=0.4, reject_by_annotation=False, ) - _check_fitted_data_epochs(aahCluster_._fitted_data, epochs_eeg, "eeg", 0.1, 0.4) + _check_fitted_data_epochs( + aahCluster_._fitted_data, epochs_eeg, "eeg", 0.1, 0.4 + ) # --------------------- # Reject by annotations @@ -656,7 +746,7 @@ def test_fit_data_shapes(): # Compare 'omit' and True assert np.isclose( aahCluster_reject_omit._fitted_data, - aahCluster_reject_True._fitted_data + aahCluster_reject_True._fitted_data, ).all() assert np.isclose(aahCluster_reject_omit.GEV_, aahCluster_reject_True.GEV_) assert np.isclose( @@ -665,14 +755,15 @@ def test_fit_data_shapes(): # due to internal randomness, the sign can be flipped sgn = np.sign( np.sum( - aahCluster_reject_True._cluster_centers_ * - aahCluster_reject_omit._cluster_centers_, axis=1 + aahCluster_reject_True._cluster_centers_ + * aahCluster_reject_omit._cluster_centers_, + axis=1, ) ) aahCluster_reject_True._cluster_centers_ *= sgn[:, None] assert np.isclose( aahCluster_reject_omit._cluster_centers_, - aahCluster_reject_True._cluster_centers_ + aahCluster_reject_True._cluster_centers_, ).all() # Make sure there is a shape diff between True and False @@ -691,7 +782,9 @@ def test_fit_data_shapes(): # Check with reject with tmin/tmax aahCluster_rej_0_5 = aahCluster_.copy() - aahCluster_rej_0_5.fit(raw_, n_jobs=1, tmin=0, tmax=5, reject_by_annotation=True) + aahCluster_rej_0_5.fit( + raw_, n_jobs=1, tmin=0, tmax=5, reject_by_annotation=True + ) aahCluster_rej_5_end = aahCluster_.copy() aahCluster_rej_5_end.fit( raw_, n_jobs=1, tmin=5, tmax=None, reject_by_annotation=True @@ -705,7 +798,9 @@ def test_fit_data_shapes(): aahCluster_rej_5_end._fitted_data, raw_, "eeg", 5, None, "omit" ) assert aahCluster_rej_0_5._fitted_data.shape != fitted_data_0_5.shape - assert np.isclose(fitted_data_5_end, aahCluster_rej_5_end._fitted_data).all() + assert np.isclose( + fitted_data_5_end, aahCluster_rej_5_end._fitted_data + ).all() def test_refit(): @@ -715,7 +810,6 @@ def test_refit(): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) aahCluster_.fit(raw, picks="eeg") eeg_ch_names = aahCluster_.info["ch_names"] @@ -733,7 +827,6 @@ def test_refit(): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) aahCluster_.fit(raw, picks="eeg") # works eeg_ch_names = aahCluster_.info["ch_names"] @@ -865,7 +958,6 @@ def test_picks_fit_predict(caplog): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) # test invalid fit @@ -904,7 +996,8 @@ def test_picks_fit_predict(caplog): info = info_.copy() raw_predict = RawArray(data, info) caplog.clear() - aahCluster_.predict(raw_predict, picks="eeg") # -> warning for selected Fp2 + aahCluster_.predict(raw_predict, picks="eeg") + # -> warning for selected Fp2 assert "Fp2 which was not used during fitting" in caplog.text caplog.clear() aahCluster_.predict(raw_predict, picks=["Fp1", "CP1", "CP2"]) @@ -915,7 +1008,8 @@ def test_picks_fit_predict(caplog): assert "Fp2 which was not used during fitting" not in caplog.text # predict with a channel used for fitting that is now missing - # fails, because aahCluster.info includes Fp1 which is bad in prediction instance + # fails, because aahCluster.info includes Fp1 which is bad + # in prediction instance raw_predict.info["bads"] = ["Fp1"] with pytest.raises(ValueError, match="Fp1 is required to predict"): aahCluster_.predict(raw_predict, picks="eeg") @@ -926,8 +1020,8 @@ def test_picks_fit_predict(caplog): assert "Fp1 is set as bad in the instance but was selected" in caplog.text caplog.clear() - # fails, because aahCluster_.info includes Fp1 which is missing from prediction - # instance selection + # fails, because aahCluster_.info includes Fp1 which is missing + # from prediction instance selection with pytest.raises(ValueError, match="Fp1 is required to predict"): aahCluster_.predict(raw_predict, picks=["CP2", "CP1"]) @@ -948,8 +1042,8 @@ def test_picks_fit_predict(caplog): assert msg1 in caplog.text or msg2 in caplog.text caplog.clear() - # fails, because aahCluster_.info includes Fp1 which is missing from prediction - # instance selection + # fails, because aahCluster_.info includes Fp1 which is missing from + # prediction instance selection with pytest.raises(ValueError, match="Fp1 is required to predict"): aahCluster_.predict(raw_predict, picks=["Fp2", "Fpz", "CP2", "CP1"]) caplog.clear() @@ -1021,7 +1115,8 @@ def test_picks_fit_predict(caplog): with pytest.raises(ValueError, match="Fp2 is required to predict"): aahCluster_.predict(raw_predict, picks=["Fp1", "CP2", "CP1"]) - # works, because same channels as aahCluster_.info (with warnings for Fp1, Fp2) + # works, because same channels as aahCluster_.info + # (with warnings for Fp1, Fp2) caplog.clear() aahCluster_.predict(raw_predict, picks=["Fp1", "Fp2", "CP2", "CP1"]) assert predict_warning in caplog.text @@ -1043,8 +1138,6 @@ def test_predict_invalid_arguments(): TypeError, match="'half_window_size' must be an instance of " ): aahCluster.predict(raw_eeg, half_window_size="1") - with pytest.raises(TypeError, match="'tol' must be an instance of "): - aahCluster.predict(raw_eeg, tol="0") with pytest.raises( TypeError, match="'min_segment_length' must be an instance of " ): @@ -1061,14 +1154,15 @@ def test_contains_mixin(): """Test contains mixin class.""" assert "eeg" in aahCluster assert aahCluster.compensation_grade is None - assert aahCluster.get_channel_types() == ["eeg"] * aahCluster._info["nchan"] + assert ( + aahCluster.get_channel_types() == ["eeg"] * aahCluster._info["nchan"] + ) # test raise with non-fitted instance aahCluster_ = AAHCluster( n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) with pytest.raises( ValueError, match="Instance 'AAHCluster' attribute 'info' is None." @@ -1100,7 +1194,6 @@ def test_montage_mixin(): n_clusters=n_clusters, ignore_polarity=True, normalize_input=False, - tol=1e-4, ) with pytest.raises( ValueError, match="Instance 'AAHCluster' attribute 'info' is None." @@ -1163,9 +1256,6 @@ def test_comparison(caplog): aahCluster1._normalize_input = True assert aahCluster1 != aahCluster2 aahCluster1 = aahCluster.copy() - aahCluster1._tol = 0.101 - assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() aahCluster1._GEV_ = 0.101 assert aahCluster1 != aahCluster2 @@ -1181,7 +1271,9 @@ def test_comparison(caplog): assert aahCluster1 != aahCluster2 aahCluster1 = aahCluster.copy() aahCluster1._info = ChInfo( - ch_names=[str(k) for k in range(aahCluster1._cluster_centers_.shape[1])], + ch_names=[ + str(k) for k in range(aahCluster1._cluster_centers_.shape[1]) + ], ch_types=["eeg"] * aahCluster1._cluster_centers_.shape[1], ) assert aahCluster1 != aahCluster2 diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index 7bc2dfee..52e6e3d7 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -34,7 +34,7 @@ from .. import __version__ from .._typing import CHInfo -from ..cluster import ModKMeans, AAHCluster +from ..cluster import AAHCluster, ModKMeans from ..utils._checks import _check_type, _check_value from ..utils._docs import fill_doc from ..utils._logs import logger @@ -179,7 +179,7 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): "variables": ["GEV_"], }, "AAHCluster": { - "parameters": ["ignore_polarity", "normalize_input", "tol"], + "parameters": ["ignore_polarity", "normalize_input"], "variables": ["GEV_"], }, } @@ -215,11 +215,13 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): fit_parameters["tol"] = ModKMeans._check_tol(value) elif algorithm == "AAHCluster": if key == "ignore_polarity": - fit_parameters["ignore_polarity"] = \ - AAHCluster._check_ignore_polarity(value) + fit_parameters[ + "ignore_polarity" + ] = AAHCluster._check_ignore_polarity(value) elif key == "normalize_input": - fit_parameters["normalize_input"] = \ - AAHCluster._check_normalize_input(value) + fit_parameters[ + "normalize_input" + ] = AAHCluster._check_normalize_input(value) elif key == "tol": fit_parameters["tol"] = AAHCluster._check_tol(value) if key == "GEV_": @@ -326,7 +328,7 @@ def _read_cluster(fname: Union[str, Path]): # reconstruct cluster instance function = { "ModKMeans": _create_ModKMeans, - "AAHCluster": _create_AAHCluster + "AAHCluster": _create_AAHCluster, } return ( @@ -408,14 +410,12 @@ def _create_AAHCluster( labels_: NDArray[int], ignore_polarity: bool, normalize_input: bool, - tol: Union[int, float], GEV_: float, ): - print(normalize_input, ignore_polarity) """Create a AAHCluster object.""" cluster = AAHCluster( - cluster_centers_.shape[0], ignore_polarity, normalize_input, tol + cluster_centers_.shape[0], ignore_polarity, normalize_input ) cluster._cluster_centers_ = cluster_centers_ cluster._info = info @@ -426,6 +426,7 @@ def _create_AAHCluster( cluster._fitted = True return cluster + # ---------------------------------------------------------------------------- def _write_meas_info(fid, info: CHInfo): """Write measurement info into a file id (from a fif file). diff --git a/pycrostates/utils/_checks.py b/pycrostates/utils/_checks.py index 3da0528e..e1b50742 100644 --- a/pycrostates/utils/_checks.py +++ b/pycrostates/utils/_checks.py @@ -67,7 +67,6 @@ def __instancecheck__(cls, other): "path-like": (str, Path, os.PathLike), "int": (_IntLike(),), "callable": (_Callable(),), - "bool": (bool,), } From 1161e8eeba3a0c522d5e5577454d45423393527a Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Tue, 1 Nov 2022 12:02:17 +0900 Subject: [PATCH 09/24] Updates in pycrostates.cluster.[kmeans|aahc].py - changed _repr_html to use GEV as formatted string. --- pycrostates/cluster/aahc.py | 2 +- pycrostates/cluster/kmeans.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 12d59f91..aceda9e2 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -86,7 +86,7 @@ def _repr_html_(self, caption=None): f"{ch_count} {ch_type.upper()}" for ch_type, ch_count in zip(ch_types, ch_counts) ] - GEV = int(self._GEV_ * 100) + GEV = f"{self._GEV_ * 100:.2f}" else: n_samples = None ch_repr = None diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 31df0e58..5ab507a5 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -81,7 +81,7 @@ def _repr_html_(self, caption=None): f"{ch_count} {ch_type.upper()}" for ch_type, ch_count in zip(ch_types, ch_counts) ] - GEV = int(self._GEV_ * 100) + GEV = f"{self._GEV_ * 100:.2f}" else: n_samples = None ch_repr = None From 29fa0e53ffd2a1aefc633a32e47bec430afcb275 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Tue, 1 Nov 2022 12:37:09 +0900 Subject: [PATCH 10/24] Update in pycrostates/cluster/aahc.py - added unit tests for invalid arguments `ignore_polarity` and `normalize_input` - added test case for `normalize_input` argument --- pycrostates/cluster/aahc.py | 2 +- pycrostates/cluster/tests/test_aahc.py | 39 ++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index aceda9e2..d4c78d6e 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -66,7 +66,7 @@ def __init__( self._ignore_polarity = AAHCluster._check_ignore_polarity( ignore_polarity ) - self._normalize_input = AAHCluster._check_ignore_polarity( + self._normalize_input = AAHCluster._check_normalize_input( normalize_input ) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 1df2ceef..47dbf782 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -139,6 +139,31 @@ def test_ignore_polarity_false(): assert len(set(dists.argmin(axis=0))) == 2 * sim_n_ms +def test_normalize_input_true(): + obj = AAHCluster( + n_clusters=sim_n_ms, ignore_polarity=True, normalize_input=True + ) + obj.fit(raw_sim) + + # extract cluster centers + A_hat = obj._cluster_centers_ + + # compute Euclidean distances (using the sign that minimizes the distance) + sgn = np.sign(A @ A_hat.T) + dists = np.linalg.norm( + (A_hat[None, ...] - A[:, None] * sgn[..., None]), axis=2 + ) + # compute tolerance (2 times the expected noise level) + tol = ( + 2 * sim_sigma / np.sqrt(sim_n_frames / sim_n_ms) * np.sqrt(sim_n_chans) + ) + # check if there is a cluster center whose distance + # is within the tolerance + assert (dists.min(axis=0) < tol).all() + # ensure that all cluster centers were identified + assert len(set(dists.argmin(axis=0))) == sim_n_ms + + # pylint: disable=protected-access def _check_fitted(aahCluster): """ @@ -584,6 +609,20 @@ def test_invalid_arguments(): with pytest.raises(ValueError, match="The number of clusters must be a"): aahCluster_ = AAHCluster(n_clusters=-101) + # ignore_polarity + with pytest.raises( + TypeError, match="'ignore_polarity' must be an instance of bool" + ): + aahCluster_ = AAHCluster(n_clusters=n_clusters, ignore_polarity="asdf") + aahCluster_ = AAHCluster(n_clusters=n_clusters, ignore_polarity=None) + + # normalize_input + with pytest.raises( + TypeError, match="'normalize_input' must be an instance of bool" + ): + aahCluster_ = AAHCluster(n_clusters=n_clusters, normalize_input="asdf") + aahCluster_ = AAHCluster(n_clusters=n_clusters, normalize_input=None) + aahCluster_ = AAHCluster( n_clusters=n_clusters, ignore_polarity=True, From 814a94988a404c85a6e0a12a1b25c90aed1902de Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Tue, 1 Nov 2022 12:37:32 +0900 Subject: [PATCH 11/24] updated changelog --- docs/source/dev/changes/latest.rst | 1 + docs/source/dev/changes/names.inc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/docs/source/dev/changes/latest.rst b/docs/source/dev/changes/latest.rst index 5347d47f..bf0645c1 100644 --- a/docs/source/dev/changes/latest.rst +++ b/docs/source/dev/changes/latest.rst @@ -25,6 +25,7 @@ Current 0.3.0.dev Enhancements ~~~~~~~~~~~~ - Improve changelog. (:pr:`86` by `Victor Férat`_) +- Support for AAHC clustering (:pr: `92` by `Reinmar Kobler`_) Bugs ~~~~ diff --git a/docs/source/dev/changes/names.inc b/docs/source/dev/changes/names.inc index 6372dcd3..06ad7b66 100644 --- a/docs/source/dev/changes/names.inc +++ b/docs/source/dev/changes/names.inc @@ -1,3 +1,5 @@ .. _Victor Férat: https://github.com/vferat .. _Mathieu Scheltienne: https://github.com/mscheltienne + +.. _Reinmar Kobler: https://github.com/rkobler From 869017b545a02a3e431c7011b2ffe8c247397e5d Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Wed, 2 Nov 2022 08:51:58 +0900 Subject: [PATCH 12/24] fixed sphinx warning in doc/source/dev/changes/ratest.rst --- docs/source/dev/changes/latest.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dev/changes/latest.rst b/docs/source/dev/changes/latest.rst index bf0645c1..6ef17d2c 100644 --- a/docs/source/dev/changes/latest.rst +++ b/docs/source/dev/changes/latest.rst @@ -25,7 +25,7 @@ Current 0.3.0.dev Enhancements ~~~~~~~~~~~~ - Improve changelog. (:pr:`86` by `Victor Férat`_) -- Support for AAHC clustering (:pr: `92` by `Reinmar Kobler`_) +- Support for AAHC clustering (:pr:`92` by `Reinmar Kobler`_) Bugs ~~~~ From f5e5e4e3a0ba0aab5083b997c3ab65cda84447d0 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Wed, 2 Nov 2022 08:53:32 +0900 Subject: [PATCH 13/24] fixed AAHC related pylint and pydoc warnings --- pycrostates/cluster/aahc.py | 21 ++- pycrostates/cluster/tests/test_aahc.py | 234 +++++++++++++------------ pycrostates/io/fiff.py | 5 +- 3 files changed, 136 insertions(+), 124 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index d4c78d6e..58d51571 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -1,5 +1,4 @@ -"""Class and functions to use Atomize and Agglomerate Hierarchical Clustering - (AAHC).""" +"""Atomize and Agglomerate Hierarchical Clustering (AAHC).""" from functools import wraps as func_wraps from pathlib import Path @@ -22,7 +21,9 @@ from numba import njit except ImportError: - def njit(cache=False): + def njit(cache=False): # pylint: disable=unused-argument + """Define dummy decorator for numba.""" + def decorator(func): @func_wraps(func) def wrapper(*args, **kwargs): @@ -74,8 +75,11 @@ def __init__( self._GEV_ = None def _repr_html_(self, caption=None): + # pylint: disable=import-outside-toplevel from ..html_templates import repr_templates_env + # pylint: enable=import-outside-toplevel + template = repr_templates_env.get_template("AAHCluster.html.jinja") if self.fitted: n_samples = self._fitted_data.shape[-1] @@ -139,7 +143,7 @@ def _check_fit(self): @copy_doc(_BaseCluster.fit) @fill_doc - def fit( + def fit( # pylint: disable=arguments-differ self, inst: Union[BaseRaw, BaseEpochs], picks: Picks = "eeg", @@ -178,7 +182,10 @@ def save(self, fname: Union[str, Path]): super().save(fname) # TODO: to be replaced by a general writer than infers the writer from # the file extension. - from ..io.fiff import _write_cluster # pylint: disable=C0415 + # pylint: disable=import-outside-toplevel + from ..io.fiff import _write_cluster + + # pylint: enable=import-outside-toplevel _write_cluster( fname, @@ -210,6 +217,7 @@ def _aahc( gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq return gev, maps, segmentation + # pylint: disable=too-many-locals @staticmethod def _compute_maps( data: NDArray[float], @@ -285,6 +293,8 @@ def _compute_maps( GEV[c] = np.sum(new_fit) return cluster.T, assignment + # pylint: enable=too-many-locals + @staticmethod @njit(cache=True) def _first_principal_component( @@ -319,7 +329,6 @@ def _first_principal_component( See :footcite:t:`Roweis1997` for additional information. """ - v = v0.flatten() assert v.shape[0] == X.shape[0] v /= np.linalg.norm(v) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 47dbf782..6c48ef7c 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -47,11 +47,11 @@ # Fit one for general purposes n_clusters = 4 -aahCluster = AAHCluster( +aah_cluster = AAHCluster( n_clusters=n_clusters, ignore_polarity=True, normalize_input=False ) -aahCluster.fit(ch_data) +aah_cluster.fit(ch_data) # simulated data @@ -68,7 +68,7 @@ dtype=np.double, ) sim_n_ms = sources.shape[0] -sim_n_frames = 1000 # number of samples to generate +sim_n_frames = 250 # number of samples to generate sim_n_chans = pos.shape[0] # number of channels # compute forward model A = np.sum( @@ -97,7 +97,7 @@ def test_ignore_polarity_true(): obj.fit(raw_sim) # extract cluster centers - A_hat = obj._cluster_centers_ + A_hat = obj.cluster_centers_ # compute Euclidean distances (using the sign that minimizes the distance) sgn = np.sign(A @ A_hat.T) @@ -120,7 +120,7 @@ def test_ignore_polarity_false(): obj.fit(raw_sim) # extract cluster centers - A_hat = obj._cluster_centers_ + A_hat = obj.cluster_centers_ # create extended targets (pos. and neg. polarity) A_ = np.concatenate((A, -A), axis=0) # compute Euclidean distances @@ -146,7 +146,7 @@ def test_normalize_input_true(): obj.fit(raw_sim) # extract cluster centers - A_hat = obj._cluster_centers_ + A_hat = obj.cluster_centers_ # compute Euclidean distances (using the sign that minimizes the distance) sgn = np.sign(A @ A_hat.T) @@ -165,32 +165,32 @@ def test_normalize_input_true(): # pylint: disable=protected-access -def _check_fitted(aahCluster): +def _check_fitted(aah_cluster): """ - Checks that the aahCluster is fitted. + Checks that the aah_cluster is fitted. """ - assert aahCluster.fitted - assert aahCluster.n_clusters == n_clusters - assert len(aahCluster._cluster_names) == n_clusters - assert len(aahCluster._cluster_centers_) == n_clusters - assert aahCluster._fitted_data is not None - assert aahCluster._info is not None - assert aahCluster.GEV_ is not None - assert aahCluster._labels_ is not None + assert aah_cluster.fitted + assert aah_cluster.n_clusters == n_clusters + assert len(aah_cluster._cluster_names) == n_clusters + assert len(aah_cluster._cluster_centers_) == n_clusters + assert aah_cluster._fitted_data is not None + assert aah_cluster._info is not None + assert aah_cluster.GEV_ is not None + assert aah_cluster._labels_ is not None -def _check_unfitted(aahCluster): +def _check_unfitted(aah_cluster): """ - Checks that the aahCluster is not fitted. + Checks that the aah_cluster is not fitted. """ - assert not aahCluster.fitted - assert aahCluster.n_clusters == n_clusters - assert len(aahCluster._cluster_names) == n_clusters - assert aahCluster._cluster_centers_ is None - assert aahCluster._fitted_data is None - assert aahCluster._info is None - assert aahCluster.GEV_ is None - assert aahCluster._labels_ is None + assert not aah_cluster.fitted + assert aah_cluster.n_clusters == n_clusters + assert len(aah_cluster._cluster_names) == n_clusters + assert aah_cluster._cluster_centers_ is None + assert aah_cluster._fitted_data is None + assert aah_cluster._info is None + assert aah_cluster.GEV_ is None + assert aah_cluster._labels_ is None def _check_fitted_data_raw( @@ -289,8 +289,8 @@ def test_aahClusterMeans(): # Test representation expected = f"" - assert expected == aahCluster1.__repr__() - assert "" == aahCluster2.__repr__() + assert expected == repr(aahCluster1) + assert "" == repr(aahCluster2) # Test HTML representation html = aahCluster1._repr_html_() @@ -311,7 +311,7 @@ def test_aahClusterMeans(): def test_invert_polarity(): """Test invert polarity method.""" # list/tuple - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity([True, False, True, False]) assert np.isclose( @@ -328,7 +328,7 @@ def test_invert_polarity(): ).all() # bool - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(True) assert np.isclose( @@ -345,7 +345,7 @@ def test_invert_polarity(): ).all() # np.array - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() cluster_centers_ = deepcopy(aahCluster_._cluster_centers_) aahCluster_.invert_polarity(np.array([True, False, True, False])) assert np.isclose( @@ -383,9 +383,9 @@ def test_rename(caplog): alphabet = ["A", "B", "C", "D"] # Test mapping - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() mapping = { - old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + old: alphabet[k] for k, old in enumerate(aah_cluster._cluster_names) } for key, value in mapping.items(): assert isinstance(key, str) @@ -393,26 +393,26 @@ def test_rename(caplog): assert key != value aahCluster_.rename_clusters(mapping=mapping) assert aahCluster_._cluster_names == alphabet - assert aahCluster_._cluster_names != aahCluster._cluster_names + assert aahCluster_._cluster_names != aah_cluster._cluster_names # Test new_names - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.rename_clusters(new_names=alphabet) assert aahCluster_._cluster_names == alphabet - assert aahCluster_._cluster_names != aahCluster._cluster_names + assert aahCluster_._cluster_names != aah_cluster._cluster_names # Test invalid arguments - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() with pytest.raises(TypeError, match="'mapping' must be an instance of "): aahCluster_.rename_clusters(mapping=101) with pytest.raises(ValueError, match="Invalid value for the 'old name'"): mapping = { old + "101": alphabet[k] - for k, old in enumerate(aahCluster._cluster_names) + for k, old in enumerate(aah_cluster._cluster_names) } aahCluster_.rename_clusters(mapping=mapping) with pytest.raises(TypeError, match="'new name' must be an instance of "): - mapping = {old: k for k, old in enumerate(aahCluster._cluster_names)} + mapping = {old: k for k, old in enumerate(aah_cluster._cluster_names)} aahCluster_.rename_clusters(mapping=mapping) with pytest.raises( ValueError, match="Argument 'new_names' should contain" @@ -426,17 +426,19 @@ def test_rename(caplog): ValueError, match="Only one of 'mapping' or 'new_names'" ): mapping = { - old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + old: alphabet[k] + for k, old in enumerate(aah_cluster._cluster_names) } aahCluster_.rename_clusters(mapping=mapping, new_names=alphabet) # Test unfitted - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.fitted = False _check_unfitted(aahCluster_) with pytest.raises(RuntimeError, match="must be fitted before"): mapping = { - old: alphabet[k] for k, old in enumerate(aahCluster._cluster_names) + old: alphabet[k] + for k, old in enumerate(aah_cluster._cluster_names) } aahCluster_.rename_clusters(mapping=mapping) with pytest.raises(RuntimeError, match="must be fitted before"): @@ -446,42 +448,44 @@ def test_rename(caplog): def test_reorder(caplog): """Test reordering of clusters.""" # Test mapping - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(mapping={0: 1}) assert np.isclose( - aahCluster._cluster_centers_[0, :], aahCluster_._cluster_centers_[1, :] + aah_cluster._cluster_centers_[0, :], + aahCluster_._cluster_centers_[1, :], ).all() assert np.isclose( - aahCluster._cluster_centers_[1, :], aahCluster_._cluster_centers_[0, :] + aah_cluster._cluster_centers_[1, :], + aahCluster_._cluster_centers_[0, :], ).all() - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] # Test order - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=[1, 0, 2, 3]) assert np.isclose( - aahCluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] + aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] ).all() assert np.isclose( - aahCluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] + aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] ).all() - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.reorder_clusters(order=np.array([1, 0, 2, 3])) assert np.isclose( - aahCluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] + aah_cluster._cluster_centers_[0], aahCluster_._cluster_centers_[1] ).all() assert np.isclose( - aahCluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] + aah_cluster._cluster_centers_[1], aahCluster_._cluster_centers_[0] ).all() - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] - assert aahCluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] + assert aah_cluster._cluster_names[0] == aahCluster_._cluster_names[1] # test ._labels_ reordering - y = aahCluster._labels_[:20] + y = aah_cluster._labels_[:20] y[y == 0] = -1 y[y == 1] = 0 y[y == -1] = 1 @@ -489,7 +493,7 @@ def test_reorder(caplog): assert np.all(x == y) # Test invalid arguments - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() with pytest.raises(TypeError, match="'mapping' must be an instance of "): aahCluster_.reorder_clusters(mapping=101) with pytest.raises( @@ -531,7 +535,7 @@ def test_reorder(caplog): aahCluster_.reorder_clusters(mapping={0: 1}, order=[1, 0, 2, 3]) # Test unfitted - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.fitted = False _check_unfitted(aahCluster_) with pytest.raises(RuntimeError, match="must be fitted before"): @@ -564,17 +568,17 @@ def test_properties(caplog): caplog.clear() # Fitted - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() - aahCluster_.cluster_centers_ + assert aahCluster_.cluster_centers_ is not None assert "Clustering algorithm has not been fitted." not in caplog.text caplog.clear() - aahCluster_.info + assert aahCluster_.info is not None assert "Clustering algorithm has not been fitted." not in caplog.text caplog.clear() - aahCluster_.fitted_data + assert aahCluster_.fitted_data is not None assert "Clustering algorithm has not been fitted." not in caplog.text caplog.clear() @@ -591,7 +595,7 @@ def test_properties(caplog): log = "'fitted' can not be set to 'True' directly. Please use the .fit()" assert log in caplog.text caplog.clear() - aahCluster_ = aahCluster.copy() + aahCluster_ = aah_cluster.copy() aahCluster_.fitted = True log = "'fitted' can not be set to 'True' directly. The clustering" assert log in caplog.text @@ -879,13 +883,13 @@ def test_refit(): def test_predict_default(caplog): """Test predict method default behaviors.""" # raw, no smoothing, no_edge - segmentation = aahCluster.predict(raw_eeg, factor=0, reject_edges=False) + segmentation = aah_cluster.predict(raw_eeg, factor=0, reject_edges=False) assert isinstance(segmentation, RawSegmentation) assert "Segmenting data without smoothing" in caplog.text caplog.clear() # raw, no smoothing, with edge rejection - segmentation = aahCluster.predict(raw_eeg, factor=0, reject_edges=True) + segmentation = aah_cluster.predict(raw_eeg, factor=0, reject_edges=True) assert isinstance(segmentation, RawSegmentation) assert segmentation._labels[0] == -1 assert segmentation._labels[-1] == -1 @@ -893,7 +897,7 @@ def test_predict_default(caplog): caplog.clear() # raw, with smoothing - segmentation = aahCluster.predict(raw_eeg, factor=3, reject_edges=True) + segmentation = aah_cluster.predict(raw_eeg, factor=3, reject_edges=True) assert isinstance(segmentation, RawSegmentation) assert segmentation._labels[0] == -1 assert segmentation._labels[-1] == -1 @@ -901,7 +905,7 @@ def test_predict_default(caplog): caplog.clear() # raw with min_segment_length - segmentation = aahCluster.predict( + segmentation = aah_cluster.predict( raw_eeg, factor=0, reject_edges=False, min_segment_length=5 ) assert isinstance(segmentation, RawSegmentation) @@ -913,13 +917,15 @@ def test_predict_default(caplog): caplog.clear() # epochs, no smoothing, no_edge - segmentation = aahCluster.predict(epochs_eeg, factor=0, reject_edges=False) + segmentation = aah_cluster.predict( + epochs_eeg, factor=0, reject_edges=False + ) assert isinstance(segmentation, EpochsSegmentation) assert "Segmenting data without smoothing" in caplog.text caplog.clear() # epochs, no smoothing, with edge rejection - segmentation = aahCluster.predict(epochs_eeg, factor=0, reject_edges=True) + segmentation = aah_cluster.predict(epochs_eeg, factor=0, reject_edges=True) assert isinstance(segmentation, EpochsSegmentation) for epoch_labels in segmentation._labels: assert epoch_labels[0] == -1 @@ -928,7 +934,7 @@ def test_predict_default(caplog): caplog.clear() # epochs, with smoothing - segmentation = aahCluster.predict(epochs_eeg, factor=3, reject_edges=True) + segmentation = aah_cluster.predict(epochs_eeg, factor=3, reject_edges=True) assert isinstance(segmentation, EpochsSegmentation) for epoch_labels in segmentation._labels: assert epoch_labels[0] == -1 @@ -937,7 +943,7 @@ def test_predict_default(caplog): caplog.clear() # epochs with min_segment_length - segmentation = aahCluster.predict( + segmentation = aah_cluster.predict( epochs_eeg, factor=0, reject_edges=False, min_segment_length=5 ) assert isinstance(segmentation, EpochsSegmentation) @@ -953,16 +959,16 @@ def test_predict_default(caplog): bad_annot = Annotations([1], [2], "bad") raw_ = raw_eeg.copy() raw_.set_annotations(bad_annot) - segmentation_rej_True = aahCluster.predict( + segmentation_rej_True = aah_cluster.predict( raw_, factor=0, reject_edges=True, reject_by_annotation=True ) - segmentation_rej_False = aahCluster.predict( + segmentation_rej_False = aah_cluster.predict( raw_, factor=0, reject_edges=True, reject_by_annotation=False ) - segmentation_rej_None = aahCluster.predict( + segmentation_rej_None = aah_cluster.predict( raw_, factor=0, reject_edges=True, reject_by_annotation=None ) - segmentation_no_annot = aahCluster.predict( + segmentation_no_annot = aah_cluster.predict( raw_eeg, factor=0, reject_edges=True, reject_by_annotation="omit" ) assert not np.isclose( @@ -976,13 +982,13 @@ def test_predict_default(caplog): ).all() # test different half_window_size - segmentation1 = aahCluster.predict( + segmentation1 = aah_cluster.predict( raw_eeg, factor=3, reject_edges=False, half_window_size=3 ) - segmentation2 = aahCluster.predict( + segmentation2 = aah_cluster.predict( raw_eeg, factor=3, reject_edges=False, half_window_size=60 ) - segmentation3 = aahCluster.predict( + segmentation3 = aah_cluster.predict( raw_eeg, factor=0, reject_edges=False, half_window_size=3 ) assert not np.isclose(segmentation1._labels, segmentation2._labels).all() @@ -1047,7 +1053,7 @@ def test_picks_fit_predict(caplog): assert "Fp2 which was not used during fitting" not in caplog.text # predict with a channel used for fitting that is now missing - # fails, because aahCluster.info includes Fp1 which is bad + # fails, because aah_cluster.info includes Fp1 which is bad # in prediction instance raw_predict.info["bads"] = ["Fp1"] with pytest.raises(ValueError, match="Fp1 is required to predict"): @@ -1093,7 +1099,7 @@ def test_picks_fit_predict(caplog): caplog.clear() # try with a missing channel from the prediction instance - # fails, because Fp1 is used in aahCluster.info + # fails, because Fp1 is used in aah_cluster.info raw_predict.drop_channels(["Fp1"]) with pytest.raises( ValueError, match="Fp1 was used during fitting but is missing" @@ -1166,35 +1172,35 @@ def test_picks_fit_predict(caplog): def test_predict_invalid_arguments(): """Test invalid arguments passed to predict.""" with pytest.raises(TypeError, match="'inst' must be an instance of "): - aahCluster.predict(epochs_eeg.average()) + aah_cluster.predict(epochs_eeg.average()) with pytest.raises(TypeError, match="'factor' must be an instance of "): - aahCluster.predict(raw_eeg, factor="0") + aah_cluster.predict(raw_eeg, factor="0") with pytest.raises( TypeError, match="'reject_edges' must be an instance of " ): - aahCluster.predict(raw_eeg, reject_edges=1) + aah_cluster.predict(raw_eeg, reject_edges=1) with pytest.raises( TypeError, match="'half_window_size' must be an instance of " ): - aahCluster.predict(raw_eeg, half_window_size="1") + aah_cluster.predict(raw_eeg, half_window_size="1") with pytest.raises( TypeError, match="'min_segment_length' must be an instance of " ): - aahCluster.predict(raw_eeg, min_segment_length="0") + aah_cluster.predict(raw_eeg, min_segment_length="0") with pytest.raises( TypeError, match="'reject_by_annotation' must be an instance of " ): - aahCluster.predict(raw_eeg, reject_by_annotation=1) + aah_cluster.predict(raw_eeg, reject_by_annotation=1) with pytest.raises(ValueError, match="'reject_by_annotation' can be"): - aahCluster.predict(raw_eeg, reject_by_annotation="101") + aah_cluster.predict(raw_eeg, reject_by_annotation="101") def test_contains_mixin(): """Test contains mixin class.""" - assert "eeg" in aahCluster - assert aahCluster.compensation_grade is None + assert "eeg" in aah_cluster + assert aah_cluster.compensation_grade is None assert ( - aahCluster.get_channel_types() == ["eeg"] * aahCluster._info["nchan"] + aah_cluster.get_channel_types() == ["eeg"] * aah_cluster._info["nchan"] ) # test raise with non-fitted instance @@ -1206,7 +1212,7 @@ def test_contains_mixin(): with pytest.raises( ValueError, match="Instance 'AAHCluster' attribute 'info' is None." ): - "eeg" in aahCluster_ + assert "eeg" in aahCluster_ with pytest.raises( ValueError, match="Instance 'AAHCluster' attribute 'info' is None." ): @@ -1214,13 +1220,13 @@ def test_contains_mixin(): with pytest.raises( ValueError, match="Instance 'AAHCluster' attribute 'info' is None." ): - aahCluster_.compensation_grade + _ = aahCluster_.compensation_grade def test_montage_mixin(): """Test montage mixin class.""" - aahCluster_ = aahCluster.copy() - montage = aahCluster.get_montage() + aahCluster_ = aah_cluster.copy() + montage = aah_cluster.get_montage() assert isinstance(montage, DigMontage) assert montage.dig[-1]["r"][0] != 0 montage.dig[-1]["r"][0] = 0 @@ -1249,11 +1255,11 @@ def test_save(tmp_path, caplog): """Test .save() method.""" # writing to .fif fname1 = tmp_path / "cluster.fif" - aahCluster.save(fname1) + aah_cluster.save(fname1) # writing to .gz (compression) fname2 = tmp_path / "cluster.fif.gz" - aahCluster.save(fname2) + aah_cluster.save(fname2) # re-load caplog.clear() @@ -1265,12 +1271,12 @@ def test_save(tmp_path, caplog): assert __version__ not in caplog.text # compare - assert aahCluster == aahCluster1 - assert aahCluster == aahCluster2 + assert aah_cluster == aahCluster1 + assert aah_cluster == aahCluster2 assert aahCluster1 == aahCluster2 # sanity-check # test prediction - segmentation = aahCluster.predict(raw_eeg, picks="eeg") + segmentation = aah_cluster.predict(raw_eeg, picks="eeg") segmentation1 = aahCluster1.predict(raw_eeg, picks="eeg") segmentation2 = aahCluster2.predict(raw_eeg, picks="eeg") @@ -1281,20 +1287,20 @@ def test_save(tmp_path, caplog): def test_comparison(caplog): """Test == and != methods.""" - aahCluster1 = aahCluster.copy() - aahCluster2 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() + aahCluster2 = aah_cluster.copy() assert aahCluster1 == aahCluster2 # with different aahClustermeans variables aahCluster1.fitted = False assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._ignore_polarity = False assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._normalize_input = True assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._GEV_ = 0.101 assert aahCluster1 != aahCluster2 @@ -1302,13 +1308,13 @@ def test_comparison(caplog): assert aahCluster1 != 101 # with different base variables - aahCluster1 = aahCluster.copy() - aahCluster2 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() + aahCluster2 = aah_cluster.copy() assert aahCluster1 == aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._n_clusters = 101 assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._info = ChInfo( ch_names=[ str(k) for k in range(aahCluster1._cluster_centers_.shape[1]) @@ -1316,16 +1322,16 @@ def test_comparison(caplog): ch_types=["eeg"] * aahCluster1._cluster_centers_.shape[1], ) assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._labels_ = aahCluster1._labels_[::-1] assert aahCluster1 != aahCluster2 - aahCluster1 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() aahCluster1._fitted_data = aahCluster1._fitted_data[:, ::-1] assert aahCluster1 != aahCluster2 # different cluster names - aahCluster1 = aahCluster.copy() - aahCluster2 = aahCluster.copy() + aahCluster1 = aah_cluster.copy() + aahCluster2 = aah_cluster.copy() caplog.clear() assert aahCluster1 == aahCluster2 assert "Cluster names differ between both clustering" not in caplog.text diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index 52e6e3d7..d0c3da63 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -222,8 +222,6 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): fit_parameters[ "normalize_input" ] = AAHCluster._check_normalize_input(value) - elif key == "tol": - fit_parameters["tol"] = AAHCluster._check_tol(value) if key == "GEV_": _check_type(value, ("numeric",), "GEV_") if value < 0 or 1 < value: @@ -356,7 +354,7 @@ def _check_fit_parameters_and_variables( "variables": ["GEV_"], }, "AAHCluster": { - "parameters": ["ignore_polarity", "normalize_input", "tol"], + "parameters": ["ignore_polarity", "normalize_input"], "variables": ["GEV_"], }, } @@ -412,7 +410,6 @@ def _create_AAHCluster( normalize_input: bool, GEV_: float, ): - """Create a AAHCluster object.""" cluster = AAHCluster( cluster_centers_.shape[0], ignore_polarity, normalize_input From 149e1eeb8a86b6445f893115af3b07779caa2b92 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Wed, 2 Nov 2022 09:12:45 +0900 Subject: [PATCH 14/24] fixed some more linting warnings --- pycrostates/cluster/aahc.py | 4 +++- pycrostates/cluster/tests/test_aahc.py | 6 ++++-- pycrostates/io/fiff.py | 2 ++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 58d51571..2e62a9ed 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -141,9 +141,10 @@ def _check_fit(self): # sanity-check assert self.GEV_ is not None + # pylint: disable=arguments-differ @copy_doc(_BaseCluster.fit) @fill_doc - def fit( # pylint: disable=arguments-differ + def fit( self, inst: Union[BaseRaw, BaseEpochs], picks: Picks = "eeg", @@ -176,6 +177,7 @@ def fit( # pylint: disable=arguments-differ self._cluster_centers_ = maps self._labels_ = segmentation self._fitted = True + # pylint: enable=arguments-differ @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 6c48ef7c..794e825b 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -879,7 +879,7 @@ def test_refit(): assert eeg_ch_names == aahCluster_.info["ch_names"] assert np.allclose(eeg_cluster_centers, aahCluster_.cluster_centers_) - +# pylint: disable=too-many-statements def test_predict_default(caplog): """Test predict method default behaviors.""" # raw, no smoothing, no_edge @@ -994,8 +994,9 @@ def test_predict_default(caplog): assert not np.isclose(segmentation1._labels, segmentation2._labels).all() assert not np.isclose(segmentation1._labels, segmentation3._labels).all() assert not np.isclose(segmentation2._labels, segmentation3._labels).all() +# pylint: enable=too-many-statements - +# pylint: disable=too-many-statements def test_picks_fit_predict(caplog): """Test fitting and prediction with different picks.""" raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) @@ -1167,6 +1168,7 @@ def test_picks_fit_predict(caplog): assert predict_warning in caplog.text assert "Fp1 is set as bad in the instance but was selected" in caplog.text caplog.clear() +# pylint: enable=too-many-statements def test_predict_invalid_arguments(): diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index d0c3da63..314aced1 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -206,6 +206,7 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): continue # ModKMeans + # pylint: disable=protected-access if algorithm == "ModKMeans": if key == "n_init": fit_parameters["n_init"] = ModKMeans._check_n_init(value) @@ -222,6 +223,7 @@ def _prepare_kwargs(algorithm: str, kwargs: dict): fit_parameters[ "normalize_input" ] = AAHCluster._check_normalize_input(value) + # pylint: enable=protected-access if key == "GEV_": _check_type(value, ("numeric",), "GEV_") if value < 0 or 1 < value: From 6ecd6660abc15410f01d90fb46b2917bff3e96e7 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Wed, 2 Nov 2022 09:47:43 +0900 Subject: [PATCH 15/24] Changed AAHCluster interface (removed ignore_polarity parameter) as suggested in https://github.com/vferat/pycrostates/pull/92#issuecomment-1298221735 --- pycrostates/cluster/aahc.py | 12 ++--- pycrostates/cluster/tests/test_aahc.py | 63 ++++++++++++++++---------- pycrostates/io/fiff.py | 7 ++- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 2e62a9ed..5853a595 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -43,8 +43,6 @@ class AAHCluster(_BaseCluster): Parameters ---------- %(n_clusters)s - ignore_polarity : bool - If true, polarity is ignored when computing distances. normalize_input : bool If set, the input data is normalized along the channel dimension. @@ -56,7 +54,6 @@ class AAHCluster(_BaseCluster): def __init__( self, n_clusters: int, - ignore_polarity: bool = True, normalize_input: bool = False, ): super().__init__() @@ -64,9 +61,11 @@ def __init__( self._n_clusters = _BaseCluster._check_n_clusters(n_clusters) self._cluster_names = [str(k) for k in range(self.n_clusters)] - self._ignore_polarity = AAHCluster._check_ignore_polarity( - ignore_polarity - ) + # TODO : ignor_polarity=True for now. + # After _BaseCluster and Metric support ignore_polarity + # make the parameter an argument + # https://github.com/vferat/pycrostates/pull/93#issue-1431122168 + self._ignore_polarity = True self._normalize_input = AAHCluster._check_normalize_input( normalize_input ) @@ -177,6 +176,7 @@ def fit( self._cluster_centers_ = maps self._labels_ = segmentation self._fitted = True + # pylint: enable=arguments-differ @copy_doc(_BaseCluster.save) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 794e825b..4ae24b64 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -47,9 +47,7 @@ # Fit one for general purposes n_clusters = 4 -aah_cluster = AAHCluster( - n_clusters=n_clusters, ignore_polarity=True, normalize_input=False -) +aah_cluster = AAHCluster(n_clusters=n_clusters, normalize_input=False) aah_cluster.fit(ch_data) @@ -92,8 +90,9 @@ raw_sim.info["bads"] = [] -def test_ignore_polarity_true(): - obj = AAHCluster(n_clusters=sim_n_ms, ignore_polarity=True) +def test_default_algorithm(): + obj = AAHCluster(n_clusters=sim_n_ms) + assert obj.ignore_polarity is True obj.fit(raw_sim) # extract cluster centers @@ -116,7 +115,8 @@ def test_ignore_polarity_true(): def test_ignore_polarity_false(): - obj = AAHCluster(n_clusters=sim_n_ms * 2, ignore_polarity=False) + obj = AAHCluster(n_clusters=sim_n_ms * 2) + obj._ignore_polarity = False # pylint: disable=protected-access obj.fit(raw_sim) # extract cluster centers @@ -141,8 +141,11 @@ def test_ignore_polarity_false(): def test_normalize_input_true(): obj = AAHCluster( - n_clusters=sim_n_ms, ignore_polarity=True, normalize_input=True + n_clusters=sim_n_ms, + # ignore_polarity=True, + normalize_input=True, ) + assert obj.ignore_polarity is True obj.fit(raw_sim) # extract cluster centers @@ -223,12 +226,12 @@ def test_aahClusterMeans(): """Test K-Means default functionalities.""" aahCluster1 = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) # Test properties - assert aahCluster1.ignore_polarity is True + # assert aahCluster1.ignore_polarity is True assert aahCluster1.normalize_input is False _check_unfitted(aahCluster1) @@ -551,7 +554,7 @@ def test_properties(caplog): # Unfitted aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) @@ -585,7 +588,7 @@ def test_properties(caplog): # Test fitted property aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) with pytest.raises(TypeError, match="'fitted' must be an instance of"): @@ -613,12 +616,18 @@ def test_invalid_arguments(): with pytest.raises(ValueError, match="The number of clusters must be a"): aahCluster_ = AAHCluster(n_clusters=-101) - # ignore_polarity - with pytest.raises( - TypeError, match="'ignore_polarity' must be an instance of bool" - ): - aahCluster_ = AAHCluster(n_clusters=n_clusters, ignore_polarity="asdf") - aahCluster_ = AAHCluster(n_clusters=n_clusters, ignore_polarity=None) + # # ignore_polarity + # with pytest.raises( + # TypeError, match="'ignore_polarity' must be an instance of bool" + # ): + # aahCluster_ = AAHCluster( + # n_clusters=n_clusters, + # ignore_polarity="asdf" + # ) + # aahCluster_ = AAHCluster( + # n_clusters=n_clusters, + # ignore_polarity=None + # ) # normalize_input with pytest.raises( @@ -629,7 +638,7 @@ def test_invalid_arguments(): aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) # inst @@ -674,7 +683,7 @@ def test_fit_data_shapes(): """Test different tmin/tmax, rejection with fit.""" aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) @@ -851,7 +860,7 @@ def test_refit(): raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) aahCluster_.fit(raw, picks="eeg") @@ -868,7 +877,7 @@ def test_refit(): raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) aahCluster_.fit(raw, picks="eeg") # works @@ -879,6 +888,7 @@ def test_refit(): assert eeg_ch_names == aahCluster_.info["ch_names"] assert np.allclose(eeg_cluster_centers, aahCluster_.cluster_centers_) + # pylint: disable=too-many-statements def test_predict_default(caplog): """Test predict method default behaviors.""" @@ -994,15 +1004,18 @@ def test_predict_default(caplog): assert not np.isclose(segmentation1._labels, segmentation2._labels).all() assert not np.isclose(segmentation1._labels, segmentation3._labels).all() assert not np.isclose(segmentation2._labels, segmentation3._labels).all() + + # pylint: enable=too-many-statements + # pylint: disable=too-many-statements def test_picks_fit_predict(caplog): """Test fitting and prediction with different picks.""" raw = raw_meg.copy().pick_types(meg=True, eeg=True, eog=True) aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) @@ -1168,6 +1181,8 @@ def test_picks_fit_predict(caplog): assert predict_warning in caplog.text assert "Fp1 is set as bad in the instance but was selected" in caplog.text caplog.clear() + + # pylint: enable=too-many-statements @@ -1208,7 +1223,7 @@ def test_contains_mixin(): # test raise with non-fitted instance aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) with pytest.raises( @@ -1239,7 +1254,7 @@ def test_montage_mixin(): # test raise with non-fitted instance aahCluster_ = AAHCluster( n_clusters=n_clusters, - ignore_polarity=True, + # ignore_polarity=True, normalize_input=False, ) with pytest.raises( diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index 314aced1..5ef7f5f8 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -414,7 +414,12 @@ def _create_AAHCluster( ): """Create a AAHCluster object.""" cluster = AAHCluster( - cluster_centers_.shape[0], ignore_polarity, normalize_input + cluster_centers_.shape[0], + # TODO : ignor_polarity=True for now. + # After _BaseCluster and Metric support ignore_polarity + # make the parameter an argument + # ignore_polarity, + normalize_input, ) cluster._cluster_centers_ = cluster_centers_ cluster._info = info From 29c1d64abc5354e9999fe29841d5d7b4a63df968 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Fri, 4 Nov 2022 10:20:10 +0900 Subject: [PATCH 16/24] Changes in clustering interfaces - moved verbose parameter to baseclass - removed ignore_polarity property (for now) - fixed some pylint warnings --- pycrostates/cluster/_base.py | 10 +++- pycrostates/cluster/aahc.py | 63 +++++++++++++++----------- pycrostates/cluster/kmeans.py | 15 +++--- pycrostates/cluster/tests/test_aahc.py | 4 +- pycrostates/io/fiff.py | 2 +- 5 files changed, 56 insertions(+), 38 deletions(-) diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index 86f39f33..dc9bbf76 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -25,7 +25,8 @@ _check_value, ) from ..utils._docs import fill_doc -from ..utils._logs import logger, verbose +from ..utils._logs import _set_verbose, logger +from ..utils._logs import verbose as verbose_decorator from ..utils.mixin import ChannelsMixin, ContainsMixin, MontageMixin from ..viz import plot_cluster_centers @@ -193,6 +194,8 @@ def fit( tmax: Optional[Union[int, float]] = None, reject_by_annotation: bool = True, n_jobs: int = 1, + *, + verbose: Optional[str] = None, ) -> NDArray[float]: """Compute cluster centers. @@ -214,9 +217,12 @@ def fit( %(tmax_raw)s %(reject_by_annotation_raw)s %(n_jobs)s + %(verbose)s """ from ..io import ChData, ChInfo + _set_verbose(verbose) + self._check_unfitted() n_jobs = _check_n_jobs(n_jobs) _check_type(inst, (BaseRaw, BaseEpochs, ChData), item_name="inst") @@ -554,7 +560,7 @@ def save(self, fname: Union[str, Path]): self._check_fit() _check_type(fname, ("path-like",), "fname") - @verbose + @verbose_decorator def predict( self, inst: Union[BaseRaw, BaseEpochs], diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index 5853a595..d675ebfb 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -13,7 +13,7 @@ from ..utils import _corr_vectors from ..utils._checks import _check_type from ..utils._docs import copy_doc, fill_doc -from ..utils._logs import _set_verbose, logger +from ..utils._logs import logger from ._base import _BaseCluster # if we have numba, use its jit interface @@ -140,9 +140,7 @@ def _check_fit(self): # sanity-check assert self.GEV_ is not None - # pylint: disable=arguments-differ @copy_doc(_BaseCluster.fit) - @fill_doc def fit( self, inst: Union[BaseRaw, BaseEpochs], @@ -154,12 +152,14 @@ def fit( *, verbose: Optional[str] = None, ) -> None: - """ - %(verbose)s - """ - _set_verbose(verbose) # TODO: decorator nesting is failing data = super().fit( - inst, picks, tmin, tmax, reject_by_annotation, n_jobs + inst, + picks=picks, + tmin=tmin, + tmax=tmax, + reject_by_annotation=reject_by_annotation, + n_jobs=n_jobs, + verbose=verbose, ) gev, maps, segmentation = AAHCluster._aahc( @@ -177,8 +177,6 @@ def fit( self._labels_ = segmentation self._fitted = True - # pylint: enable=arguments-differ - @copy_doc(_BaseCluster.save) def save(self, fname: Union[str, Path]): super().save(fname) @@ -240,8 +238,11 @@ def _compute_maps( assignment = np.arange(n_frame) + n_steps = n_frame - n_clusters while cluster.shape[1] > n_clusters: + step = n_frame - cluster.shape[1] + to_remove = np.argmin(GEV) orphans = assignment == to_remove @@ -268,22 +269,40 @@ def _compute_maps( sgn = np.sign(old_cluster @ cluster[:, c]) - v0 = ( + first_pc_init = ( old_weight * sgn * old_cluster + (1.0 - old_weight) * cluster[:, c] ) - v, _, converged = AAHCluster._first_principal_component( - data[:, members], v0, max_iter=n_chan + ( + first_pc, + _, + converged, + ) = AAHCluster._first_principal_component( + data[:, members], first_pc_init, max_iter=n_chan ) - if not converged: + if converged: + logger.debug( + "AAHC %d/%d: " "power iterations converged.", + step, + n_steps, + ) + else: + logger.debug( + "AAHC %d/%d " + "power iteration did not converge. " + "Computing fallback solution.", + step, + n_steps, + ) # fall back to covariance estimation # and eigenvalue computation - Cxx = data[:, members] @ data[:, members].T - _, V = np.linalg.eigh(Cxx) - v = V[:, -1] + data_cov = data[:, members] @ data[:, members].T + _, evecs = np.linalg.eigh(data_cov) + # eigenvector at largest eigenvalue + first_pc = evecs[:, -1] - cluster[:, c] = v + cluster[:, c] = first_pc else: cluster[:, c] = np.mean(data[:, members], axis=1) cluster[:, c] /= np.linalg.norm( @@ -350,14 +369,6 @@ def _first_principal_component( # -------------------------------------------------------------------- - @property - def ignore_polarity(self) -> bool: - """If true, polarity is ignored when computing distances. - - :type: `bool` - """ - return self._ignore_polarity - @property def normalize_input(self) -> bool: """If set, the input data is normalized along the channel dimension. diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 5ab507a5..2f3e0f31 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -14,7 +14,7 @@ from ..utils import _corr_vectors from ..utils._checks import _check_random_state, _check_type from ..utils._docs import copy_doc, fill_doc -from ..utils._logs import _set_verbose, logger +from ..utils._logs import logger from ._base import _BaseCluster @@ -137,7 +137,6 @@ def _check_fit(self): assert self.GEV_ is not None @copy_doc(_BaseCluster.fit) - @fill_doc def fit( self, inst: Union[BaseRaw, BaseEpochs], @@ -149,12 +148,14 @@ def fit( *, verbose: Optional[str] = None, ) -> None: - """ - %(verbose)s - """ - _set_verbose(verbose) # TODO: decorator nesting is failing data = super().fit( - inst, picks, tmin, tmax, reject_by_annotation, n_jobs + inst, + picks=picks, + tmin=tmin, + tmax=tmax, + reject_by_annotation=reject_by_annotation, + n_jobs=n_jobs, + verbose=verbose, ) inits = self._random_state.randint( diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 4ae24b64..7502739a 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -92,7 +92,7 @@ def test_default_algorithm(): obj = AAHCluster(n_clusters=sim_n_ms) - assert obj.ignore_polarity is True + assert obj._ignore_polarity is True # pylint: disable=protected-access obj.fit(raw_sim) # extract cluster centers @@ -145,7 +145,7 @@ def test_normalize_input_true(): # ignore_polarity=True, normalize_input=True, ) - assert obj.ignore_polarity is True + assert obj._ignore_polarity is True # pylint: disable=protected-access obj.fit(raw_sim) # extract cluster centers diff --git a/pycrostates/io/fiff.py b/pycrostates/io/fiff.py index 5ef7f5f8..5e76e96f 100644 --- a/pycrostates/io/fiff.py +++ b/pycrostates/io/fiff.py @@ -408,7 +408,7 @@ def _create_AAHCluster( cluster_names: List[str], fitted_data: NDArray[float], labels_: NDArray[int], - ignore_polarity: bool, + ignore_polarity: bool, # pylint: disable=unused-argument normalize_input: bool, GEV_: float, ): From 06bff2403a5bfdf43ef798893adaf84ae084c711 Mon Sep 17 00:00:00 2001 From: rkobler Date: Fri, 4 Nov 2022 10:21:03 +0900 Subject: [PATCH 17/24] Update pycrostates/cluster/tests/test_aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/tests/test_aahc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 7502739a..4023c540 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -223,7 +223,7 @@ def _check_fitted_data_epochs(fitted_data, epochs, picks, tmin, tmax): def test_aahClusterMeans(): - """Test K-Means default functionalities.""" + """Test AAHC default functionalities.""" aahCluster1 = AAHCluster( n_clusters=n_clusters, # ignore_polarity=True, From 8b12166e94bb8569074adee3677d9302a22c03b9 Mon Sep 17 00:00:00 2001 From: rkobler Date: Fri, 4 Nov 2022 10:21:18 +0900 Subject: [PATCH 18/24] Update pycrostates/cluster/tests/test_aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/tests/test_aahc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 4023c540..e1a7c8c5 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -183,9 +183,7 @@ def _check_fitted(aah_cluster): def _check_unfitted(aah_cluster): - """ - Checks that the aah_cluster is not fitted. - """ + """Check that the aah_cluster is not fitted.""" assert not aah_cluster.fitted assert aah_cluster.n_clusters == n_clusters assert len(aah_cluster._cluster_names) == n_clusters From 7489d0b6078ad2d15dbf82f37657e59f1ce682a8 Mon Sep 17 00:00:00 2001 From: rkobler Date: Fri, 4 Nov 2022 10:21:29 +0900 Subject: [PATCH 19/24] Update pycrostates/cluster/tests/test_aahc.py Co-authored-by: Mathieu Scheltienne --- pycrostates/cluster/tests/test_aahc.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index e1a7c8c5..ec4b762d 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -169,9 +169,7 @@ def test_normalize_input_true(): # pylint: disable=protected-access def _check_fitted(aah_cluster): - """ - Checks that the aah_cluster is fitted. - """ + """Check that the aah_cluster is fitted.""" assert aah_cluster.fitted assert aah_cluster.n_clusters == n_clusters assert len(aah_cluster._cluster_names) == n_clusters From dcf7252e71bbc7a465932d86aae26b2be037f158 Mon Sep 17 00:00:00 2001 From: Reinmar Kobler Date: Fri, 4 Nov 2022 10:52:49 +0900 Subject: [PATCH 20/24] Changes in pycrostats/cluster/tests/test_aahc.py Fixed pytest.raises misue erro --- pycrostates/cluster/tests/test_aahc.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index ec4b762d..eb91bdef 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -612,24 +612,14 @@ def test_invalid_arguments(): with pytest.raises(ValueError, match="The number of clusters must be a"): aahCluster_ = AAHCluster(n_clusters=-101) - # # ignore_polarity - # with pytest.raises( - # TypeError, match="'ignore_polarity' must be an instance of bool" - # ): - # aahCluster_ = AAHCluster( - # n_clusters=n_clusters, - # ignore_polarity="asdf" - # ) - # aahCluster_ = AAHCluster( - # n_clusters=n_clusters, - # ignore_polarity=None - # ) - # normalize_input with pytest.raises( TypeError, match="'normalize_input' must be an instance of bool" ): aahCluster_ = AAHCluster(n_clusters=n_clusters, normalize_input="asdf") + with pytest.raises( + TypeError, match="'normalize_input' must be an instance of bool" + ): aahCluster_ = AAHCluster(n_clusters=n_clusters, normalize_input=None) aahCluster_ = AAHCluster( From 920b578eb3b723e98277eae01d877127770e1443 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Fri, 4 Nov 2022 10:37:33 +0100 Subject: [PATCH 21/24] remove duplicate docstrings --- pycrostates/cluster/aahc.py | 2 -- pycrostates/cluster/kmeans.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index d675ebfb..d1b1d4ab 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -107,7 +107,6 @@ def _repr_html_(self, caption=None): @copy_doc(_BaseCluster.__eq__) def __eq__(self, other: Any) -> bool: - """Equality == method.""" if isinstance(other, AAHCluster): if not super().__eq__(other): return False @@ -131,7 +130,6 @@ def __eq__(self, other: Any) -> bool: @copy_doc(_BaseCluster.__ne__) def __ne__(self, other: Any) -> bool: - """Different != method.""" return not self.__eq__(other) @copy_doc(_BaseCluster._check_fit) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 2f3e0f31..e5ba3dde 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -100,7 +100,6 @@ def _repr_html_(self, caption=None): @copy_doc(_BaseCluster.__eq__) def __eq__(self, other: Any) -> bool: - """Equality == method.""" if isinstance(other, ModKMeans): if not super().__eq__(other): return False @@ -127,7 +126,6 @@ def __eq__(self, other: Any) -> bool: @copy_doc(_BaseCluster.__ne__) def __ne__(self, other: Any) -> bool: - """Different != method.""" return not self.__eq__(other) @copy_doc(_BaseCluster._check_fit) From d628ee4b0cc81fb2bd99f603ad6d2e9ff23b3581 Mon Sep 17 00:00:00 2001 From: Victor Ferat Date: Fri, 4 Nov 2022 15:15:01 +0100 Subject: [PATCH 22/24] Remove n_jobs from _BaseCluster.fit --- pycrostates/cluster/_base.py | 4 ---- pycrostates/cluster/aahc.py | 2 -- pycrostates/cluster/kmeans.py | 7 +++++-- pycrostates/cluster/tests/test_aahc.py | 22 ++++++++-------------- 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/pycrostates/cluster/_base.py b/pycrostates/cluster/_base.py index dc9bbf76..4c2db066 100644 --- a/pycrostates/cluster/_base.py +++ b/pycrostates/cluster/_base.py @@ -17,7 +17,6 @@ from ..segmentation import EpochsSegmentation, RawSegmentation from ..utils import _corr_vectors from ..utils._checks import ( - _check_n_jobs, _check_picks_uniqueness, _check_reject_by_annotation, _check_tmin_tmax, @@ -193,7 +192,6 @@ def fit( tmin: Optional[Union[int, float]] = None, tmax: Optional[Union[int, float]] = None, reject_by_annotation: bool = True, - n_jobs: int = 1, *, verbose: Optional[str] = None, ) -> NDArray[float]: @@ -216,7 +214,6 @@ def fit( %(tmin_raw)s %(tmax_raw)s %(reject_by_annotation_raw)s - %(n_jobs)s %(verbose)s """ from ..io import ChData, ChInfo @@ -224,7 +221,6 @@ def fit( _set_verbose(verbose) self._check_unfitted() - n_jobs = _check_n_jobs(n_jobs) _check_type(inst, (BaseRaw, BaseEpochs, ChData), item_name="inst") if isinstance(inst, (BaseRaw, BaseEpochs)): tmin, tmax = _check_tmin_tmax(inst, tmin, tmax) diff --git a/pycrostates/cluster/aahc.py b/pycrostates/cluster/aahc.py index d1b1d4ab..84e3baa8 100644 --- a/pycrostates/cluster/aahc.py +++ b/pycrostates/cluster/aahc.py @@ -146,7 +146,6 @@ def fit( tmin: Optional[Union[int, float]] = None, tmax: Optional[Union[int, float]] = None, reject_by_annotation: bool = True, - n_jobs: int = 1, *, verbose: Optional[str] = None, ) -> None: @@ -156,7 +155,6 @@ def fit( tmin=tmin, tmax=tmax, reject_by_annotation=reject_by_annotation, - n_jobs=n_jobs, verbose=verbose, ) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index e5ba3dde..fbeab92d 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -12,7 +12,7 @@ from .._typing import Picks, RANDomState from ..utils import _corr_vectors -from ..utils._checks import _check_random_state, _check_type +from ..utils._checks import _check_random_state, _check_type, _check_n_jobs from ..utils._docs import copy_doc, fill_doc from ..utils._logs import logger from ._base import _BaseCluster @@ -146,13 +146,16 @@ def fit( *, verbose: Optional[str] = None, ) -> None: + """ + %(n_jobs)s + """ + n_jobs = _check_n_jobs(n_jobs) data = super().fit( inst, picks=picks, tmin=tmin, tmax=tmax, reject_by_annotation=reject_by_annotation, - n_jobs=n_jobs, verbose=verbose, ) diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index eb91bdef..83fc0c7b 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -235,7 +235,7 @@ def test_aahClusterMeans(): assert aahCluster1._cluster_names == ["0", "1", "2", "3"] # Test fit on RAW - aahCluster1.fit(raw_eeg, n_jobs=1) + aahCluster1.fit(raw_eeg) _check_fitted(aahCluster1) assert aahCluster1._cluster_centers_.shape == ( n_clusters, @@ -247,7 +247,7 @@ def test_aahClusterMeans(): _check_unfitted(aahCluster1) # Test fit on Epochs - aahCluster1.fit(epochs_eeg, n_jobs=1) + aahCluster1.fit(epochs_eeg) _check_fitted(aahCluster1) assert aahCluster1._cluster_centers_.shape == ( n_clusters, @@ -256,7 +256,7 @@ def test_aahClusterMeans(): # Test fit on ChData aahCluster1.fitted = False - aahCluster1.fit(ch_data, n_jobs=1) + aahCluster1.fit(ch_data) _check_fitted(aahCluster1) assert aahCluster1._cluster_centers_.shape == ( n_clusters, @@ -678,7 +678,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( raw_eeg, - n_jobs=1, picks="eeg", tmin=5, tmax=None, @@ -694,7 +693,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( epochs_eeg, - n_jobs=1, picks="eeg", tmin=0.2, tmax=None, @@ -709,7 +707,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( raw_eeg, - n_jobs=1, picks="eeg", tmin=None, tmax=5, @@ -725,7 +722,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( epochs_eeg, - n_jobs=1, picks="eeg", tmin=None, tmax=0.3, @@ -740,7 +736,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( raw_eeg, - n_jobs=1, picks="eeg", tmin=2, tmax=8, @@ -754,7 +749,6 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_.fit( epochs_eeg, - n_jobs=1, picks="eeg", tmin=0.1, tmax=0.4, @@ -775,11 +769,11 @@ def test_fit_data_shapes(): _check_unfitted(aahCluster_) aahCluster_no_reject = aahCluster_.copy() - aahCluster_no_reject.fit(raw_, n_jobs=1, reject_by_annotation=False) + aahCluster_no_reject.fit(raw_, reject_by_annotation=False) aahCluster_reject_True = aahCluster_.copy() - aahCluster_reject_True.fit(raw_, n_jobs=1, reject_by_annotation=True) + aahCluster_reject_True.fit(raw_, reject_by_annotation=True) aahCluster_reject_omit = aahCluster_.copy() - aahCluster_reject_omit.fit(raw_, n_jobs=1, reject_by_annotation="omit") + aahCluster_reject_omit.fit(raw_, reject_by_annotation="omit") # Compare 'omit' and True assert np.isclose( @@ -821,11 +815,11 @@ def test_fit_data_shapes(): # Check with reject with tmin/tmax aahCluster_rej_0_5 = aahCluster_.copy() aahCluster_rej_0_5.fit( - raw_, n_jobs=1, tmin=0, tmax=5, reject_by_annotation=True + raw_, tmin=0, tmax=5, reject_by_annotation=True ) aahCluster_rej_5_end = aahCluster_.copy() aahCluster_rej_5_end.fit( - raw_, n_jobs=1, tmin=5, tmax=None, reject_by_annotation=True + raw_, tmin=5, tmax=None, reject_by_annotation=True ) _check_fitted(aahCluster_rej_0_5) _check_fitted(aahCluster_rej_5_end) From f6a74f7a3c28aadbeee8669d66a5441a7f45d894 Mon Sep 17 00:00:00 2001 From: Victor Ferat Date: Fri, 4 Nov 2022 15:15:27 +0100 Subject: [PATCH 23/24] Fix style --- pycrostates/cluster/kmeans.py | 2 +- pycrostates/cluster/tests/test_aahc.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index fbeab92d..28b884d0 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -12,7 +12,7 @@ from .._typing import Picks, RANDomState from ..utils import _corr_vectors -from ..utils._checks import _check_random_state, _check_type, _check_n_jobs +from ..utils._checks import _check_n_jobs, _check_random_state, _check_type from ..utils._docs import copy_doc, fill_doc from ..utils._logs import logger from ._base import _BaseCluster diff --git a/pycrostates/cluster/tests/test_aahc.py b/pycrostates/cluster/tests/test_aahc.py index 83fc0c7b..2be65595 100644 --- a/pycrostates/cluster/tests/test_aahc.py +++ b/pycrostates/cluster/tests/test_aahc.py @@ -814,9 +814,7 @@ def test_fit_data_shapes(): # Check with reject with tmin/tmax aahCluster_rej_0_5 = aahCluster_.copy() - aahCluster_rej_0_5.fit( - raw_, tmin=0, tmax=5, reject_by_annotation=True - ) + aahCluster_rej_0_5.fit(raw_, tmin=0, tmax=5, reject_by_annotation=True) aahCluster_rej_5_end = aahCluster_.copy() aahCluster_rej_5_end.fit( raw_, tmin=5, tmax=None, reject_by_annotation=True From aa4ef4df011017b062d7c8609fd3febcdd6b3bf4 Mon Sep 17 00:00:00 2001 From: Mathieu Scheltienne Date: Fri, 4 Nov 2022 15:31:50 +0100 Subject: [PATCH 24/24] fix docstring --- pycrostates/cluster/kmeans.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/pycrostates/cluster/kmeans.py b/pycrostates/cluster/kmeans.py index 28b884d0..ed1ab059 100644 --- a/pycrostates/cluster/kmeans.py +++ b/pycrostates/cluster/kmeans.py @@ -134,7 +134,7 @@ def _check_fit(self): # sanity-check assert self.GEV_ is not None - @copy_doc(_BaseCluster.fit) + @fill_doc def fit( self, inst: Union[BaseRaw, BaseEpochs], @@ -146,8 +146,27 @@ def fit( *, verbose: Optional[str] = None, ) -> None: - """ + """Compute cluster centers. + + Parameters + ---------- + inst : Raw | Epochs | ChData + MNE `~mne.io.Raw`, `~mne.Epochs` or `~pycrostates.io.ChData` object + from which to extract :term:`cluster centers`. + picks : str | list | slice | None + Channels to include. Note that all channels selected must have the + same type. Slices and lists of integers will be interpreted as + channel indices. In lists, channel name strings (e.g. + ``['Fp1', 'Fp2']``) will pick the given channels. Can also be the + string values “all” to pick all channels, or “data” to pick data + channels. ``"eeg"`` (default) will pick all eeg channels. + Note that channels in ``info['bads']`` will be included if their + names or indices are explicitly provided. + %(tmin_raw)s + %(tmax_raw)s + %(reject_by_annotation_raw)s %(n_jobs)s + %(verbose)s """ n_jobs = _check_n_jobs(n_jobs) data = super().fit(