Skip to content

Commit

Permalink
add multivariate con base class
Browse files Browse the repository at this point in the history
  • Loading branch information
tsbinns committed Feb 22, 2023
1 parent d71b067 commit 446c7fe
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions mne_connectivity/spectral/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,73 @@ def combine(self, other):
self._acc += other._acc


class _EpochMeanMultivariateConEstBase(_AbstractConEstBase):
"""Base class for mean epoch-wise multivar. con. estimation methods."""

n_steps = None
patterns = None

def __init__(self, n_signals, n_cons, n_freqs, n_times, n_jobs=1):
self.n_signals = n_signals
self.n_cons = n_cons
self.n_freqs = n_freqs
self.n_times = n_times
self.n_jobs = n_jobs

# include time dimension, even when unused for indexing flexibility
if n_times == 0:
self.csd_shape = (n_signals**2, n_freqs)
self.con_scores = np.zeros((n_cons, n_freqs, 1))
else:
self.csd_shape = (n_signals**2, n_freqs, n_times)
self.con_scores = np.zeros((n_cons, n_freqs, n_times))

# allocate space for accumulation of CSD
self._acc = np.zeros(self.csd_shape, dtype=np.complex128)

self._compute_n_progress_bar_steps()

def start_epoch(self): # noqa: D401
"""Called at the start of each epoch."""
pass # for this type of con. method we don't do anything

def combine(self, other):
"""Include con. accumulated for some epochs in this estimate."""
self._acc += other._acc

def accumulate(self, con_idx, csd_xy):
"""Accumulate CSD for some connections."""
self._acc[con_idx] += csd_xy

def _compute_n_progress_bar_steps(self):
"""Calculate the number of steps to include in the progress bar."""
self.n_steps = int(np.ceil(self.n_freqs / self.n_jobs))

def _log_connection_number(self, con_i):
"""Log the number of the connection being computed."""
logger.info('Computing %i for connection %i of %i' % (
self.name, con_i + 1, self.n_cons, ))

def _get_block_indices(self, block_i, limit):
"""Get indices for a computation block capped by a limit."""
indices = np.arange(block_i * self.n_jobs, (block_i + 1) * self.n_jobs)

return indices[np.nonzero(indices < limit)]

def reshape_csd(self):
"""Reshape CSD into a matrix of times x freqs x signals x signals."""
if self.n_times == 0:
return (
np.reshape(self._acc, (self.n_signals, self.n_signals,
self.n_freqs, 1)).transpose(3, 2, 0, 1)
)
return (
np.reshape(self._acc, (self.n_signals, self.n_signals,
self.n_freqs, self.n_times)
).transpose(3, 2, 0, 1)
)


class _CohEstBase(_EpochMeanConEstBase):
"""Base Estimator for Coherence, Coherency, Imag. Coherence."""

Expand Down

0 comments on commit 446c7fe

Please sign in to comment.