Skip to content

Commit

Permalink
bug fix patterns dims
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Jun 29, 2023
1 parent 754dfc7 commit 35f8b1f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
6 changes: 4 additions & 2 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,12 +1973,14 @@ def spectral_connectivity_epochs(data, names=None, method='coh', indices=None,
this_con = this_con_bands

if this_patterns is not None:
patterns_shape = (2, n_cons, n_bands) + this_patterns.shape[3:]
patterns_shape = ((2, n_cons, len(indices[0]), n_bands) +
this_patterns.shape[4:])
this_patterns_bands = np.empty(patterns_shape,
dtype=this_patterns.dtype)
for band_idx in range(n_bands):
this_patterns_bands[:, :, band_idx] = np.mean(
this_patterns[:, :, freq_idx_bands[band_idx]], axis=2)
this_patterns[:, :, :, freq_idx_bands[band_idx]],
axis=3)
this_patterns = this_patterns_bands

con.append(this_con)
Expand Down
12 changes: 6 additions & 6 deletions mne_connectivity/spectral/tests/test_spectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,8 @@ def test_multivar_spectral_connectivity_error_catch(method, mode):
patterns_shape = (
(len(indices[0]), len(con.freqs)),
(len(indices[1]), len(con.freqs)))
assert con.attrs["patterns"][0][0].shape == patterns_shape[0]
assert con.attrs["patterns"][1][0].shape == patterns_shape[1]
assert np.shape(con.attrs["patterns"][0][0]) == patterns_shape[0]
assert np.shape(con.attrs["patterns"][1][0]) == patterns_shape[1]

# only check these once for speed
if mode == 'multitaper':
Expand All @@ -497,16 +497,16 @@ def test_multivar_spectral_connectivity_error_catch(method, mode):
con = spectral_connectivity_epochs(
data, method=method, mode=mode, indices=indices, sfreq=sfreq,
fmin=fmin, fmax=fmax, faverage=True)
assert con.attrs["patterns"][0][0].shape[1] == len(fmin)
assert con.attrs["patterns"][1][0].shape[1] == len(fmin)
assert np.shape(con.attrs["patterns"][0][0])[1] == len(fmin)
assert np.shape(con.attrs["patterns"][1][0])[1] == len(fmin)

# check patterns shape matches input data, not rank
rank = (np.array([1]), np.array([1]))
con = spectral_connectivity_epochs(
data, method=method, mode=mode, indices=indices, sfreq=sfreq,
cwt_freqs=cwt_freqs, rank=rank)
assert con.attrs["patterns"][0][0].shape[0] == len(indices[0])
assert con.attrs["patterns"][1][0].shape[0] == len(indices[1])
assert np.shape(con.attrs["patterns"][0][0])[0] == len(indices[0])
assert np.shape(con.attrs["patterns"][1][0])[0] == len(indices[1])

# check bad rank args caught
too_low_rank = (np.array([0]), np.array([0]))
Expand Down

0 comments on commit 35f8b1f

Please sign in to comment.