Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Dense inhibitory sign_mode #376

Merged
merged 5 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 46 additions & 13 deletions src/lava/utils/weightutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def determine_sign_mode(weights: np.ndarray) -> SignMode:
The sign mode that best describes the values in the given weight
matrix.
"""
if np.max(weights) < 0:
sign_mode = SignMode.INHIBITORY
elif np.min(weights) >= 0:
if np.min(weights) >= 0:
sign_mode = SignMode.EXCITATORY
elif np.max(weights) <= 0:
sign_mode = SignMode.INHIBITORY
else:
sign_mode = SignMode.MIXED

Expand Down Expand Up @@ -70,10 +70,7 @@ def optimize_weight_bits(
OptimizedWeights
An object that wraps the optimized weight matrix and weight parameters.
"""
if np.any(weights > 255) or np.any(weights < -256):
raise ValueError(f"weights have to be between -256 and 255. Got "
f"weights between {np.min(weights)} and "
f"{np.max(weights)}.")
_validate_weights(weights, sign_mode)

weight_exp = _determine_weight_exp(weights, sign_mode)
num_weight_bits = _determine_num_weight_bits(weights, weight_exp, sign_mode)
Expand All @@ -91,6 +88,32 @@ def optimize_weight_bits(
return optimized_weights


def _validate_weights(weights: np.ndarray,
sign_mode: SignMode) -> None:
"""Validate the weight values against the given sign mode.

Parameters
----------
weights : numpy.ndarray
Weight matrix
sign_mode : SignMode
Sign mode specified for the weight matrix
"""
mixed_flag = int(sign_mode == SignMode.MIXED)
excitatory_flag = int(sign_mode == SignMode.EXCITATORY)
inhibitory_flag = int(sign_mode == SignMode.INHIBITORY)

min_weight = (-2 ** 8) * (mixed_flag + inhibitory_flag)
min_weight += inhibitory_flag
max_weight = (2 ** 8 - 1) * (mixed_flag + excitatory_flag)

if np.any(weights > max_weight) or np.any(weights < min_weight):
raise ValueError(f"weights have to be between {min_weight} and "
f"{max_weight} for {sign_mode=}. Got "
f"weights between {np.min(weights)} and "
f"{np.max(weights)}.")


def _determine_weight_exp(weights: np.ndarray,
sign_mode: SignMode) -> int:
"""Determines the weight exponent to be used to optimally represent the
Expand Down Expand Up @@ -118,7 +141,7 @@ def _determine_weight_exp(weights: np.ndarray,
neg_scale = -128 / min_weight if min_weight < 0 else np.inf
scale = np.min([pos_scale, neg_scale])
elif sign_mode == SignMode.INHIBITORY:
scale = -256 / min_weight
scale = -255 / min_weight
elif sign_mode == SignMode.EXCITATORY:
scale = 255 / max_weight

Expand Down Expand Up @@ -197,13 +220,19 @@ def truncate_weights(weights: np.ndarray,
"""
weights = np.copy(weights).astype(np.int32)

if sign_mode == SignMode.INHIBITORY:
weights = -weights

mixed_flag = int(sign_mode == SignMode.MIXED)
num_truncate_bits = max_num_weight_bits - num_weight_bits + mixed_flag

truncated_weights = np.left_shift(
np.right_shift(weights, num_truncate_bits),
num_truncate_bits).astype(np.int32)

if sign_mode == SignMode.INHIBITORY:
truncated_weights = -truncated_weights

return truncated_weights


Expand All @@ -218,8 +247,7 @@ def clip_weights(weights: np.ndarray,
weights : numpy.ndarray
Weight matrix that is to be truncated.
sign_mode : SignMode
Sign mode to use for truncation. See SignMode class for the
correct values.
Sign mode to use for truncation.
num_bits : int
Number of bits to use to clip the weights to.

Expand All @@ -231,12 +259,17 @@ def clip_weights(weights: np.ndarray,
weights = np.copy(weights).astype(np.int32)

mixed_flag = int(sign_mode == SignMode.MIXED)
excitatory_flag = int(sign_mode == SignMode.EXCITATORY)
inhibitory_flag = int(sign_mode == SignMode.INHIBITORY)

min_wgt = (-2 ** num_bits) * (mixed_flag + inhibitory_flag)
max_wgt = (2 ** num_bits - 1) * (mixed_flag + excitatory_flag)
if inhibitory_flag:
weights = -weights

min_wgt = (-2 ** num_bits) * mixed_flag
max_wgt = 2 ** num_bits - 1

clipped_weights = np.clip(weights, min_wgt, max_wgt)

if inhibitory_flag:
clipped_weights = -clipped_weights

return clipped_weights
8 changes: 5 additions & 3 deletions tests/lava/proc/scif/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def run_test(

spk_src = SpikeSource(data=np.array([[0] * num_neurons]).reshape(
num_neurons, 1).astype(int))
# TODO (MR): The weight of -1 is now being correctly encoded as -1.
# It was written assuming the weight would be truncated to -2.
dense_in = Dense(weights=(-1) * np.eye(num_neurons),
num_message_bits=16)
csp_scif = CspScif(shape=(num_neurons,),
Expand Down Expand Up @@ -408,12 +410,12 @@ def test_scif_fp_no_noise_interrupt_rfct_mid(self) -> None:
self.assertTrue(np.all(v_scif[spk_idxs_pre_inj] == neg_tau_ref))
self.assertTrue(np.all(v_lif_wta[wta_pos_spk_pre_inj] == 1))
self.assertTrue(np.all(
v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + wt * step_size))
v_gt_inh_inj = (inh_inj - spk_idxs_pre_inj + 1) - t_inj_spk[inh_inj]
v_lif_sig[sig_pos_spk_pre_inj] == cost_diag + 1))
v_gt_inh_inj = step_size - t_inj_spk[inh_inj]
self.assertTrue(np.all(v_scif[inh_inj] == v_gt_inh_inj))
self.assertTrue(np.all(v_lif_wta[wta_spk_rfct_interrupt] == 1))
self.assertTrue(np.all(
v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + wt * step_size))
v_lif_sig[sig_spk_rfct_interrupt] == cost_diag + 1))
# Test post-inhibitory-injection SCIF voltage and spiking
idx_lst = [inj_times[2] + (theta // step_size) - 1 + j * total_period
for j in range(num_epochs)]
Expand Down
36 changes: 26 additions & 10 deletions tests/lava/utils/test_weightutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def test_optimize_raises_error_when_weights_out_of_bounds(self) -> None:
optimize_weight_bits(weights=np.array([-257]),
sign_mode=SignMode.EXCITATORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([-256]),
sign_mode=SignMode.INHIBITORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([1]),
sign_mode=SignMode.INHIBITORY)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([-257]),
sign_mode=SignMode.MIXED)

with self.assertRaises(ValueError):
optimize_weight_bits(weights=np.array([257]),
sign_mode=SignMode.MIXED)

def test_optimize_weight_bits_excitatory_8bit(self) -> None:
weights = np.arange(0, 255, 1, dtype=int)
sign_mode = SignMode.EXCITATORY
Expand All @@ -60,7 +76,7 @@ def test_optimize_weight_bits_excitatory_7bit(self) -> None:
self.assertEqual(optimized.num_weight_bits, 7)

def test_optimize_weight_bits_inhibitory_8bit(self) -> None:
weights = np.arange(-256, -1, 1, dtype=int)
weights = np.arange(-255, 0, 1, dtype=int)
sign_mode = SignMode.INHIBITORY

optimized = optimize_weight_bits(weights=weights,
Expand All @@ -72,7 +88,7 @@ def test_optimize_weight_bits_inhibitory_8bit(self) -> None:
self.assertEqual(optimized.num_weight_bits, 8)

def test_optimize_weight_bits_inhibitory_7bit(self) -> None:
weights = np.arange(-256, -1, 2, dtype=int)
weights = np.arange(-254, 0, 2, dtype=int)
sign_mode = SignMode.INHIBITORY

optimized = optimize_weight_bits(weights=weights,
Expand Down Expand Up @@ -123,12 +139,12 @@ def test_optimize_weight_bits_weight_exp(self) -> None:
self.assertEqual(optimized.num_weight_bits, 3)

def test_determine_weight_exp_inhibitory_0(self) -> None:
weight_exp = _determine_weight_exp(weights=np.array([-256, -128, -1]),
weight_exp = _determine_weight_exp(weights=np.array([-255, -128, -1]),
sign_mode=SignMode.INHIBITORY)
self.assertEqual(weight_exp, 0)

def test_determine_weight_exp_inhibitory_1(self) -> None:
weight_exp = _determine_weight_exp(weights=np.array([-512, -256, -1]),
weight_exp = _determine_weight_exp(weights=np.array([-510, -256, -1]),
sign_mode=SignMode.INHIBITORY)
self.assertEqual(weight_exp, 1)

Expand Down Expand Up @@ -285,13 +301,13 @@ def test_truncate_weights_inhibitory_8(self) -> None:

def test_truncate_weights_inhibitory_7(self) -> None:
truncated_weights = truncate_weights(
weights=np.array([-256, -255, -254, -253, -252]),
weights=np.array([-255, -254, -253, -252, -251]),
sign_mode=SignMode.INHIBITORY,
num_weight_bits=7
)
np.testing.assert_array_equal(
truncated_weights,
np.array([-256, -256, -254, -254, -252])
np.array([-254, -254, -252, -252, -250])
)

def test_truncate_weights_mixed_8(self) -> None:
Expand Down Expand Up @@ -342,24 +358,24 @@ def test_clip_weights_excitatory_7(self) -> None:

def test_clip_weights_inhibitory_8(self) -> None:
clipped_weights = clip_weights(
weights=np.array([-257, -256, -1, 0, 1]),
weights=np.array([-256, -255, -1, 0, 1]),
sign_mode=SignMode.INHIBITORY,
num_bits=8
)
np.testing.assert_array_equal(
clipped_weights,
np.array([-256, -256, -1, 0, 0])
np.array([-255, -255, -1, 0, 0])
)

def test_clip_weights_inhibitory_7(self) -> None:
clipped_weights = clip_weights(
weights=np.array([-129, -128, -1, 0, 1]),
weights=np.array([-128, -127, -1, 0, 1]),
sign_mode=SignMode.INHIBITORY,
num_bits=7
)
np.testing.assert_array_equal(
clipped_weights,
np.array([-128, -128, -1, 0, 0])
np.array([-127, -127, -1, 0, 0])
)

def test_clip_weights_mixed_8(self) -> None:
Expand Down