Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Sparse processes #672

Merged
merged 52 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
0debe63
added Sparse proc and init test
weidel-p Apr 5, 2023
ea09e52
added test and implementatin for sparse proc model in floating precision
weidel-p Apr 5, 2023
6750191
test for graded spikes
weidel-p Apr 6, 2023
98a8393
added bit acc version and adapted weight utils
weidel-p Apr 6, 2023
63597f7
typehints
weidel-p Apr 6, 2023
d160a8e
integer weights for fixed point tests
weidel-p Apr 6, 2023
19ad0c1
draft learning
weidel-p Apr 12, 2023
5905b9c
draft get
weidel-p Apr 12, 2023
f5cc6a0
refactoring tests
Apr 13, 2023
cbbd39b
avoid saving var in varmodel
weidel-p Apr 13, 2023
bced9ed
draft set
weidel-p Apr 13, 2023
834ca28
Sparse get/set working on CPU
Apr 13, 2023
a970b2c
delay sparse process + test
SveaMeyer13 Apr 14, 2023
9e51b4d
change order of weights
weidel-p Apr 17, 2023
ad6238b
minor fix for complete sparse matrix
weidel-p Apr 17, 2023
b223d90
minor fix
weidel-p Apr 17, 2023
320b758
delay sparse model + test
SveaMeyer13 Apr 17, 2023
10f8d02
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
SveaMeyer13 Apr 17, 2023
2c90884
get/set for float/fixed
weidel-p Apr 17, 2023
7bd35de
delay sparse fix for int input
SveaMeyer13 Apr 17, 2023
e3cacb0
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
SveaMeyer13 Apr 17, 2023
12dc022
LearningSparse floating-pt version + test
Apr 17, 2023
d22a0fd
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
Apr 17, 2023
2941f8e
all tests for delay dense also run for delay sparse
SveaMeyer13 Apr 18, 2023
c4b410c
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
SveaMeyer13 Apr 18, 2023
1efe122
update test naming
SveaMeyer13 Apr 19, 2023
db8a3b4
use dot product
SveaMeyer13 Apr 19, 2023
e689067
make learning float work
weidel-p Apr 21, 2023
35f38a0
fixed pt learning
weidel-p Apr 21, 2023
c255b90
rm bit approx version
weidel-p Apr 21, 2023
4415740
lint
weidel-p Apr 21, 2023
d8c9c9b
added learning dense for bit approx
weidel-p Apr 24, 2023
8b93c84
improve calculation of wgt_dly
SveaMeyer13 Apr 24, 2023
902b186
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
SveaMeyer13 Apr 24, 2023
f6b8c56
tests for dt and dd
weidel-p Apr 24, 2023
e7552c0
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
weidel-p Apr 24, 2023
cd3a48d
lint
weidel-p Apr 25, 2023
d5a9275
Merge branch 'main' into dev/sparse_proc
weidel-p Apr 25, 2023
4a53076
avoid warnings
weidel-p Apr 25, 2023
03eaadf
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
weidel-p Apr 25, 2023
cfea135
improve getting the zero matrix
SveaMeyer13 Apr 25, 2023
28cece9
lint
SveaMeyer13 Apr 25, 2023
f8200f2
minor changes
weidel-p Apr 26, 2023
f5bad9f
improve documentation
SveaMeyer13 Apr 27, 2023
1d79bcf
Merge branch 'main' into dev/sparse_proc
PhilippPlank May 2, 2023
f158b80
minor change
weidel-p May 2, 2023
ee497df
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
weidel-p May 2, 2023
91ecf61
changes to comments
weidel-p May 2, 2023
80f2486
improve documentation
SveaMeyer13 May 2, 2023
2741c16
improve documentation
SveaMeyer13 May 2, 2023
d3bef4e
minor change
weidel-p May 2, 2023
65e6439
Merge branch 'dev/sparse_proc' of https://github.com/lava-nc/lava int…
weidel-p May 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/lava/magma/compiler/builders/py_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing as ty

import numpy as np
from scipy.sparse import csr_matrix
from lava.magma.compiler.builders.interfaces import AbstractProcessBuilder

