Skip to content

Commit

Permalink
LearningDense bit-accurate (#812)
Browse files Browse the repository at this point in the history
* minor change in dependency computation

* updating stochastic round type hint

* small fix in clip_weights

* progress in making tests pass

* fixing Sparse init

* trying tests

* adapting init method of LearningDense Process

---------

Co-authored-by: PhilippPlank <[email protected]>
  • Loading branch information
gkarray and PhilippPlank authored Jan 8, 2024
1 parent 19ef851 commit ce5c755
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 25 deletions.
11 changes: 6 additions & 5 deletions src/lava/magma/core/learning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,24 @@ def stochastic_round(values: np.ndarray,
return (values + (random_numbers < probabilities).astype(int)).astype(int)


def apply_mask(int_number: int, nb_bits: int) -> int:
def apply_mask(item: ty.Union[np.ndarray, int], nb_bits: int) \
-> ty.Union[np.ndarray, int]:
"""Get nb_bits least-significant bits.
Parameters
----------
int_number : int
Integer number.
item : np.ndarray or int
Item to apply mask to.
nb_bits : int
Number of LSBs to keep.
Returns
----------
result : int
result : np.ndarray or int
Least-significant bits.
"""
mask = ~(~0 << nb_bits)
return int_number & mask
return item & mask


def float_to_literal(learning_parameter: float) -> str:
Expand Down
4 changes: 2 additions & 2 deletions src/lava/magma/core/model/py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ def _update_synaptic_variable_random(self) -> None:
pass

def _update_dependencies(self) -> None:
self.x0[self.tx > 0] = True
self.y0[self.ty > 0] = True
self.x0 = self.tx > 0
self.y0 = self.ty > 0

@abstractmethod
def _compute_trace_histories(self) -> typing.Tuple[np.ndarray, np.ndarray]:
Expand Down
13 changes: 13 additions & 0 deletions src/lava/proc/dense/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ class LearningDense(LearningConnectionProcess, Dense):
def __init__(self,
*,
weights: np.ndarray,
tag_2: ty.Optional[np.ndarray] = None,
tag_1: ty.Optional[np.ndarray] = None,
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
Expand All @@ -164,6 +166,8 @@ def __init__(self,
learning_rule.x1_impulse = 0

super().__init__(weights=weights,
tag_2=tag_2,
tag_1=tag_1,
shape=weights.shape,
name=name,
num_message_bits=num_message_bits,
Expand All @@ -172,6 +176,15 @@ def __init__(self,
graded_spike_cfg=graded_spike_cfg,
**kwargs)

if tag_2 is None:
tag_2 = np.zeros(weights.shape)

if tag_1 is None:
tag_1 = np.zeros(weights.shape)

self.tag_2.init = tag_2.copy()
self.tag_1.init = tag_1.copy()


class DelayDense(Dense):
def __init__(self,
Expand Down
42 changes: 26 additions & 16 deletions src/lava/proc/sparse/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,26 +69,26 @@ def __init__(self,
log_config=log_config,
**kwargs)

weights = self._create_csr_matrix_from_weights(weights)
weights = self._create_csr_matrix(weights)
shape = weights.shape

# Ports
self.s_in = InPort(shape=(shape[1],))
self.a_out = OutPort(shape=(shape[0],))

# Variables
self.weights = Var(shape=shape, init=weights)
self.weights = Var(shape=shape, init=weights.copy())
self.a_buff = Var(shape=(shape[0],), init=0)
self.num_message_bits = Var(shape=(1,), init=num_message_bits)

@staticmethod
def _create_csr_matrix_from_weights(weights):
def _create_csr_matrix(matrix):
# Transform weights to csr matrix
if isinstance(weights, np.ndarray):
weights = csr_matrix(weights)
if isinstance(matrix, np.ndarray):
matrix = csr_matrix(matrix)
else:
weights = weights.tocsr()
return weights
matrix = matrix.tocsr()
return matrix


class LearningSparse(LearningConnectionProcess, Sparse):
Expand Down Expand Up @@ -160,6 +160,8 @@ class LearningSparse(LearningConnectionProcess, Sparse):
def __init__(self,
*,
weights: ty.Union[spmatrix, np.ndarray],
tag_2: ty.Optional[ty.Union[spmatrix, np.ndarray]] = None,
tag_1: ty.Optional[ty.Union[spmatrix, np.ndarray]] = None,
name: ty.Optional[str] = None,
num_message_bits: ty.Optional[int] = 0,
log_config: ty.Optional[LogConfig] = None,
Expand All @@ -172,6 +174,8 @@ def __init__(self,
learning_rule.x1_impulse = 0

super().__init__(weights=weights,
tag_2=tag_2,
tag_1=tag_1,
shape=weights.shape,
num_message_bits=num_message_bits,
name=name,
Expand All @@ -180,17 +184,23 @@ def __init__(self,
graded_spike_cfg=graded_spike_cfg,
**kwargs)

weights = self._create_csr_matrix_from_weights(weights)
shape = weights.shape
if tag_2 is None:
tag_2 = np.zeros(weights.shape)

# Ports
self.s_in = InPort(shape=(shape[1],))
self.a_out = OutPort(shape=(shape[0],))
if tag_1 is None:
tag_1 = np.zeros(weights.shape)

# Variables
self.weights = Var(shape=shape, init=weights)
self.a_buff = Var(shape=(shape[0],), init=0)
self.num_message_bits = Var(shape=(1,), init=num_message_bits)
tag_2 = self._create_csr_matrix(tag_2)
tag_1 = self._create_csr_matrix(tag_1)

self.tag_2.init = tag_2.copy()
self.tag_1.init = tag_1.copy()

self.proc_params["x_idx_active_syn_vars"] = {
"weights": weights.nonzero()[1],
"tag_2": tag_2.nonzero()[1],
"tag_1": tag_1.nonzero()[1]
}


class DelaySparse(Sparse):
Expand Down
9 changes: 7 additions & 2 deletions src/lava/utils/weightutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def truncate_weights(weights: ty.Union[np.ndarray, spmatrix],

def clip_weights(weights: ty.Union[np.ndarray, spmatrix],
sign_mode: SignMode,
num_bits: int) -> ty.Union[np.ndarray, spmatrix]:
num_bits: int,
learning_simulation: ty.Optional[bool] = False) \
-> ty.Union[np.ndarray, spmatrix]:
"""Truncate the least significant bits of the weight matrix given the
sign mode and number of weight bits.
Expand All @@ -261,6 +263,9 @@ def clip_weights(weights: ty.Union[np.ndarray, spmatrix],
Sign mode to use for truncation.
num_bits : int
Number of bits to use to clip the weights to.
learning_simulation : bool, optional
Boolean flag, specifying if this method is used in context of learning
(in simulation).
Returns
-------
Expand All @@ -276,7 +281,7 @@ def clip_weights(weights: ty.Union[np.ndarray, spmatrix],
weights = -weights

min_wgt = (-2 ** num_bits) * mixed_flag
max_wgt = 2 ** num_bits - 1
max_wgt = 2 ** num_bits - 1 - learning_simulation * mixed_flag

if isinstance(weights, np.ndarray):
weights = np.clip(weights, min_wgt, max_wgt)
Expand Down

0 comments on commit ce5c755

Please sign in to comment.