diff --git a/src/lava/magma/core/model/py/connection.py b/src/lava/magma/core/model/py/connection.py index 87242b197..f42911944 100644 --- a/src/lava/magma/core/model/py/connection.py +++ b/src/lava/magma/core/model/py/connection.py @@ -186,18 +186,16 @@ def _build_active_traces(self) -> None: """Build and store boolean numpy arrays specifying which x and y traces are active.""" # Shape : (2, ) - self._active_x_traces = np.logical_or( - self._active_x_traces_per_dependency[0], - self._active_x_traces_per_dependency[1], - self._active_x_traces_per_dependency[2], - ) + self._active_x_traces = \ + self._active_x_traces_per_dependency[0] \ + | self._active_x_traces_per_dependency[1] \ + | self._active_x_traces_per_dependency[2] # Shape : (3, ) - self._active_y_traces = np.logical_or( - self._active_y_traces_per_dependency[0], - self._active_y_traces_per_dependency[1], - self._active_y_traces_per_dependency[2], - ) + self._active_y_traces = \ + self._active_y_traces_per_dependency[0] \ + | self._active_y_traces_per_dependency[1] \ + | self._active_y_traces_per_dependency[2] def _build_learning_rule_appliers(self) -> None: """Build and store LearningRuleApplier for each active learning @@ -493,7 +491,7 @@ def _record_pre_spike_times(self, s_in: np.ndarray) -> None: Pre-synaptic spikes. """ self.x0[s_in] = True - multi_spike_x = np.logical_and(self.tx > 0, s_in) + multi_spike_x = (self.tx > 0) & s_in x_traces = self._x_traces x_traces[:, multi_spike_x] = self._add_impulse( @@ -519,7 +517,7 @@ def _record_post_spike_times(self, s_in_bap: np.ndarray) -> None: Post-synaptic spikes. """ self.y0[s_in_bap] = True - multi_spike_y = np.logical_and(self.ty > 0, s_in_bap) + multi_spike_y = (self.ty > 0) & s_in_bap y_traces = self._y_traces y_traces[:, multi_spike_y] = self._add_impulse( @@ -956,12 +954,8 @@ def _evaluate_trace( t_diff = t_eval - t_spikes - decay_only = np.logical_and( - np.logical_or(t_spikes == 0, t_diff < 0), broad_taus > 0 - ) - decay_spike_decay = np.logical_and( - t_spikes != 0, t_diff >= 0, broad_taus > 0 - ) + decay_only = ((t_spikes == 0) | (t_diff < 0)) & (broad_taus > 0) + decay_spike_decay = (t_spikes != 0) & (t_diff >= 0) & (broad_taus > 0) result = trace_values.copy() @@ -1160,7 +1154,7 @@ def _record_pre_spike_times(self, s_in: np.ndarray) -> None: """ self.x0[s_in] = True - multi_spike_x = np.logical_and(self.tx > 0, s_in) + multi_spike_x = (self.tx > 0) & s_in x_traces = self._x_traces x_traces[:, multi_spike_x] += self._x_impulses[:, np.newaxis] @@ -1182,7 +1176,7 @@ def _record_post_spike_times(self, s_in_bap: np.ndarray) -> None: """ self.y0[s_in_bap] = True - multi_spike_y = np.logical_and(self.ty > 0, s_in_bap) + multi_spike_y = (self.ty > 0) & s_in_bap y_traces = self._y_traces y_traces[:, multi_spike_y] += self._y_impulses[:, np.newaxis] @@ -1377,12 +1371,8 @@ def _evaluate_trace( t_diff = t_eval - t_spikes - decay_only = np.logical_and( - np.logical_or(t_spikes == 0, t_diff < 0), broad_taus > 0 - ) - decay_spike_decay = np.logical_and( - t_spikes != 0, t_diff >= 0, broad_taus > 0 - ) + decay_only = ((t_spikes == 0) | (t_diff < 0)) & (broad_taus > 0) + decay_spike_decay = (t_spikes != 0) & (t_diff >= 0) & (broad_taus > 0) result = trace_values.copy()