Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GSOC] Add support for multiple components of multivariate connectivity #213

Merged
merged 67 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
2fa5aca
Add filter storage
tsbinns May 28, 2024
6e206ce
Merge branch 'mne-tools:main' into decoding
tsbinns May 28, 2024
68a0890
Refactor results reshaping
tsbinns May 29, 2024
0368367
Fix filter indexing for storage
tsbinns May 29, 2024
bb75cb1
Update fill_doc dictionary
tsbinns May 29, 2024
115757c
Add n_components to ingored numpydoc words
tsbinns May 29, 2024
e7c9da9
Add decoding module
tsbinns May 29, 2024
5e51923
Update API with decoding module
tsbinns May 29, 2024
9fe0f28
Rename file and add suport for cwt_morlet mode
tsbinns May 29, 2024
6302a4d
Merge branch 'main' into decoding
tsbinns May 30, 2024
fa6a001
Make property docstrings private
tsbinns Jun 3, 2024
15e99e4
Bug fix error check
tsbinns Jun 3, 2024
58eca90
Bug fix fit_transform no return
tsbinns Jun 3, 2024
9b43dfb
Bug fix _check_X 2d array
tsbinns Jun 3, 2024
bb0b520
Add preliminary decomp example
tsbinns Jun 3, 2024
cdc7ce4
Merge branch 'main' into decoding
tsbinns Jun 3, 2024
8cd98e1
Merge branch 'main' into decoding
tsbinns Jun 5, 2024
a3bf253
Switch to cleaner epoch indexing
tsbinns Jun 5, 2024
ea48ce3
Fix spelling error
tsbinns Jun 5, 2024
9bcff58
Update error checking
tsbinns Jun 5, 2024
17cfeab
Merge branch 'decoding' of https://github.com/tsbinns/mne-connectivit…
tsbinns Jun 5, 2024
6c8ae97
Bug fix indices setter wrong format
tsbinns Jun 6, 2024
55b14ab
Add unit tests
tsbinns Jun 6, 2024
3861da5
Update example from review
tsbinns Jun 6, 2024
025e6c1
Update cwt_morlet params
tsbinns Jun 6, 2024
cbfcc13
Fix platform-specific failing unit test
tsbinns Jun 6, 2024
8339d40
Refactor decomposition classes
tsbinns Jun 6, 2024
bdfb626
Add test reminder
tsbinns Jun 10, 2024
a0e9560
Add decomposition plotting
tsbinns Jun 10, 2024
840f4f6
Update tests and fix getter/setters
tsbinns Jun 12, 2024
f9bc90f
Merge branch 'main' into decoding
tsbinns Jun 12, 2024
d1b35cf
Merge branch 'main' into decoding
tsbinns Jun 12, 2024
8afa1b0
Merge remote-tracking branch 'upstream/main' into decoding
tsbinns Jun 12, 2024
b8b57bb
Switch from matmul to at
tsbinns Jun 12, 2024
8cde3a6
Merge branch 'decoding' into decoding_plotting
tsbinns Jun 12, 2024
88eafb5
Shorten tests with kwargs
tsbinns Jun 13, 2024
19931fd
Merge branch 'decoding' into decoding_plotting
tsbinns Jun 13, 2024
163d802
Add decomp class to main init
tsbinns Jun 13, 2024
865fa67
Update plotting docstrings
tsbinns Jun 13, 2024
ccdad63
Add docs authorship
tsbinns Jun 13, 2024
c026d3e
Archive old example
tsbinns Jun 19, 2024
ad35318
Merge remote-tracking branch 'upstream/main' into decoding_plotting
tsbinns Jun 19, 2024
03186a0
Update decomp example formatting
tsbinns Jun 20, 2024
1314cb4
Add decomp plotting example
tsbinns Jun 20, 2024
0a3f7ac
Update epochs_multivar formatting and docs
tsbinns Jun 20, 2024
18eb0d5
Merge remote-tracking branch 'upstream/main' into decoding_ncomps
tsbinns Jun 28, 2024
926e6d6
Add support for multiple comps in Decomp Class
tsbinns Jun 28, 2024
72bcee3
Merge remote-tracking branch 'upstream/main' into decoding_ncomps
tsbinns Jul 10, 2024
f981e05
Merge missing changes from multi comps and upstream
tsbinns Jul 10, 2024
a2e3885
Add support for multiple comps in spec_conn_epochs
tsbinns Jul 15, 2024
80e9360
Refactor decomposition n_comps checking
tsbinns Jul 15, 2024
a357214
Fix failing linux tests
tsbinns Jul 15, 2024
d6a0688
Add support for multiple comps in spec_conn_time
tsbinns Jul 15, 2024
0f5ad95
Improve container reprs
tsbinns Jul 15, 2024
69a8abd
Simplify multiple comps check
tsbinns Jul 15, 2024
00ea4be
Update container tests
tsbinns Jul 15, 2024
01ee79d
Update container get_data method
tsbinns Jul 15, 2024
9fad5dd
Merge remote-tracking branch 'upstream/main' into decoding_ncomps
tsbinns Jul 15, 2024
1e16cba
Add comps to xref ignore
tsbinns Jul 15, 2024
6d4a403
Clean up container tests
tsbinns Jul 15, 2024
571ab17
Merge branch 'main' into decoding_ncomps
tsbinns Jul 16, 2024
93a3bdd
Simplify data transformation from review
tsbinns Jul 16, 2024
ca3fac0
Update comments
tsbinns Jul 22, 2024
d331bce
Update from review
tsbinns Jul 22, 2024
69d9a3d
Update to np.testing assertions
tsbinns Jul 22, 2024
dc4229d
Merge branch 'decoding_ncomps' of https://github.com/tsbinns/mne-conn…
tsbinns Jul 22, 2024
d566b3b
Merge branch 'main' into decoding_ncomps
tsbinns Jul 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
"epochs",
"freqs",
"times",
"components",
"arrays",
"lists",
"func",
Expand Down
4 changes: 3 additions & 1 deletion examples/decoding/cohy_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@
mode="multitaper",
fmin=FMIN,
fmax=FMAX,
rank=(3, 3),
rank=(3, 3), # project to rank subspace to avoid overfitting to noise
n_components=1, # the data contains only one simulated component of connectivity
)

