Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jun 29, 2023
1 parent 9dbfa3b commit 9098b2e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 11 deletions.
Binary file not shown.
Binary file not shown.
64 changes: 53 additions & 11 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import numpy as np
from numpy.testing import (assert_allclose, assert_array_almost_equal,
assert_array_less)
import pandas as pd
import pytest
from mne import (EpochsArray, SourceEstimate, create_info)
from mne.filter import filter_data
Expand Down Expand Up @@ -408,6 +410,44 @@ def test_spectral_connectivity(method, mode):
assert (out_lens[0] == 10)


def test_multivariate_spectral_connectivity_epochs_regression():
"""Test multivar. spectral connectivity over epochs for regression.
The multivariate methods were originally implemented in MATLAB by their
respective authors. To show that this Python implementation is identical
and to avoid any future regressions, we compare the results of the Python
and MATLAB implementations on some example data (randomly generated).
As the code for computing the cross-spectral density matrix is not
available in MATLAB, the CSD matrix was computed using MNE and then loaded
into MATLAB to compute the connectivity from the original MATLAB
implementations using the same settings in MATLAB and Python.
It is therefore important that no changes are made to the settings for
computing the CSD or the final connectivity scores!
"""
fpath = os.path.dirname(os.path.realpath(__file__))
data = pd.read_pickle(os.path.join(fpath, 'example_multivariate_data.pkl'))
sfreq = 100
indices = tuple([[0, 1], [2, 3]])
methods = ['mic', 'mim', 'gc', 'gc_tr']
con = spectral_connectivity_epochs(
data, method=methods, indices=indices,
mode='multitaper', sfreq=sfreq, fskip=0, faverage=False, tmin=0,
tmax=None, mt_bandwidth=4, mt_low_bias=True, mt_adaptive=False,
gc_n_lags=20, rank=tuple([[2], [2]]), n_jobs=1)

# should take the absolute of the MIC scores, as the signs vary depending
# on the eigendecomposition implementation (leads to sign differences
# between MATLAB & Python)
mne_results = {this_con.method: np.abs(this_con.get_data())
for this_con in con}
matlab_results = pd.read_pickle(
os.path.join(fpath, 'example_multivariate_matlab_results.pkl'))
for method in methods:
assert_allclose(matlab_results[method], mne_results[method], 1e-5)


@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr',
['mic', 'mim', 'gc', 'gc_tr']])
@pytest.mark.parametrize('mode', ['multitaper', 'fourier', 'cwt_morlet'])
Expand Down Expand Up @@ -540,7 +580,7 @@ def test_multivar_spectral_connectivity_error_catch(method, mode):
con[method.index('gc_tr')].get_data())


@pytest.mark.parametrize('method', ['mic', 'mim', 'gc'])
@pytest.mark.parametrize('method', ['mic', 'mim', 'gc', 'gc_tr'])
def test_multivar_spectral_connectivity_parallel(method):
"""Test multivar. freq.-domain connectivity methods run in parallel."""
sfreq = 50.
Expand Down Expand Up @@ -669,28 +709,30 @@ def test_spectral_connectivity_time_phaselocked(method, mode, data_option):
con = spectral_connectivity_time(
data, freqs, method=method, mode=mode, sfreq=sfreq,
fmin=freq_band_low_limit, fmax=freq_band_high_limit, n_jobs=1,
faverage=True if method not in multivar_methods else False,
average=True if method not in multivar_methods else False, sm_times=0)
faverage=True if method != 'mic' else False,
average=True if method != 'mic' else False, sm_times=0)
con_matrix = con.get_data()

# MIC values can be pos. and neg., so averaging across freqs./epochs means
# connectivity estimates != 1
# MIC values can be pos. and neg., so must be averaged after taking the
# absolute values for the test to work
if method in multivar_methods:
assert con.shape == (n_epochs, 1, len(con.freqs))
if method == 'mic':
con_matrix = np.mean(np.abs(con_matrix), axis=(0, 2))
assert con.shape == (n_epochs, 1, len(con.freqs))
else:
assert con.shape == (1, len(con.freqs))
else:
assert con.shape == (n_channels ** 2, len(con.freqs))

con_matrix = con.get_data()[..., 0]
if data_option == 'sync':
# signals are perfectly phase-locked, connectivity matrix should be
# a matrix of ones
assert np.allclose(
np.abs(con_matrix), np.ones(con_matrix.shape), atol=0.01)
assert np.allclose(con_matrix, np.ones(con_matrix.shape), atol=0.01)
if data_option == 'random':
# signals are random, all connectivity values should be small
# 0.5 is picked rather arbitrarily such that the obsolete wrong
# implementation fails
assert np.all(np.round(np.abs(con_matrix), 2) <= 0.5)

assert np.all(con_matrix <= 0.5)


@pytest.mark.parametrize('method', ['coh', 'plv', 'pli', 'wpli', 'ciplv'])
Expand Down

0 comments on commit 9098b2e

Please sign in to comment.