Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Reduce compute time for multivariate coherency methods #184

Merged
merged 10 commits into from
May 30, 2024

Conversation

tsbinns
Copy link
Collaborator

@tsbinns tsbinns commented May 23, 2024

Problem

A major step in computing connectivity for the CaCoh, MIC, and MIM methods is computing a transformation matrix, taken as the inverse square root of the CSD which must be performed in turn for each frequency and time entry of the CSD.

The current approach is to use SciPy's linalg.fractional_matrix_power() function, however this is only compatible with 2D arrays, meaning we must loop over all time and frequency entries which not very fast. Unfortunately, there is no equivalent method in NumPy for working with ND arrays (also no equivalent of the linalg.sqrtm() function which would be an alternative).

The workaround for now was to offer parallelisation of this computation over frequencies, however this is still not as fast as operating on the ND arrays, and also requires a loop over time entries.

Solution

A faster solution can be achieved using a couple of tensor-compatible NumPy functions:

T = np.zeros_like(C_r, dtype=np.float64)
# seeds
eigvals, eigvects = np.linalg.eigh(C_r[:, :, :n_seeds, :n_seeds])
T[:, :, :n_seeds, :n_seeds] = np.linalg.inv(
np.matmul(
(eigvects * np.expand_dims(np.sqrt(eigvals), (2))),
eigvects.transpose(0, 1, 3, 2),
)
)
# targets
eigvals, eigvects = np.linalg.eigh(C_r[:, :, n_seeds:, n_seeds:])
T[:, :, n_seeds:, n_seeds:] = np.linalg.inv(
np.matmul(
(eigvects * np.expand_dims(np.sqrt(eigvals), (2))),
eigvects.transpose(0, 1, 3, 2),
)
)
return T

For a CSD with 4 channels, 96 frequencies, and no parallelisation, the current approach took ~80 ms to run. This new approach takes ~10 ms. Even with paralellisation, the new approach still offers improvements, as if n_jobs is < the number of frequencies some looping over frequency bins will still be involved, and there is the overhead to initialise the workers.

Also, it is not a super rigorous check, but running the test suite of spectral.py drops the run time from ~30 s to ~20 s on my machine, so there's a noticeable improvement.


The result is also identical down to ~1e-16, and the regression tests for checking the consistency of the connectivity results pass without any problems.


The biggest difference is how invalid cases (e.g. non-singular CSDs) are handled. Currently, the function will run in these cases, but NaN/inf/imaginary numbers can be present, which was being checked for. In this new approach, a LinAlgError will be raised which we can catch:

try:
return self._invsqrtm(C_r, n_seeds)
except np.linalg.LinAlgError as error:
raise RuntimeError(
"the transformation matrix of the data must be real-valued "
"and contain no NaN or infinity values; check that you are "
"using full rank data or specify an appropriate rank for the "
"seeds and targets that is less than or equal to their ranks"
) from error

Again, the unit tests for checking these invalid cases are caught also pass without any problems.


Also, if we are not parallelising the function it no longer has to be outside of the class, so we can tidy things up and make the function a class method:

def _invsqrtm(self, C_r, n_seeds):
"""Compute inverse sqrt of CSD over frequencies and times.
Parameters
----------
C_r : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Real part of the CSD. Expected to be symmetric and non-singular.
n_seeds : int
Number of seed channels for the connection.
Returns
-------
T : np.ndarray, shape=(n_freqs, n_times, n_channels, n_channels)
Inverse square root of the real-valued CSD. Name comes from Ewald
et al. (2012).
Notes
-----
This approach is a workaround for computing the inverse square root of
an ND array. SciPy has dedicated functions for this purpose, e.g.
`sp.linalg.fractional_matrix_power(A, -0.5)` or `sp.linalg.inv(
sp.linalg.sqrtm(A))`, however these only work with 2D arrays, meaning
frequencies and times must be looped over which is very slow. There are
no equivalent functions in NumPy for working with ND arrays (as of
v1.26).
The data array is expected to be symmetric and non-singular, otherwise
a LinAlgError is raised.
See Eq. 3 of Ewald et al. (2012). NeuroImage. DOI:
10.1016/j.neuroimage.2011.11.084.
"""

Conclusion

Altogether I think this is a very nice improvement that maintains the functionality while giving a big speed up to the CaCoh, MIC, and MIM methods.

Let me know if anyone has any thoughts/comments. Cheers!

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 23, 2024

It could also be good to update the error message for the invalid cases to better reflect what is now happening in the code:

try:
return self._invsqrtm(C_r, n_seeds)
except np.linalg.LinAlgError as error:
raise RuntimeError(
"the transformation matrix of the data could not be computed "
"from the cross-spectral density; check that you are using "
"full rank data or specify an appropriate rank for the seeds "
"and targets that is less than or equal to their ranks"
) from error

Before: "the transformation matrix of the data must be real-valued and contain no NaN or infinity values".
After: "the transformation matrix of the data could not be computed from the cross-spectral density".

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 23, 2024

