Skip to content

Commit

Permalink
Fixing Dense inhibitory sign_mode (#376)
Browse files Browse the repository at this point in the history
* Fix for inhibitory sign mode in Dense (wip).

Signed-off-by: Mathis Richter <[email protected]>

* Fixed failing unit test of scif neurons.

Signed-off-by: Mathis Richter <[email protected]>

* Removed debugging comments.

Signed-off-by: Mathis Richter <[email protected]>

* Fixed error in unit test.

Signed-off-by: Mathis Richter <[email protected]>

Signed-off-by: Mathis Richter <[email protected]>
  • Loading branch information
mathisrichter committed Sep 26, 2022
1 parent 49e2115 commit eed625a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 26 deletions.
59 changes: 46 additions & 13 deletions src/lava/utils/weightutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def determine_sign_mode(weights: np.ndarray) -> SignMode:
The sign mode that best describes the values in the given weight
matrix.
"""
if np.max(weights) < 0:
sign_mode = SignMode.INHIBITORY
elif np.min(weights) >= 0:
if np.min(weights) >= 0:
sign_mode = SignMode.EXCITATORY
elif np.max(weights) <= 0:
sign_mode = SignMode.INHIBITORY
else:
sign_mode = SignMode.MIXED

Expand Down Expand Up @@ -70,10 +70,7 @@ def optimize_weight_bits(
OptimizedWeights
An object that wraps the optimized weight matrix and weight parameters.
"""
if np.any(weights > 255) or np.any(weights < -256):
raise ValueError(f"weights have to be between -256 and 255. Got "
f"weights between {np.min(weights)} and "
f"{np.max(weights)}.")
_validate_weights(weights, sign_mode)

weight_exp = _determine_weight_exp(weights, sign_mode)
num_weight_bits = _determine_num_weight_bits(weights, weight_exp, sign_mode)
Expand All @@ -91,6 +88,32 @@ def optimize_weight_bits(
return optimized_weights


def _validate_weights(weights: np.ndarray,
sign_mode: SignMode) -> None:
"""Validate the weight values against the given sign mode.
Parameters
----------
weights : numpy.ndarray
Weight matrix
sign_mode : SignMode
Sign mode specified for the weight matrix
"""
mixed_flag = int(sign_mode == SignMode.MIXED)
excitatory_flag = int(sign_mode == SignMode.EXCITATORY)
inhibitory_flag = int(sign_mode == SignMode.INHIBITORY)

min_weight = (-2 ** 8) * (mixed_flag + inhibitory_flag)
min_weight += inhibitory_flag
max_weight = (2 ** 8 - 1) * (mixed_flag + excitatory_flag)

if np.any(weights > max_weight) or np.any(weights < min_weight):
raise ValueError(f"weights have to be between {min_weight} and "
f"{max_weight} for {sign_mode=}. Got "
f"weights between {np.min(weights)} and "
f"{np.max(weights)}.")


def _determine_weight_exp(weights: np.ndarray,
sign_mode: SignMode) -> int:
"""Determines the weight exponent to be used to optimally represent the
Expand Down Expand Up @@ -118,7 +141,7 @@ def _determine_weight_exp(weights: np.ndarray,
neg_scale = -128 / min_weight if min_weight < 0 else np.inf
scale = np.min([pos_scale, neg_scale])
elif sign_mode == SignMode.INHIBITORY:
scale = -256 / min_weight
scale = -255 / min_weight
elif sign_mode == SignMode.EXCITATORY:
scale = 255 / max_weight

Expand Down Expand Up @@ -197,13 +220,19 @@ def truncate_weights(weights: np.ndarray,
"""
weights = np.copy(weights).astype(np.int32)

if sign_mode == SignMode.INHIBITORY:
weights = -weights

mixed_flag = int(sign_mode == SignMode.MIXED)
num_truncate_bits = max_num_weight_bits - num_weight_bits + mixed_flag

truncated_weights = np.left_shift(
np.right_shift(weights, num_truncate_bits),
num_truncate_bits).astype(np.int32)

if sign_mode == SignMode.INHIBITORY:
truncated_weights = -truncated_weights

return truncated_weights


Expand All @@ -218,8 +247,7 @@ def clip_weights(weights: np.ndarray,
weights : numpy.ndarray
Weight matrix that is to be truncated.
sign_mode : SignMode
Sign mode to use for truncation. See SignMode class for the
correct values.
Sign mode to use for truncation.
num_bits : int
Number of bits to use to clip the weights to.
Expand All @@ -231,12 +259,17 @@ def clip_weights(weights: np.ndarray,
weights = np.copy(weights).astype(np.int32)

mixed_flag = int(sign_mode == SignMode.MIXED)
excitatory_flag = int(sign_mode == SignMode.EXCITATORY)
inhibitory_flag = int(sign_mode == SignMode.INHIBITORY)

min_wgt = (-2 ** num_bits) * (mixed_flag + inhibitory_flag)
max_wgt = (2 ** num_bits - 1) * (mixed_flag + excitatory_flag)
if inhibitory_flag:
weights = -weights

min_wgt = (-2 ** num_bits) * mixed_flag
max_wgt = 2 ** num_bits - 1

clipped_weights = np.clip(weights, min_wgt, max_wgt)

if inhibitory_flag:
clipped_weights = -clipped_weights

return clipped_weights
8 changes: 5 additions & 3 deletions tests/lava/proc/scif/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def run_test(

spk_src = SpikeSource(data=np.array([[0] * num_neurons]).reshape(
num_neurons, 1).astype(int))
# TODO (MR): The weight of -1 is now being correctly encoded as -1.
# It was written assuming the weight would be truncated to -2.
dense_in = Dense(weights=(-1) * np.eye(num_neurons),
num_message_bits=16)
csp_scif = CspScif(shape=(num_neurons,),
Expand Down Expand Up @@ -408,12 +410,12 @@ def test_scif_fp_no_noise_interrupt_rfct_mid(self) -> None:
self.assertTrue(np.all(v_scif[spk_idxs_pre_inj] == neg_tau_ref))
self.assertTrue(np.all(v_lif_wta[wta_pos_spk_pre_inj] == 1))
self.assertTrue(np.all(
v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + wt * step_size))
v_gt_inh_inj = (inh_inj - spk_idxs_pre_inj + 1) - t_inj_spk[inh_inj]
v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + 1))
v_gt_inh_inj = step_size - t_inj_spk[inh_inj]
self.assertTrue(np.all(v_scif[inh_inj] == v_gt_inh_inj))
self.assertTrue(np.all(v_lif_wta[wta_spk_rfct_interrupt] == 1))
self.assertTrue(np.all(
v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + wt * step_size))
v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + 1))
# Test post-inhibitory-injection SCIF voltage and spiking
idx_lst = [inj_times[2] + (theta // step_size) - 1 + j * total_period
for j in range(num_epochs)]
Expand Down
36 changes: 26 additions & 10 deletions tests/lava/utils/test_weightutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def test_optimize_raises_error_when_weights_out_of_bounds(self) -> None:
optimize_weight_bits(weights=np.array([-257]),
sign_mode=SignMode.EXCITATORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([-256]),
sign_mode=SignMode.INHIBITORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([1]),
sign_mode=SignMode.INHIBITORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([-257]),
sign_mode=SignMode.MIXED)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([257]),
sign_mode=SignMode.MIXED)

def test_optimize_weight_bits_excitatory_8bit(self) -> None:
weights = np.arange(0, 255, 1, dtype=int)
sign_mode = SignMode.EXCITATORY
Expand All @@ -60,7 +76,7 @@ def test_optimize_weight_bits_excitatory_7bit(self) -> None:
self.assertEqual(optimized.num_weight_bits, 7)

def test_optimize_weight_bits_inhibitory_8bit(self) -> None:
weights = np.arange(-256, -1, 1, dtype=int)
weights = np.arange(-255, 0, 1, dtype=int)
sign_mode = SignMode.INHIBITORY

optimized = optimize_weight_bits(weights=weights,
Expand All @@ -72,7 +88,7 @@ def test_optimize_weight_bits_inhibitory_8bit(self) -> None:
self.assertEqual(optimized.num_weight_bits, 8)

def test_optimize_weight_bits_inhibitory_7bit(self) -> None:
weights = np.arange(-256, -1, 2, dtype=int)
weights = np.arange(-254, 0, 2, dtype=int)
sign_mode = SignMode.INHIBITORY

optimized = optimize_weight_bits(weights=weights,
Expand Down Expand Up @@ -123,12 +139,12 @@ def test_optimize_weight_bits_weight_exp(self) -> None:
self.assertEqual(optimized.num_weight_bits, 3)

def test_determine_weight_exp_inhibitory_0(self) -> None:
weight_exp = _determine_weight_exp(weights=np.array([-256, -128, -1]),
weight_exp = _determine_weight_exp(weights=np.array([-255, -128, -1]),
sign_mode=SignMode.INHIBITORY)
self.assertEqual(weight_exp, 0)

def test_determine_weight_exp_inhibitory_1(self) -> None:
weight_exp = _determine_weight_exp(weights=np.array([-512, -256, -1]),
weight_exp = _determine_weight_exp(weights=np.array([-510, -256, -1]),
sign_mode=SignMode.INHIBITORY)
self.assertEqual(weight_exp, 1)

Expand Down Expand Up @@ -285,13 +301,13 @@ def test_truncate_weights_inhibitory_8(self) -> None:

def test_truncate_weights_inhibitory_7(self) -> None:
truncated_weights = truncate_weights(
weights=np.array([-256, -255, -254, -253, -252]),
weights=np.array([-255, -254, -253, -252, -251]),
sign_mode=SignMode.INHIBITORY,
num_weight_bits=7
)
np.testing.assert_array_equal(
truncated_weights,
np.array([-256, -256, -254, -254, -252])
np.array([-254, -254, -252, -252, -250])
)

def test_truncate_weights_mixed_8(self) -> None:
Expand Down Expand Up @@ -342,24 +358,24 @@ def test_clip_weights_excitatory_7(self) -> None:

def test_clip_weights_inhibitory_8(self) -> None:
clipped_weights = clip_weights(
weights=np.array([-257, -256, -1, 0, 1]),
weights=np.array([-256, -255, -1, 0, 1]),
sign_mode=SignMode.INHIBITORY,
num_bits=8
)
np.testing.assert_array_equal(
clipped_weights,
np.array([-256, -256, -1, 0, 0])
np.array([-255, -255, -1, 0, 0])
)

def test_clip_weights_inhibitory_7(self) -> None:
clipped_weights = clip_weights(
weights=np.array([-129, -128, -1, 0, 1]),
weights=np.array([-128, -127, -1, 0, 1]),
sign_mode=SignMode.INHIBITORY,
num_bits=7
)
np.testing.assert_array_equal(
clipped_weights,
np.array([-128, -128, -1, 0, 0])
np.array([-127, -127, -1, 0, 0])
)

def test_clip_weights_mixed_8(self) -> None:
Expand Down

0 comments on commit eed625a

Please sign in to comment.