########################################################################################
Expand Down Expand Up @@ -371,6 +372,7 @@
fmin=FMIN,
fmax=FMAX,
rank=(3, 3),
n_components=1,
)

# Time fitting of filters
Expand Down
10 changes: 6 additions & 4 deletions examples/decoding/cohy_decomposition_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@
# bound between :math:`[-1, 1]`.
#
# Plotting the patterns for 20-30 Hz connectivity below, we find the strongest
# connectivity between the left and right hemispheres comes from centromedial left and
# frontolateral right sensors, based on the areas with the largest absolute values. As
# these patterns come from decomposition on sensor-space data, we make no assumptions
# about the underlying brain regions involved in this connectivity.
# connectivity ('MIC0', i.e. 1st component) between the left and right hemispheres comes
# from centromedial left and frontolateral right sensors, based on the areas with the
# largest absolute values. Patterns for the weaker connectivity components ('MIC1' &
# 'MIC2' are also shown). As these patterns come from decomposition on sensor-space
# data, we make no assumptions about the underlying brain regions involved in this
# connectivity.

# %%

Expand Down
111 changes: 66 additions & 45 deletions mne_connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,9 @@ def __repr__(self) -> str:
if "times" in self.dims:
r += f"time : [{self.times[0]}, {self.times[-1]}], " # type: ignore
r += f", nave : {self.n_epochs_used}"
r += f", nodes, n_estimated : {self.n_nodes}, " f"{self.n_estimated_nodes}"
r += f", nodes, n_estimated : {self.n_nodes}, {self.n_estimated_nodes}"
if "components" in self.dims:
r += f", n_components : {len(self.coords['components'])}, "
r += f", ~{sizeof_fmt(self._size)}"
r += ">"
return r
Expand Down Expand Up @@ -488,6 +490,9 @@ def _prepare_xarray(
if self.is_epoched:
coords["epochs"] = list(map(str, range(data.shape[0])))
coords["node_in -> node_out"] = n_estimated_list
if "components" in kwargs:
coords["components"] = kwargs.pop("components")
dims.append("components")
if "freqs" in kwargs:
coords["freqs"] = kwargs.pop("freqs")
dims.append("freqs")
Expand Down Expand Up @@ -531,17 +536,17 @@ def _check_data_consistency(self, data, indices, n_nodes):
raise TypeError("Connectivity data must be passed in as a numpy array.")

if self.is_epoched:
if data.ndim < 2 or data.ndim > 4:
if data.ndim < 2 or data.ndim > 5:
raise RuntimeError(
"Data using an epoched data structure should have at least 2 "
f"dimensions and at most 4 dimensions. Your data was {data.shape} "
f"dimensions and at most 5 dimensions. Your data was {data.shape} "
"shape."
)
else:
if data.ndim > 3:
if data.ndim > 4:
raise RuntimeError(
"Data not using an epoched data structure should have at least 1 "
f"dimensions and at most 3 dimensions. Your data was {data.shape} "
f"dimensions and at most 4 dimensions. Your data was {data.shape} "
"shape."
)

Expand Down Expand Up @@ -709,11 +714,15 @@ def get_data(self, output="compact"):
if output == "raveled":
data = self._data
else:
if self.method in ["cacoh", "mic", "mim", "gc", "gc_tr"]:
# multivariate results cannot be returned in a dense form as a
# single set of results would correspond to multiple entries in
# the matrix, and there could also be cases where multiple
# results correspond to the same entries in the matrix.
if (
isinstance(self.indices, tuple)
and not isinstance(self.indices[0], int)
and not isinstance(self.indices[1], int)
): # i.e. check if multivariate results based on nested indices
# multivariate results cannot be returned in a dense form as a single
# set of results would correspond to multiple entries in the matrix, and
# there could also be cases where multiple results correspond to the
# same entries in the matrix.
raise ValueError(
"cannot return multivariate connectivity data in a dense form"
)
Expand All @@ -728,6 +737,8 @@ def get_data(self, output="compact"):
# and thus appends the connectivity matrices side by side, so the
# shape is N x N * lags
new_shape.extend([self.n_nodes, self.n_nodes])
if "components" in self.dims:
new_shape.append(len(self.coords["components"]))
if "freqs" in self.dims:
new_shape.append(len(self.coords["freqs"]))
if "times" in self.dims:
Expand Down Expand Up @@ -870,9 +881,10 @@ def save(self, fname):
class SpectralConnectivity(BaseConnectivity, SpectralMixin):
"""Spectral connectivity class.

This class stores connectivity data that varies over
frequencies. The underlying data is an array of shape
(n_connections, n_freqs), or (n_nodes, n_nodes, n_freqs).
This class stores connectivity data that varies over frequencies. The underlying
data is an array of shape (n_connections, [n_components], n_freqs), or (n_nodes,
n_nodes, [n_components], n_freqs). ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.

Parameters
----------
Expand Down Expand Up @@ -924,11 +936,12 @@ def __init__(
class TemporalConnectivity(BaseConnectivity, TimeMixin):
"""Temporal connectivity class.

This is an array of shape (n_connections, n_times),
or (n_nodes, n_nodes, n_times). This describes how connectivity
varies over time. It describes sample-by-sample time-varying
connectivity (usually on the order of milliseconds). Here
time (t=0) is the same for all connectivity measures.
This is an array of shape (n_connections, [n_components], n_times), or (n_nodes,
n_nodes, [n_components], n_times). This describes how connectivity varies over
time. It describes sample-by-sample time-varying connectivity (usually on the order
of milliseconds). Here time (t=0) is the same for all connectivity measures.
``n_components`` is an optional dimension for multivariate methods where each
connection has multiple components of connectivity.

Parameters
----------
Expand All @@ -943,12 +956,11 @@ class TemporalConnectivity(BaseConnectivity, TimeMixin):

Notes
-----
`mne_connectivity.EpochConnectivity` is a similar connectivity
class to this one. However, that describes one connectivity snapshot
for each epoch. These epochs might be chunks of time that have
different meaning for time ``t=0``. Epochs can mean separate trials,
where the beginning of the trial implies t=0. These Epochs may
also be discontiguous.
`mne_connectivity.EpochConnectivity` is a similar connectivity class to this one.
However, that describes one connectivity snapshot for each epoch. These epochs might
be chunks of time that have different meaning for time ``t=0``. Epochs can mean
separate trials, where the beginning of the trial implies t=0. These Epochs may also
be discontiguous.
"""

expected_n_dim = 2
Expand Down Expand Up @@ -980,13 +992,14 @@ def __init__(
class SpectroTemporalConnectivity(BaseConnectivity, SpectralMixin, TimeMixin):
"""Spectrotemporal connectivity class.

This class stores connectivity data that varies over both frequency
and time. The temporal part describes sample-by-sample time-varying
connectivity (usually on the order of milliseconds). Note the
difference relative to Epochs.
This class stores connectivity data that varies over both frequency and time. The
temporal part describes sample-by-sample time-varying connectivity (usually on the
order of milliseconds). Note the difference relative to Epochs.

The underlying data is an array of shape (n_connections, n_freqs,
n_times), or (n_nodes, n_nodes, n_freqs, n_times).
The underlying data is an array of shape (n_connections, [n_components], n_freqs,
n_times), or (n_nodes, n_nodes, [n_components], n_freqs, n_times). ``n_components``
is an optional dimension for multivariate methods where each connection has multiple
components of connectivity.

Parameters
----------
Expand Down Expand Up @@ -1038,9 +1051,11 @@ def __init__(
class EpochSpectralConnectivity(SpectralConnectivity):
"""Spectral connectivity class over Epochs.

This is an array of shape (n_epochs, n_connections, n_freqs),
or (n_epochs, n_nodes, n_nodes, n_freqs). This describes how
connectivity varies over frequencies for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_freqs), or
(n_epochs, n_nodes, n_nodes, [n_components], n_freqs). This describes how
connectivity varies over frequencies for different epochs. ``n_components`` is an
optional dimension for multivariate methods where each connection has multiple
components of connectivity.

Parameters
----------
Expand Down Expand Up @@ -1088,9 +1103,11 @@ def __init__(
class EpochTemporalConnectivity(TemporalConnectivity):
"""Temporal connectivity class over Epochs.

This is an array of shape (n_epochs, n_connections, n_times),
or (n_epochs, n_nodes, n_nodes, n_times). This describes how
connectivity varies over time for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_times), or
(n_epochs, n_nodes, n_nodes, [n_components], n_times). This describes how
connectivity varies over time for different epochs. ``n_components`` is an optional
dimension for multivariate methods where each connection has multiple components of
connectivity.

Parameters
----------
Expand Down Expand Up @@ -1129,9 +1146,11 @@ def __init__(
class EpochSpectroTemporalConnectivity(SpectroTemporalConnectivity):
"""Spectrotemporal connectivity class over Epochs.

This is an array of shape (n_epochs, n_connections, n_freqs, n_times),
or (n_epochs, n_nodes, n_nodes, n_freqs, n_times). This describes how
connectivity varies over frequencies and time for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components], n_freqs,
n_times), or (n_epochs, n_nodes, n_nodes, [n_components], n_freqs, n_times). This
describes how connectivity varies over frequencies and time for different epochs.
``n_components`` is an optional dimension for multivariate methods where each
connection has multiple components of connectivity.

Parameters
----------
Expand Down Expand Up @@ -1178,9 +1197,10 @@ def __init__(
class Connectivity(BaseConnectivity):
"""Connectivity class without frequency or time component.

This is an array of shape (n_connections,),
or (n_nodes, n_nodes). This describes a connectivity matrix/graph
that does not vary over time, frequency, or epochs.
This is an array of shape (n_connections, [n_components]), or (n_nodes, n_nodes,
[n_components]). This describes a connectivity matrix/graph that does not vary
over time, frequency, or epochs. ``n_components`` is an optional dimension for
multivariate methods where each connection has multiple components of connectivity.

Parameters
----------
Expand Down Expand Up @@ -1222,9 +1242,10 @@ def __init__(
class EpochConnectivity(BaseConnectivity):
"""Epoch connectivity class.

This is an array of shape (n_epochs, n_connections),
or (n_epochs, n_nodes, n_nodes). This describes how
connectivity varies for different epochs.
This is an array of shape (n_epochs, n_connections, [n_components]), or (n_epochs,
n_nodes, n_nodes, [n_components]). This describes how connectivity varies for
different epochs. ``n_components`` is an optional dimension for multivariate methods
where each connection has multiple components of connectivity.

Parameters
----------
Expand Down
Loading