from lava.magma.compiler.channels.interfaces import AbstractCspPort
Expand Down Expand Up @@ -402,6 +403,12 @@ def build(self):
var[:] = v.value
elif issubclass(lt.cls, (int, float, str)):
var = v.value
elif issubclass(lt.cls, (csr_matrix)):
if isinstance(v.value, int):
var = csr_matrix(v.shape, dtype=lt.d_type)
var[:] = v.value
else:
var = v.value
else:
raise NotImplementedError(
"Cannot initiliaze variable "
Expand Down
1 change: 1 addition & 0 deletions src/lava/magma/compiler/var_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __post_init__(self, var: Var) -> None:
self.name: str = var.name
self.shape: ty.Tuple[int, ...] = var.shape
self.proc_id: int = var.process.id
self.dtype = type(var.init)


@dataclass
Expand Down
137 changes: 90 additions & 47 deletions src/lava/magma/core/model/py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# See: https://spdx.org/licenses/

from abc import abstractmethod
from lava.utils.sparse import find_with_explicit_zeros
import numpy as np
import typing
from scipy.sparse import csr_matrix

from lava.magma.core.learning.learning_rule import (
LoihiLearningRule,
Expand Down Expand Up @@ -827,11 +829,12 @@ def _compute_trace_histories(self) -> typing.Tuple[np.ndarray, np.ndarray]:
t_spike_y = self.ty

# most naive algorithm to decay traces
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
x_traces_history = np.full((t_epoch + 1,) + x_traces.shape, np.nan,
x_traces_history = np.full((t_epoch + 1,) + x_traces.shape, 0,
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
dtype=int)
x_traces_history[0] = x_traces
y_traces_history = np.full((t_epoch + 1,) + y_traces.shape, np.nan,
y_traces_history = np.full((t_epoch + 1,) + y_traces.shape, 0,
dtype=int)

y_traces_history[0] = y_traces

for t in range(1, t_epoch + 1):
Expand Down Expand Up @@ -942,10 +945,16 @@ def _apply_learning_rules(

for syn_var_name, lr_applier in self._learning_rule_appliers.items():
syn_var = getattr(self, syn_var_name).copy()
syn_var = np.left_shift(
syn_var, W_ACCUMULATOR_S - W_SYN_VAR_S[syn_var_name]
)
syn_var = lr_applier.apply(syn_var, **applier_args)
shift = W_ACCUMULATOR_S - W_SYN_VAR_S[syn_var_name]
if isinstance(syn_var, csr_matrix):
syn_var.data = syn_var.data << shift
dst, src, _ = find_with_explicit_zeros(syn_var)
syn_var[dst, src] = lr_applier.apply(syn_var,
**applier_args)[dst, src]
else:
syn_var = syn_var << shift
syn_var = lr_applier.apply(syn_var, **applier_args)

syn_var = self._saturate_synaptic_variable_accumulator(
syn_var_name, syn_var
)
Expand All @@ -954,9 +963,11 @@ def _apply_learning_rules(
syn_var,
self._conn_var_random.random_stochastic_round,
)
syn_var = np.right_shift(
syn_var, W_ACCUMULATOR_S - W_SYN_VAR_S[syn_var_name]
)

if isinstance(syn_var, csr_matrix):
syn_var.data = syn_var.data >> shift
PhilippPlank marked this conversation as resolved.
Show resolved Hide resolved
else:
syn_var = syn_var >> shift

syn_var = self._saturate_synaptic_variable(syn_var_name, syn_var)
setattr(self, syn_var_name, syn_var)
Expand Down Expand Up @@ -1015,18 +1026,18 @@ def _extract_applier_args(
return applier_args

def _saturate_synaptic_variable_accumulator(
self, synaptic_variable_name: str,
synaptic_variable_values: np.ndarray
self, syn_var_name: str,
syn_var_values: typing.Union[np.ndarray, csr_matrix]
) -> np.ndarray:
"""Saturate synaptic variable accumulator.

Checks that sign is valid.

Parameters
----------
synaptic_variable_name: str
syn_var_name: str
Synaptic variable name.
synaptic_variable_values: ndarray
syn_var_values: ndarray
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
Synaptic variable values to saturate.

Returns
Expand All @@ -1035,63 +1046,73 @@ def _saturate_synaptic_variable_accumulator(
Saturated synaptic variable values.
"""
# Weights
if synaptic_variable_name == "weights":
if syn_var_name == "weights":
if self.sign_mode == SignMode.MIXED:
return synaptic_variable_values
return syn_var_values
elif self.sign_mode == SignMode.EXCITATORY:
return np.maximum(0, synaptic_variable_values)
return np.maximum(0, syn_var_values)
elif self.sign_mode == SignMode.INHIBITORY:
return np.minimum(0, synaptic_variable_values)
return np.minimum(0, syn_var_values)
# Delays
elif synaptic_variable_name == "tag_2":
return np.maximum(0, synaptic_variable_values)
elif syn_var_name == "tag_2":
if isinstance(syn_var_values, csr_matrix):
syn_var_values.data[syn_var_values.data < 0] = 0
return syn_var_values
return np.maximum(0, syn_var_values)
# Tags
elif synaptic_variable_name == "tag_1":
return synaptic_variable_values
elif syn_var_name == "tag_1":
return syn_var_values
else:
raise ValueError(
f"synaptic_variable_name can be 'weights', "
f"syn_var_name can be 'weights', "
f"'tag_1', or 'tag_2'."
f"Got {synaptic_variable_name=}."
f"Got {syn_var_name=}."
)

@staticmethod
def _stochastic_round_synaptic_variable(
synaptic_variable_name: str,
synaptic_variable_values: np.ndarray,
syn_var_name: str,
syn_var_values: typing.Union[np.ndarray, csr_matrix],
random: float,
) -> np.ndarray:
) -> typing.Union[np.ndarray, csr_matrix]:
"""Stochastically round synaptic variable after learning rule
application.

Parameters
----------
synaptic_variable_name: str
syn_var_name: str
Synaptic variable name.
synaptic_variable_values: ndarray
syn_var_values: ndarray
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
Synaptic variable values to stochastically round.

Returns
----------
result : ndarray
Stochastically rounded synaptic variable values.
"""
exp_mant = 2 ** (W_ACCUMULATOR_U - W_SYN_VAR_U[synaptic_variable_name])
exp_mant = 2 ** (W_ACCUMULATOR_U - W_SYN_VAR_U[syn_var_name])

integer_part = synaptic_variable_values / exp_mant
if isinstance(syn_var_values, csr_matrix):
integer_part = syn_var_values.data / exp_mant
else:
integer_part = syn_var_values / exp_mant
fractional_part = integer_part % 1

integer_part = np.floor(integer_part)
integer_part = stochastic_round(integer_part, random, fractional_part)
result = (integer_part * exp_mant).astype(
synaptic_variable_values.dtype
)

return result
if isinstance(syn_var_values, csr_matrix):
syn_var_values.data = (integer_part
* exp_mant).astype(syn_var_values.dtype)
return syn_var_values
else:
return (integer_part * exp_mant).astype(
syn_var_values.dtype
)

def _saturate_synaptic_variable(
self, synaptic_variable_name: str,
synaptic_variable_values: np.ndarray
self, syn_var_name: str,
syn_var_val: typing.Union[np.ndarray, csr_matrix]
) -> np.ndarray:
"""Saturate synaptic variable.

Expand All @@ -1100,40 +1121,54 @@ def _saturate_synaptic_variable(

Parameters
----------
synaptic_variable_name: str
syn_var_name: str
Synaptic variable name.
synaptic_variable_values: ndarray
syn_var_val: ndarray
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
Synaptic variable values to saturate.

Returns
----------
result : ndarray
Saturated synaptic variable values.
"""

# Weights
if synaptic_variable_name == "weights":
if syn_var_name == "weights":
return clip_weights(
synaptic_variable_values,
syn_var_val,
sign_mode=self.sign_mode,
num_bits=W_WEIGHTS_U,
)
# Delays
elif synaptic_variable_name == "tag_2":
elif syn_var_name == "tag_2":
if isinstance(syn_var_val, csr_matrix):
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
_min = -(2 ** W_TAG_2_U) - 1
_max = (2 ** W_TAG_2_U) - 1
syn_var_val.data[syn_var_val.data < _min] = _min
syn_var_val.data[syn_var_val.data > _max] = _max
return syn_var_val

return np.clip(
synaptic_variable_values, a_min=0, a_max=2 ** W_TAG_2_U - 1
syn_var_val, a_min=0, a_max=2 ** W_TAG_2_U - 1
)
# Tags
elif synaptic_variable_name == "tag_1":
elif syn_var_name == "tag_1":
if isinstance(syn_var_val, csr_matrix):
_min = -(2 ** W_TAG_1_U) - 1
_max = (2 ** W_TAG_1_U) - 1
syn_var_val.data[syn_var_val.data < _min] = _min
syn_var_val.data[syn_var_val.data > _max] = _max
return syn_var_val
return np.clip(
synaptic_variable_values,
syn_var_val,
a_min=-(2 ** W_TAG_1_U) - 1,
a_max=2 ** W_TAG_1_U - 1,
)
else:
raise ValueError(
f"synaptic_variable_name can be 'weights', "
f"syn_var_name can be 'weights', "
f"'tag_1', or 'tag_2'."
f"Got {synaptic_variable_name=}."
f"Got {syn_var_name=}."
)


Expand Down Expand Up @@ -1423,7 +1458,12 @@ def _apply_learning_rules(

for syn_var_name, lr_applier in self._learning_rule_appliers.items():
syn_var = getattr(self, syn_var_name).copy()
syn_var = lr_applier.apply(syn_var, **applier_args)
if (isinstance(syn_var, csr_matrix)):
dst, src, _ = find_with_explicit_zeros(syn_var)
syn_var[dst, src] = lr_applier.apply(syn_var,
**applier_args)[dst, src]
else:
syn_var = lr_applier.apply(syn_var, **applier_args)
syn_var = self._saturate_synaptic_variable(syn_var_name, syn_var)
setattr(self, syn_var_name, syn_var)

Expand Down Expand Up @@ -1511,6 +1551,9 @@ def _saturate_synaptic_variable(
elif synaptic_variable_name == "tag_1":
return synaptic_variable_values
elif synaptic_variable_name == "tag_2":
if isinstance(synaptic_variable_values, csr_matrix):
synaptic_variable_values[synaptic_variable_values < 0] = 0
return synaptic_variable_values
return np.maximum(0, synaptic_variable_values)
else:
raise ValueError(
Expand Down
22 changes: 21 additions & 1 deletion src/lava/magma/core/model/py/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import typing as ty
from abc import ABC, abstractmethod
import logging
from lava.utils.sparse import find_with_explicit_zeros
import numpy as np
from scipy.sparse import csr_matrix, find
import platform

from lava.magma.compiler.channels.pypychannel import (
Expand Down Expand Up @@ -125,6 +127,12 @@ def _get_var(self):
data_port.send(enum_to_np(num_items))
for value in var_iter:
data_port.send(enum_to_np(value, np.float64))
elif isinstance(var, csr_matrix):
dst, src, values = find_with_explicit_zeros(var)
num_items = var.data.size
data_port.send(enum_to_np(num_items))
for value in values:
data_port.send(enum_to_np(value, np.float64))
elif isinstance(var, str):
encoded_str = list(var.encode("ascii"))
data_port.send(enum_to_np(len(encoded_str)))
Expand Down Expand Up @@ -161,6 +169,19 @@ def _set_var(self):
num_items -= 1
i[...] = data_port.recv()[0]
self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE)
elif isinstance(var, csr_matrix):
# First item is number of items
num_items = int(data_port.recv()[0])

buffer = np.empty(num_items)
# Set data one by one
for i in range(num_items):
buffer[i] = data_port.recv()[0]
dst, src, _ = find(var)
var = csr_matrix((buffer, (dst, src)), var.shape)
setattr(self, var_name, var)

self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE)
elif isinstance(var, str):
# First item is number of items
num_items = int(data_port.recv()[0])
Expand All @@ -172,7 +193,6 @@ def _set_var(self):
s = bytes(s).decode("ascii")
setattr(self, var_name, s)
self.process_to_service.send(MGMT_RESPONSE.SET_COMPLETE)

else:
self.process_to_service.send(MGMT_RESPONSE.ERROR)
raise RuntimeError("Unsupported type")
Expand Down
21 changes: 20 additions & 1 deletion src/lava/magma/core/process/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import typing as ty
import numpy as np
from scipy.sparse import csr_matrix, spmatrix, find

from lava.magma.core.process.interfaces import (
AbstractProcessMember,
Expand Down Expand Up @@ -126,9 +127,22 @@ def validate_alias(self):
f"."
)

def set(self, value: ty.Union[np.ndarray, str], idx: np.ndarray = None):
def set(self,
value: ty.Union[np.ndarray, str, spmatrix],
idx: np.ndarray = None):
"""Sets value of Var. If this Var aliases another Var, then set(..) is
delegated to aliased Var."""
if isinstance(value, spmatrix):
weidel-p marked this conversation as resolved.
Show resolved Hide resolved
value = value.tocsr()
if value.shape != self.init.shape or \
(value.indices != self.init.indices).any() or \
(value.indptr != self.init.indptr).any() or \
(len(find(value)[2]) != len(find(self.init)[2])):
raise ValueError("Indices and number of non-zero elements "
"must stay equal when using set on a"
"sparse matrix.")
value = find(value)[2]

if self.aliased_var is not None:
self.aliased_var.set(value, idx)
else:
Expand Down Expand Up @@ -156,6 +170,11 @@ def get(self, idx: np.ndarray = None) -> np.ndarray:
if isinstance(self.init, str):
# decode if var is string
return bytes(buffer.astype(int).tolist()).decode("ascii")
if isinstance(self.init, csr_matrix):
dst, src, _ = find(self.init)

ret = csr_matrix((buffer, (dst, src)), self.init.shape)
return ret
else:
return buffer
else:
Expand Down
Loading