Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add raw stc #12001

Merged
merged 17 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/changes/devel.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Enhancements
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)
- Refactored internals of :func:`mne.read_annotations` (:gh:`11964` by `Paul Roujansky`_)
- Enhance :func:`~mne.viz.plot_evoked_field` with a GUI that has controls for time, colormap, and contour lines (:gh:`11942` by `Marijn van Vliet`_)
- Add extracting all time courses in a label using :func:`mne.extract_label_time_course` without applying an aggregation function (like ``mean``) (:gh:`12001` by `Hamza Abdelhedi`_)

Bugs
~~~~
Expand Down
25 changes: 17 additions & 8 deletions mne/source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,6 +3240,7 @@ def _pca_flip(flip, data):
"mean_flip": lambda flip, data: np.mean(flip * data, axis=0),
"max": lambda flip, data: np.max(np.abs(data), axis=0),
"pca_flip": _pca_flip,
"raw": lambda flip, data: data, # Return Identity: Preserves all vertices.
}


Expand Down Expand Up @@ -3572,7 +3573,11 @@ def _gen_extract_label_time_course(
)

# do the extraction
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
if mode == "raw":
label_tc = [] # Initialize an empty list for raw mode.
else:
# For other modes, initialize the label_tc array
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
if vertidx is not None:
if isinstance(vertidx, sparse.csr_matrix):
Expand All @@ -3583,15 +3588,19 @@ def _gen_extract_label_time_course(
this_data.shape = (this_data.shape[0],) + stc.data.shape[1:]
else:
this_data = stc.data[vertidx]
label_tc[i] = func(flip, this_data)
if mode == "raw":
label_tc.append(func(flip, this_data))
else:
label_tc[i] = func(flip, this_data)

# extract label time series for the vol src space (only mean supported)
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
if nv != 0:
v2 = offset + nv
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
offset = v2
if mode != "raw":
offset = nvert[:-n_mean].sum() # effectively :2 or :0
for i, nv in enumerate(nvert[2:]):
if nv != 0:
v2 = offset + nv
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
offset = v2

# this is a generator!
yield label_tc
Expand Down
30 changes: 24 additions & 6 deletions mne/tests/test_source_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def test_extract_label_time_course(kind, vector):

label_tcs = dict(mean=np.arange(n_labels)[:, None] * np.ones((n_labels, n_times)))
label_tcs["max"] = label_tcs["mean"]
label_tcs["raw"] = label_tcs["mean"]

# compute the mean with sign flip
label_tcs["mean_flip"] = np.zeros_like(label_tcs["mean"])
Expand Down Expand Up @@ -734,32 +735,49 @@ def test_extract_label_time_course(kind, vector):
assert_array_equal(arr[1:], vol_means_t)

# test the different modes
modes = ["mean", "mean_flip", "pca_flip", "max", "auto"]
modes = ["mean", "mean_flip", "pca_flip", "max", "auto", "raw"]

for mode in modes:
if vector and mode not in ("mean", "max", "auto"):
with pytest.raises(ValueError, match="when using a vector"):
extract_label_time_course(stcs, labels, src, mode=mode)
continue
with _record_warnings(): # SVD convergence on arm64
print(stcs)
label_tc = extract_label_time_course(stcs, labels, src, mode=mode)
label_tc_method = [
stc.extract_label_time_course(labels, src, mode=mode) for stc in stcs
]
assert len(label_tc) == n_stcs
assert len(label_tc_method) == n_stcs
for tc1, tc2 in zip(label_tc, label_tc_method):
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
if mode == "raw":
assert all(arr.shape[1] == tc1[0].shape[1] for arr in tc1)
assert all(arr.shape[1] == tc2[0].shape[1] for arr in tc2)
assert (len(tc1), tc1[0].shape[1]) == (n_labels,) + end_shape
assert (len(tc2), tc2[0].shape[1]) == (n_labels,) + end_shape
for arr1, arr2 in zip(tc1, tc2): # list of arrays
assert_allclose(arr1, arr2, rtol=1e-8, atol=1e-16)
else:
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
# XXX we don't check pca_flip, probably should someday...
if mode == "auto":
use_mode = "mean" if vector else "mean_flip"
else:
use_mode = mode
# XXX we don't check pca_flip, probably should someday...
if use_mode in ("mean", "max", "mean_flip"):
assert_array_almost_equal(tc1[:n_labels], label_tcs[use_mode])
assert_array_almost_equal(tc1[n_labels:], vol_means_t)
elif use_mode == "raw":
for arr1, arr2 in zip(
tc1[:n_labels], label_tcs[use_mode]
): # list of arrays
assert_allclose(
arr1, np.tile(arr2, (arr1.shape[0], 1)), rtol=1e-8, atol=1e-16
)
if mode != "raw":
assert_array_almost_equal(tc1[n_labels:], vol_means_t)

# test label with very few vertices (check SVD conditionals)
label = Label(vertices=src[0]["vertno"][:2], hemi="lh")
Expand Down
Loading