diff --git a/src/lava/utils/weightutils.py b/src/lava/utils/weightutils.py index f6eadf61f..ac81ef46c 100644 --- a/src/lava/utils/weightutils.py +++ b/src/lava/utils/weightutils.py @@ -32,10 +32,10 @@ def determine_sign_mode(weights: np.ndarray) -> SignMode: The sign mode that best describes the values in the given weight matrix. """ - if np.max(weights) < 0: - sign_mode = SignMode.INHIBITORY - elif np.min(weights) >= 0: + if np.min(weights) >= 0: sign_mode = SignMode.EXCITATORY + elif np.max(weights) <= 0: + sign_mode = SignMode.INHIBITORY else: sign_mode = SignMode.MIXED @@ -70,10 +70,7 @@ def optimize_weight_bits( OptimizedWeights An object that wraps the optimized weight matrix and weight parameters. """ - if np.any(weights > 255) or np.any(weights < -256): - raise ValueError(f"weights have to be between -256 and 255. Got " - f"weights between {np.min(weights)} and " - f"{np.max(weights)}.") + _validate_weights(weights, sign_mode) weight_exp = _determine_weight_exp(weights, sign_mode) num_weight_bits = _determine_num_weight_bits(weights, weight_exp, sign_mode) @@ -91,6 +88,32 @@ def optimize_weight_bits( return optimized_weights +def _validate_weights(weights: np.ndarray, + sign_mode: SignMode) -> None: + """Validate the weight values against the given sign mode. + + Parameters + ---------- + weights : numpy.ndarray + Weight matrix + sign_mode : SignMode + Sign mode specified for the weight matrix + """ + mixed_flag = int(sign_mode == SignMode.MIXED) + excitatory_flag = int(sign_mode == SignMode.EXCITATORY) + inhibitory_flag = int(sign_mode == SignMode.INHIBITORY) + + min_weight = (-2 ** 8) * (mixed_flag + inhibitory_flag) + min_weight += inhibitory_flag + max_weight = (2 ** 8 - 1) * (mixed_flag + excitatory_flag) + + if np.any(weights > max_weight) or np.any(weights < min_weight): + raise ValueError(f"weights have to be between {min_weight} and " + f"{max_weight} for {sign_mode=}. Got " + f"weights between {np.min(weights)} and " + f"{np.max(weights)}.") + + def _determine_weight_exp(weights: np.ndarray, sign_mode: SignMode) -> int: """Determines the weight exponent to be used to optimally represent the @@ -118,7 +141,7 @@ def _determine_weight_exp(weights: np.ndarray, neg_scale = -128 / min_weight if min_weight < 0 else np.inf scale = np.min([pos_scale, neg_scale]) elif sign_mode == SignMode.INHIBITORY: - scale = -256 / min_weight + scale = -255 / min_weight elif sign_mode == SignMode.EXCITATORY: scale = 255 / max_weight @@ -197,6 +220,9 @@ def truncate_weights(weights: np.ndarray, """ weights = np.copy(weights).astype(np.int32) + if sign_mode == SignMode.INHIBITORY: + weights = -weights + mixed_flag = int(sign_mode == SignMode.MIXED) num_truncate_bits = max_num_weight_bits - num_weight_bits + mixed_flag @@ -204,6 +230,9 @@ def truncate_weights(weights: np.ndarray, np.right_shift(weights, num_truncate_bits), num_truncate_bits).astype(np.int32) + if sign_mode == SignMode.INHIBITORY: + truncated_weights = -truncated_weights + return truncated_weights @@ -218,8 +247,7 @@ def clip_weights(weights: np.ndarray, weights : numpy.ndarray Weight matrix that is to be truncated. sign_mode : SignMode - Sign mode to use for truncation. See SignMode class for the - correct values. + Sign mode to use for truncation. num_bits : int Number of bits to use to clip the weights to. @@ -231,12 +259,17 @@ def clip_weights(weights: np.ndarray, weights = np.copy(weights).astype(np.int32) mixed_flag = int(sign_mode == SignMode.MIXED) - excitatory_flag = int(sign_mode == SignMode.EXCITATORY) inhibitory_flag = int(sign_mode == SignMode.INHIBITORY) - min_wgt = (-2 ** num_bits) * (mixed_flag + inhibitory_flag) - max_wgt = (2 ** num_bits - 1) * (mixed_flag + excitatory_flag) + if inhibitory_flag: + weights = -weights + + min_wgt = (-2 ** num_bits) * mixed_flag + max_wgt = 2 ** num_bits - 1 clipped_weights = np.clip(weights, min_wgt, max_wgt) + if inhibitory_flag: + clipped_weights = -clipped_weights + return clipped_weights diff --git a/tests/lava/proc/scif/test_models.py b/tests/lava/proc/scif/test_models.py index f986051f4..04cebf693 100644 --- a/tests/lava/proc/scif/test_models.py +++ b/tests/lava/proc/scif/test_models.py @@ -45,6 +45,8 @@ def run_test( spk_src = SpikeSource(data=np.array([[0] * num_neurons]).reshape( num_neurons, 1).astype(int)) + # TODO (MR): The weight of -1 is now being correctly encoded as -1. + # It was written assuming the weight would be truncated to -2. dense_in = Dense(weights=(-1) * np.eye(num_neurons), num_message_bits=16) csp_scif = CspScif(shape=(num_neurons,), @@ -408,12 +410,12 @@ def test_scif_fp_no_noise_interrupt_rfct_mid(self) -> None: self.assertTrue(np.all(v_scif[spk_idxs_pre_inj] == neg_tau_ref)) self.assertTrue(np.all(v_lif_wta[wta_pos_spk_pre_inj] == 1)) self.assertTrue(np.all( - v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + wt * step_size)) - v_gt_inh_inj = (inh_inj - spk_idxs_pre_inj + 1) - t_inj_spk[inh_inj] + v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + 1)) + v_gt_inh_inj = step_size - t_inj_spk[inh_inj] self.assertTrue(np.all(v_scif[inh_inj] == v_gt_inh_inj)) self.assertTrue(np.all(v_lif_wta[wta_spk_rfct_interrupt] == 1)) self.assertTrue(np.all( - v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + wt * step_size)) + v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + 1)) # Test post-inhibitory-injection SCIF voltage and spiking idx_lst = [inj_times[2] + (theta // step_size) - 1 + j * total_period for j in range(num_epochs)] diff --git a/tests/lava/utils/test_weightutils.py b/tests/lava/utils/test_weightutils.py index d62171dc2..01645f129 100644 --- a/tests/lava/utils/test_weightutils.py +++ b/tests/lava/utils/test_weightutils.py @@ -34,6 +34,22 @@ def test_optimize_raises_error_when_weights_out_of_bounds(self) -> None: optimize_weight_bits(weights=np.array([-257]), sign_mode=SignMode.EXCITATORY) + with self.assertRaises(ValueError): + optimize_weight_bits(weights=np.array([-256]), + sign_mode=SignMode.INHIBITORY) + + with self.assertRaises(ValueError): + optimize_weight_bits(weights=np.array([1]), + sign_mode=SignMode.INHIBITORY) + + with self.assertRaises(ValueError): + optimize_weight_bits(weights=np.array([-257]), + sign_mode=SignMode.MIXED) + + with self.assertRaises(ValueError): + optimize_weight_bits(weights=np.array([257]), + sign_mode=SignMode.MIXED) + def test_optimize_weight_bits_excitatory_8bit(self) -> None: weights = np.arange(0, 255, 1, dtype=int) sign_mode = SignMode.EXCITATORY @@ -60,7 +76,7 @@ def test_optimize_weight_bits_excitatory_7bit(self) -> None: self.assertEqual(optimized.num_weight_bits, 7) def test_optimize_weight_bits_inhibitory_8bit(self) -> None: - weights = np.arange(-256, -1, 1, dtype=int) + weights = np.arange(-255, 0, 1, dtype=int) sign_mode = SignMode.INHIBITORY optimized = optimize_weight_bits(weights=weights, @@ -72,7 +88,7 @@ def test_optimize_weight_bits_inhibitory_8bit(self) -> None: self.assertEqual(optimized.num_weight_bits, 8) def test_optimize_weight_bits_inhibitory_7bit(self) -> None: - weights = np.arange(-256, -1, 2, dtype=int) + weights = np.arange(-254, 0, 2, dtype=int) sign_mode = SignMode.INHIBITORY optimized = optimize_weight_bits(weights=weights, @@ -123,12 +139,12 @@ def test_optimize_weight_bits_weight_exp(self) -> None: self.assertEqual(optimized.num_weight_bits, 3) def test_determine_weight_exp_inhibitory_0(self) -> None: - weight_exp = _determine_weight_exp(weights=np.array([-256, -128, -1]), + weight_exp = _determine_weight_exp(weights=np.array([-255, -128, -1]), sign_mode=SignMode.INHIBITORY) self.assertEqual(weight_exp, 0) def test_determine_weight_exp_inhibitory_1(self) -> None: - weight_exp = _determine_weight_exp(weights=np.array([-512, -256, -1]), + weight_exp = _determine_weight_exp(weights=np.array([-510, -256, -1]), sign_mode=SignMode.INHIBITORY) self.assertEqual(weight_exp, 1) @@ -285,13 +301,13 @@ def test_truncate_weights_inhibitory_8(self) -> None: def test_truncate_weights_inhibitory_7(self) -> None: truncated_weights = truncate_weights( - weights=np.array([-256, -255, -254, -253, -252]), + weights=np.array([-255, -254, -253, -252, -251]), sign_mode=SignMode.INHIBITORY, num_weight_bits=7 ) np.testing.assert_array_equal( truncated_weights, - np.array([-256, -256, -254, -254, -252]) + np.array([-254, -254, -252, -252, -250]) ) def test_truncate_weights_mixed_8(self) -> None: @@ -342,24 +358,24 @@ def test_clip_weights_excitatory_7(self) -> None: def test_clip_weights_inhibitory_8(self) -> None: clipped_weights = clip_weights( - weights=np.array([-257, -256, -1, 0, 1]), + weights=np.array([-256, -255, -1, 0, 1]), sign_mode=SignMode.INHIBITORY, num_bits=8 ) np.testing.assert_array_equal( clipped_weights, - np.array([-256, -256, -1, 0, 0]) + np.array([-255, -255, -1, 0, 0]) ) def test_clip_weights_inhibitory_7(self) -> None: clipped_weights = clip_weights( - weights=np.array([-129, -128, -1, 0, 1]), + weights=np.array([-128, -127, -1, 0, 1]), sign_mode=SignMode.INHIBITORY, num_bits=7 ) np.testing.assert_array_equal( clipped_weights, - np.array([-128, -128, -1, 0, 0]) + np.array([-127, -127, -1, 0, 0]) ) def test_clip_weights_mixed_8(self) -> None: