diff --git a/src/lava/utils/weightutils.py b/src/lava/utils/weightutils.py index 009e7a29c..0230da93e 100644 --- a/src/lava/utils/weightutils.py +++ b/src/lava/utils/weightutils.py @@ -239,6 +239,7 @@ def truncate_weights(weights: ty.Union[np.ndarray, spmatrix], weights = (weights >> num_truncate_bits) << num_truncate_bits elif isinstance(weights, spmatrix): weights.data = (weights.data >> num_truncate_bits) << num_truncate_bits + weights.eliminate_zeros() if sign_mode == SignMode.INHIBITORY: weights = -weights