There are a few runtime errors for MacOS with some NumPy functions, however these do not occur for Ubuntu or Windows. Not immediately clear to me why since all are running Python 3.11.9 with NumPy 1.26.4.

Failing tests also include Granger causality methods, where nothing has been changed.

Would need to look into this more unless someone has any ideas??

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 23, 2024

The documentation also fails to build due to an error in cwt_sensor_connectivity.py (not touched in this PR).

Hopefully #185 solves this.

Copy link
Member

@larsoner larsoner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I often resort to using np.linalg.* over scipy.linalg exactly for its ability to operate over the last two dimensions of a N-dimensional array

@larsoner
Copy link
Member

Failure is now related. You probably need to add a filterwarnings. You can make it specific to macOS arm64 like I did in #187 or you can just @pytest.mark.filterwarnings(...) decorate the offending function

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 24, 2024

Failure is now related. You probably need to add a filterwarnings. You can make it specific to macOS arm64 like I did in #187 or you can just @pytest.mark.filterwarnings(...) decorate the offending function

Thanks for the helping @larsoner! Will add the filters for the coherency functions.

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 27, 2024

@larsoner Added macOS-specific filters for the coherency methods alongside your GC filters. Now everything is passing!

Just want to clarify there is nothing outstanding??

@Hugo-W
Copy link

Hugo-W commented May 27, 2024

Hi everyone,

About this new approach, if I follow the logic correctly and if I did not miss anything, you could simply square root and invert the eigenvalues rather than the entire matrices (which was maybe intended as this is natural for power of matrices).

I mean you would change:

T[:, :, :n_seeds, :n_seeds] = np.linalg.inv( 
     np.matmul( 
         (eigvects * np.expand_dims(np.sqrt(eigvals), (2))), 
         eigvects.transpose(0, 1, 3, 2), 
     ) 
 )

into

T[:, :, :n_seeds, :n_seeds] = (eigvects * np.expand_dims(1./np.sqrt(eigvals), (2))) @ eigvects.transpose(0, 1, 3, 2)

For both seed and target. It saves the linalg.inv operation twice.

This change has little impact for few channels (that's where the inverse operator bottleneck is), but for many channels it's a relatively big difference. I observed a consistent decrease in computation time, e.g. with C_r of shape (32, 50, 60, 60) on my machine, I'd go from 430ms to 330ms.

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 27, 2024

@Hugo-W Nice suggestion! Yeah, this approach also works and all regression tests pass.

The only difference is how the cases of non-full rank data are handled. In the extreme cases we are using in the unit tests, this is now all being caught as a LinAlgError, however in this new approach there is a RuntimeWarning for division by 0 (from 1.0/np.sqrt(eigvals)). @larsoner @drammock What is the MNE view on treating warnings as errors?

Even if we do this, I think it's good to keep catching LinAlgError as well, since this should be what happens in more real-data scenarios.

@larsoner
Copy link
Member

I think it's good to keep catching LinAlgError as well, since this should be what happens in more real-data scenarios.

Indeed it's probably worth having a if (eigvals == 0).any(): raise LinAlgError(...) so that the new code raises at least the same error class as it did before. Then the warnings-as-error thing isn't really relevant I think.

@tsbinns
Copy link
Collaborator Author

tsbinns commented May 29, 2024

Added the suggestion of @larsoner to raise LinAlgError when encountering zero-valued eigenvalues.

Think that's all of the points addressed.

@tsbinns tsbinns mentioned this pull request May 29, 2024
Comment on lines 347 to 348
if (eigvals == 0).any(): # sign of non-full rank data
raise np.linalg.LinAlgError()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (eigvals == 0).any(): # sign of non-full rank data
raise np.linalg.LinAlgError()
n_zero = (eigvals == 0).sum()
if n_zero: # sign of non-full rank data
raise np.linalg.LinAlgError(
"Cannot compute inverse square root of rank-deficient matrix "
f"with {n_zero}/{len(eigvals)} zero eigenvalue(s)"
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added for both the seed and target eigvals.

Copy link
Member

@larsoner larsoner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one tiny suggestion to make the raised error more meaningful/informative

EDIT: And a question, too

_gc_marks = []
if platform.system() == "Darwin" and platform.processor() == "arm":
_coh_marks.extend([
pytest.mark.filterwarnings("ignore:invalid value encountered in sqrt:")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed now that there is a raise LinAlgError in there? Or does it come from another code path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the filters for the cohy methods should now be redundant. Have removed and will wait to see if macOS tests pass to be sure.

@tsbinns tsbinns mentioned this pull request May 30, 2024
@larsoner larsoner merged commit 67df38b into mne-tools:main May 30, 2024
10 checks passed
@larsoner
Copy link
Member

Thanks @tsbinns !

@tsbinns tsbinns deleted the efficient_invsqrtm branch May 30, 2024 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants