diff --git a/poetry.lock b/poetry.lock index 532eb143..0cbbddbb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -816,7 +816,7 @@ scipy = "^1.8.0" type = "git" url = "https://github.com/lava-nc/lava.git" reference = "main" -resolved_reference = "d321a8b01b7cac017c52f196329c4e486558b8b7" +resolved_reference = "5f7faa7ea19667bc3e0991413e4cd7edbd794220" [[package]] name = "linecache2" diff --git a/src/lava/lib/dl/netx/blocks/models.py b/src/lava/lib/dl/netx/blocks/models.py index 2b0f4d2b..8f4787d6 100644 --- a/src/lava/lib/dl/netx/blocks/models.py +++ b/src/lava/lib/dl/netx/blocks/models.py @@ -11,7 +11,8 @@ from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol from lava.magma.core.resources import CPU -from lava.lib.dl.netx.blocks.process import Input, Dense, Conv +from lava.lib.dl.netx.blocks.process import Input, ComplexInput, Dense, Conv,\ + ComplexDense @requires(CPU) @@ -47,12 +48,24 @@ def __init__(self, proc: AbstractProcess) -> None: super().__init__(proc) +@implements(proc=ComplexInput, protocol=LoihiProtocol) +class PyComplexInputModel(AbstractPyBlockModel): + def __init__(self, proc: AbstractProcess) -> None: + super().__init__(proc) + + @implements(proc=Dense, protocol=LoihiProtocol) class PyDenseModel(AbstractPyBlockModel): def __init__(self, proc: AbstractProcess) -> None: super().__init__(proc) +@implements(proc=ComplexDense, protocol=LoihiProtocol) +class PyComplexDenseModel(AbstractPyBlockModel): + def __init__(self, proc: AbstractProcess) -> None: + super().__init__(proc) + + @implements(proc=Conv, protocol=LoihiProtocol) class PyConvModel(AbstractPyBlockModel): def __init__(self, proc: AbstractProcess) -> None: diff --git a/src/lava/lib/dl/netx/blocks/process.py b/src/lava/lib/dl/netx/blocks/process.py index cf3651ac..f18cd223 100644 --- a/src/lava/lib/dl/netx/blocks/process.py +++ b/src/lava/lib/dl/netx/blocks/process.py @@ -108,7 +108,6 @@ def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None: class Dense(AbstractBlock): """Dense layer block. - Parameters ---------- shape : tuple or list @@ -164,6 +163,106 @@ def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None: raise NotImplementedError +class ComplexDense(AbstractBlock): + """Dense Complex layer block. + + Parameters + ---------- + shape : tuple or list + shape of the layer block in (x, y, z)/WHC format. + neuron_params : dict, optional + dictionary of neuron parameters. Defaults to None. + weight_real : np.ndarray + synaptic real weight. + weight_imag : np.ndarray + synaptic imag weight. + has_graded_input : dict + flag for graded spikes at input. Defaults to False. + num_weight_bits_real : int + number of real weight bits. Defaults to 8. + num_weight_bits_imag : int + number of imag weight bits. Defaults to 8. + weight_exponent_real : int + real weight exponent value. Defaults to 0. + weight_exponent_imag : int + imag weight exponent value. Defaults to 0. + input_message_bits : int, optional + number of message bits in input spike. Defaults to 0 meaning unary + spike. + """ + + def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None: + super().__init__(**kwargs) + + num_weight_bits_real = kwargs.pop('num_weight_bits_real', 8) + num_weight_bits_imag = kwargs.pop('num_weight_bits_imag', 8) + + weight_exponent_real = kwargs.pop('weight_exponent_real', 0) + weight_exponent_imag = kwargs.pop('weight_exponent_imag', 0) + weight_real = kwargs.pop('weight_real') + weight_imag = kwargs.pop('weight_imag') + + self.neuron = self._neuron(None) + self.real_synapse = DenseSynapse( + weights=weight_real, + weight_exp=weight_exponent_real, + num_weight_bits=num_weight_bits_real, + num_message_bits=self.input_message_bits, + ) + self.imag_synapse = DenseSynapse( + weights=weight_imag, + weight_exp=weight_exponent_imag, + num_weight_bits=num_weight_bits_imag, + num_message_bits=self.input_message_bits, + ) + + if self.shape != self.real_synapse.a_out.shape: + raise RuntimeError( + f'Expected synapse output shape to be {self.shape[-1]}, ' + f'found {self.synapse.a_out.shape}.' + ) + + self.inp = InPort(shape=self.real_synapse.s_in.shape) + self.out = OutPort(shape=self.neuron.s_out.shape) + self.inp.connect(self.real_synapse.s_in) + self.inp.connect(self.imag_synapse.s_in) + self.real_synapse.a_out.connect(self.neuron.a_real_in) + self.imag_synapse.a_out.connect(self.neuron.a_imag_in) + self.neuron.s_out.connect(self.out) + + self._clean() + + def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None: + raise NotImplementedError + + +class ComplexInput(AbstractBlock): + """Input layer block. + + Parameters + ---------- + shape : tuple or list + shape of the layer block in (x, y, z)/WHC format. + neuron_params : dict, optional + dictionary of neuron parameters. Defaults to None. + """ + + def __init__(self, **kwargs: Union[dict, tuple, list, int, bool]) -> None: + super().__init__(**kwargs) + self.neuron = self._neuron(None) + + self.inp = InPort(shape=self.neuron.a_real_in.shape) + self.inp.connect(self.neuron.a_real_in) + self.inp.connect(self.neuron.a_imag_in) + self.out = OutPort(shape=self.neuron.s_out.shape) + self.neuron.s_out.connect(self.out) + + self._clean() + + def export_hdf5(self, handle: Union[h5py.File, h5py.Group]) -> None: + raise NotImplementedError + + class Conv(AbstractBlock): """Conv layer block. diff --git a/src/lava/lib/dl/netx/hdf5.py b/src/lava/lib/dl/netx/hdf5.py index b423e180..dec08379 100644 --- a/src/lava/lib/dl/netx/hdf5.py +++ b/src/lava/lib/dl/netx/hdf5.py @@ -7,6 +7,8 @@ import warnings from lava.magma.core.decorator import implements from lava.magma.core.sync.protocols.loihi_protocol import LoihiProtocol +from lava.proc.rf.process import RF +from lava.proc.rf_iz.process import RF_IZ import numpy as np import h5py @@ -14,9 +16,10 @@ from lava.magma.core.process.ports.ports import InPort, OutPort from lava.proc.lif.process import LIF, LIFReset from lava.proc.sdn.process import Sigma, Delta, SigmaDelta +from lava.lib.dl.slayer.neuron.rf import neuron_params as get_rf_params from lava.lib.dl.netx.utils import NetDict from lava.lib.dl.netx.utils import optimize_weight_bits -from lava.lib.dl.netx.blocks.process import Input, Dense, Conv +from lava.lib.dl.netx.blocks.process import Input, Dense, Conv, ComplexDense from lava.lib.dl.netx.blocks.models import AbstractPyBlockModel @@ -170,6 +173,20 @@ def get_neuron_params(neuron_config: h5py.Group, 'state_exp': 6, 'num_message_bits': num_message_bits} return neuron_params + elif "RF" in neuron_type: + if num_message_bits is None: + num_message_bits = 0 # default value + neuron_process = RF if "PHASE" in neuron_type else RF_IZ + neuron_params = get_rf_params(neuron_config) + neuron_params = { + 'neuron_proc': neuron_process, + 'vth': neuron_config['vThMant'], + 'period': neuron_params['period'], + 'alpha': neuron_params['decay'], + 'state_exp': 6, + 'decay_bits': 12 + } + return neuron_params @staticmethod def _table_str(type_str: str = '', @@ -293,34 +310,66 @@ def create_dense(layer_config: h5py.Group, table entry string for process. """ shape = (np.prod(layer_config['shape']),) + neuron_params = Network.get_neuron_params(layer_config['neuron'], reset_interval=reset_interval, reset_offset=reset_offset) - weight = layer_config['weight'] - if weight.ndim == 1: - weight = weight.reshape(shape[0], -1) - - opt_weights = optimize_weight_bits(weight) - weight, num_weight_bits, weight_exponent, sign_mode = opt_weights - - # arguments for dense block - params = {'shape': shape, - 'neuron_params': neuron_params, - 'weight': weight, - 'num_weight_bits': num_weight_bits, - 'weight_exponent': weight_exponent, - 'sign_mode': sign_mode, - 'input_message_bits': input_message_bits} - - # optional arguments - if 'bias' in layer_config.keys(): - params['bias'] = layer_config['bias'] + # check to see for nested weights + if isinstance(layer_config["weight"], NetDict): + weight_real = layer_config['weight/real'] + weight_imag = layer_config['weight/imag'] + if weight_real.ndim == 1: + weight_real = weight_real.reshape(shape[0], -1) + weight_imag = weight_imag.reshape(shape[0], -1) + + opt_weights_real = optimize_weight_bits(weight_real) + opt_weights_imag = optimize_weight_bits(weight_imag) + weight_real, num_weight_bits_real, weight_exponent_real,\ + sign_mode_real = opt_weights_real + weight_imag, num_weight_bits_imag, weight_exponent_imag,\ + sign_mode_imag = opt_weights_imag + + # arguments for complex dense block + params = {'shape': shape, + 'neuron_params': neuron_params, + 'weight_real': weight_real, + 'weight_imag': weight_imag, + 'num_weight_bits_real': num_weight_bits_real, + 'num_weight_bits_imag': num_weight_bits_imag, + 'weight_exponent_real': weight_exponent_real, + 'weight_exponent_imag': weight_exponent_imag, + 'sign_mode_real': sign_mode_real, + 'sign_mode_imag': sign_mode_imag, + 'input_message_bits': input_message_bits} + + proc = ComplexDense(**params) + else: + weight = layer_config['weight'] + if weight.ndim == 1: + weight = weight.reshape(shape[0], -1) + + opt_weights = optimize_weight_bits(weight) + weight, num_weight_bits, weight_exponent, sign_mode = opt_weights + + # arguments for dense block + params = {'shape': shape, + 'neuron_params': neuron_params, + 'weight': weight, + 'num_weight_bits': num_weight_bits, + 'weight_exponent': weight_exponent, + 'sign_mode': sign_mode, + 'input_message_bits': input_message_bits} + + # optional arguments + if 'bias' in layer_config.keys(): + params['bias'] = layer_config['bias'] + + proc = Dense(**params) table_entry = Network._table_str(type_str='Dense', width=1, height=1, channel=shape[0], delay='delay' in layer_config.keys()) - - return Dense(**params), table_entry + return proc, table_entry @staticmethod def create_conv(layer_config: h5py.Group, @@ -437,7 +486,6 @@ def _create(self) -> List[AbstractProcess]: reset_offset = self.reset_offset + 1 # time starts from 1 in hardware for i in range(num_layers): layer_type = layer_config[i]['type'] - if layer_type == 'input': table = None if 'neuron' in layer_config[i].keys(): @@ -510,7 +558,21 @@ def _create(self) -> List[AbstractProcess]: else: if len(layers) > 1: layers[-2].out.connect(layers[-1].inp) - + elif layer_type == "dense_comp": + layer, table = self.create_complex_dense( + layer_config=layer_config[i], + input_message_bits=input_message_bits + ) + layers.append(layer) + input_message_bits = layer.output_message_bits + if flatten_next: + layers[-2].out.transpose([2, 1, 0]).flatten().connect( + layers[-1].inp + ) + flatten_next = False + else: + if len(layers) > 1: + layers[-2].out.connect(layers[-1].inp) elif layer_type == 'average': raise NotImplementedError(f'{layer_type} is not implemented.') diff --git a/src/lava/lib/dl/netx/utils.py b/src/lava/lib/dl/netx/utils.py index 6b71a5ff..bdb2d0f5 100644 --- a/src/lava/lib/dl/netx/utils.py +++ b/src/lava/lib/dl/netx/utils.py @@ -7,6 +7,7 @@ import h5py import numpy as np from enum import IntEnum, unique +import torch @unique @@ -45,9 +46,9 @@ def __init__( self.array_keys = [ 'shape', 'stride', 'padding', 'dilation', 'groups', 'delay', 'iDecay', 'refDelay', 'scaleRho', 'tauRho', 'theta', 'vDecay', - 'vThMant', 'wgtExp', + 'vThMant', 'wgtExp', 'sinDecay', 'cosDecay', "complex_synapse" ] - self.copy_keys = ['weight', 'bias'] + self.copy_keys = ['weight', 'bias', 'weight/real', 'weight/imag'] def keys(self) -> h5py._hl.base.KeysViewHDF5: return self.f.keys() diff --git a/src/lava/lib/dl/slayer/block/base.py b/src/lava/lib/dl/slayer/block/base.py index ce25ee78..e505adc4 100644 --- a/src/lava/lib/dl/slayer/block/base.py +++ b/src/lava/lib/dl/slayer/block/base.py @@ -546,17 +546,19 @@ def weight(s): def delay(d): return torch.floor(d.delay).flatten().cpu().data.numpy() - # dense descriptors handle.create_dataset( 'type', (1, ), 'S10', ['dense'.encode('ascii', 'ignore')] ) + handle.create_dataset('shape', data=np.array(self.neuron.shape)) handle.create_dataset('inFeatures', data=self.synapse.in_channels) handle.create_dataset('outFeatures', data=self.synapse.out_channels) if self.synapse.weight_norm_enabled: self.synapse.disable_weight_norm() + if hasattr(self.synapse, 'imag'): # complex synapse + handle.create_dataset("complex_synapse", data=np.array(True)) handle.create_dataset( 'weight/real', data=weight(self.synapse.real) @@ -566,6 +568,7 @@ def delay(d): data=weight(self.synapse.imag) ) else: + handle.create_dataset("complex_synapse", data=np.array(False)) handle.create_dataset('weight', data=weight(self.synapse)) # bias diff --git a/src/lava/lib/dl/slayer/neuron/rf.py b/src/lava/lib/dl/slayer/neuron/rf.py index 260527b7..3817992f 100644 --- a/src/lava/lib/dl/slayer/neuron/rf.py +++ b/src/lava/lib/dl/slayer/neuron/rf.py @@ -39,8 +39,8 @@ def neuron_params(device_params, scale=1 << 6, p_scale=1 << 12): dictionary of neuron parameters that can be used to initialize neuron class. """ - sin_decay = device_params['sinDecay'] / p_scale, - cos_decay = device_params['cosDecay'] / p_scale, + sin_decay = device_params['sinDecay'] / p_scale + cos_decay = device_params['cosDecay'] / p_scale decay = 1 - np.sqrt(sin_decay ** 2 + cos_decay ** 2) frequency = np.arctan2(sin_decay, cos_decay) / 2 / np.pi return { diff --git a/src/lava/lib/dl/slayer/neuron/rf_iz.py b/src/lava/lib/dl/slayer/neuron/rf_iz.py index 6740a2f3..893758c7 100644 --- a/src/lava/lib/dl/slayer/neuron/rf_iz.py +++ b/src/lava/lib/dl/slayer/neuron/rf_iz.py @@ -41,8 +41,8 @@ def neuron_params(device_params, scale=1 << 6, p_scale=1 << 12): dictionary of neuron parameters that can be used to initialize neuron class. """ - sin_decay = device_params['sinDecay'] / p_scale, - cos_decay = device_params['cosDecay'] / p_scale, + sin_decay = device_params['sinDecay'] / p_scale + cos_decay = device_params['cosDecay'] / p_scale decay = 1 - np.sqrt(sin_decay ** 2 + cos_decay ** 2) frequency = np.arctan2(sin_decay, cos_decay) / 2 / np.pi return { diff --git a/src/lava/lib/dl/slayer/synapse/complex.py b/src/lava/lib/dl/slayer/synapse/complex.py index a3ee4bf5..302e5da1 100644 --- a/src/lava/lib/dl/slayer/synapse/complex.py +++ b/src/lava/lib/dl/slayer/synapse/complex.py @@ -105,6 +105,10 @@ def __init__( weight_scale, weight_norm, pre_hook_fx ) + self.in_channels = self.real.in_channels + self.out_channels = self.real.out_channels + self.weight_norm_enabled = self.real.weight_norm_enabled + class Conv(ComplexLayer): """Convolution complex-synapse layer. diff --git a/tests/lava/lib/dl/netx/gts/complex_dense/current.npy b/tests/lava/lib/dl/netx/gts/complex_dense/current.npy new file mode 100644 index 00000000..b903e853 Binary files /dev/null and b/tests/lava/lib/dl/netx/gts/complex_dense/current.npy differ diff --git a/tests/lava/lib/dl/netx/gts/complex_dense/in.npy b/tests/lava/lib/dl/netx/gts/complex_dense/in.npy new file mode 100644 index 00000000..3d8672ec Binary files /dev/null and b/tests/lava/lib/dl/netx/gts/complex_dense/in.npy differ diff --git a/tests/lava/lib/dl/netx/gts/complex_dense/out.npy b/tests/lava/lib/dl/netx/gts/complex_dense/out.npy new file mode 100644 index 00000000..bb9a2ac7 Binary files /dev/null and b/tests/lava/lib/dl/netx/gts/complex_dense/out.npy differ diff --git a/tests/lava/lib/dl/netx/gts/complex_dense/weight_img.npy b/tests/lava/lib/dl/netx/gts/complex_dense/weight_img.npy new file mode 100644 index 00000000..dfec817f Binary files /dev/null and b/tests/lava/lib/dl/netx/gts/complex_dense/weight_img.npy differ diff --git a/tests/lava/lib/dl/netx/gts/complex_dense/weight_r.npy b/tests/lava/lib/dl/netx/gts/complex_dense/weight_r.npy new file mode 100644 index 00000000..6046421e Binary files /dev/null and b/tests/lava/lib/dl/netx/gts/complex_dense/weight_r.npy differ diff --git a/tests/lava/lib/dl/netx/test_blocks.py b/tests/lava/lib/dl/netx/test_blocks.py index 2664b7a6..fd968545 100644 --- a/tests/lava/lib/dl/netx/test_blocks.py +++ b/tests/lava/lib/dl/netx/test_blocks.py @@ -16,10 +16,13 @@ from lava.proc.io.source import RingBuffer as SendProcess from lava.proc.io.sink import RingBuffer as ReceiveProcess from lava.proc.lif.process import LIF +from lava.proc.rf.process import RF +from lava.proc.rf_iz.process import RF_IZ from lava.proc.sdn.process import Sigma, Delta, SigmaDelta from lava.proc.conv import utils -from lava.lib.dl.netx.blocks.process import Dense, Conv, Input +from lava.lib.dl.netx.blocks.process import Dense, Conv, Input, ComplexDense,\ + ComplexInput verbose = True if (('-v' in sys.argv) or ('--verbose' in sys.argv)) else False @@ -224,6 +227,110 @@ def test_dense(self) -> None: ) +class TestRFBlocks(unittest.TestCase): + + def test_input(self) -> None: + """Tests input rf block driven by known input.""" + num_steps = 2000 + rf_params = {'vth': 1.1, + 'period': 7, + 'state_exp': 6, + 'decay_bits': 12, + 'alpha': .05} + + input_blk = ComplexInput( + shape=(200,), + neuron_params={'neuron_proc': RF_IZ, **rf_params}, + ) + source = SendProcess(data=np.load(root + '/gts/complex_dense/in.npy')) + source.s_out.connect(input_blk.inp) + sink = ReceiveProcess(shape=input_blk.out.shape, buffer=num_steps) + input_blk.out.connect(sink.a_in) + + run_condition = RunSteps(num_steps=num_steps) + run_config = TestRunConfig(select_tag='fixed_pt') + input_blk.run(condition=run_condition, run_cfg=run_config) + output = sink.data.get() + input_blk.stop() + + gt = np.load(root + '/gts/complex_dense/current.npy') + + error = np.abs(output - gt).sum() + if verbose: + print('Input spike error:', error) + if HAVE_DISPLAY: + plt.figure() + out_ae = np.argwhere(output.reshape((-1, num_steps)) > 0) + gt_ae = np.argwhere(gt.reshape((-1, num_steps)) > 0) + plt.plot(gt_ae[:, 1], + gt_ae[:, 0], + '.', markersize=15, label='Ground Truth') + plt.plot(out_ae[:, 1], out_ae[:, 0], '.', label='Input Block') + plt.xlabel('Time') + plt.ylabel('Neuron ID') + plt.legend() + plt.show() + + self.assertTrue( + error == 0, + f'Output spike and ground truth do not match for Input block. ' + f'Found {output[output != gt] = } and {gt[output != gt] = }. ' + f'Error was {error}.' + ) + + def test_dense(self) -> None: + """Tests RF dense block driven by known input.""" + num_steps = 2000 + rf_params = {'vth': 25, + 'period': 11, + 'state_exp': 6, + 'decay_bits': 12, + 'alpha': .05} + + dense_blk = ComplexDense( + shape=(256,), + neuron_params={'neuron_proc': RF, **rf_params}, + weight_real=np.load(root + '/gts/complex_dense/weight_r.npy'), + weight_imag=np.load(root + '/gts/complex_dense/weight_img.npy'), + ) + + source = SendProcess(data=np.load(root + '/gts/complex_dense/in.npy')) + sink = ReceiveProcess(shape=dense_blk.out.shape, buffer=num_steps) + source.s_out.connect(dense_blk.inp) + dense_blk.out.connect(sink.a_in) + + run_condition = RunSteps(num_steps=num_steps) + run_config = TestRunConfig(select_tag='fixed_pt') + dense_blk.run(condition=run_condition, run_cfg=run_config) + s = sink.data.get() + dense_blk.stop() + + s_gt = np.load(root + '/gts/complex_dense/out.npy') + s_error = np.abs(s - s_gt).sum() + + if verbose: + print('Dense spike error:', s_error) + if HAVE_DISPLAY: + plt.figure() + out_ae = np.argwhere(s.reshape((-1, num_steps)) > 0) + gt_ae = np.argwhere(s_gt.reshape((-1, num_steps)) > 0) + plt.plot(gt_ae[:, 1], + gt_ae[:, 0], + '.', markersize=15, label='Ground Truth') + plt.plot(out_ae[:, 1], out_ae[:, 0], '.', label='Input Block') + plt.xlabel('Time') + plt.ylabel('Neuron ID') + plt.legend() + plt.show() + + self.assertTrue( + s_error == 0, + f'Output spike and ground truth do not match for Dense block. ' + f'Found {s[s != s_gt] = } and {s_gt[s != s_gt] = }. ' + f'Error was {s_error}.' + ) + + class TestSDNBlocks(unittest.TestCase): def test_input(self) -> None: """Tests SDN input block driven by known input."""