From 246a7020725b9b1288e8093b24bf25827dd3da9b Mon Sep 17 00:00:00 2001 From: bamsumit Date: Sat, 2 Jul 2022 00:22:21 -0700 Subject: [PATCH] Revert "Vectorization of conv_as_sparse ncproc and enabling slicing of nodes" --- .../magma/compiler/builders/py_builder.py | 4 +- src/lava/magma/core/nets/tables.py | 26 +---- src/lava/proc/conv/ncmodels.py | 110 ++++++++++++++---- 3 files changed, 92 insertions(+), 48 deletions(-) diff --git a/src/lava/magma/compiler/builders/py_builder.py b/src/lava/magma/compiler/builders/py_builder.py index b2c0ac2b4..994f9c456 100644 --- a/src/lava/magma/compiler/builders/py_builder.py +++ b/src/lava/magma/compiler/builders/py_builder.py @@ -1080,7 +1080,9 @@ def set_resources(self, data_ht: HomoTable = reg_view.reg[i].data[0] field_names: ty.List[str] = data_ht.field_names for f in field_names: - ev[f].append(data_ht[f].item()) + fi: int = data_ht.field_names.index(f) + data = data_ht[fi] + ev[f].append(data) reg.set(list(range(len(selected_indices))), **ev) diff --git a/src/lava/magma/core/nets/tables.py b/src/lava/magma/core/nets/tables.py index 248c018e7..073c846b0 100644 --- a/src/lava/magma/core/nets/tables.py +++ b/src/lava/magma/core/nets/tables.py @@ -101,30 +101,6 @@ def connect( "other Nodes instances and to NcOutPorts, not " f"to type '{type(target)}'.") - def __getitem__(self, idx: ty.Union[int, slice, np.ndarray]) -> Nodes: - """Slice the Node object. - - Parameters - ---------- - idx : ty.Union[int, slice, np.ndarray] - Slicing description. - - Returns - ------- - Nodes - Node representing the sliced indices only. - """ - if type(idx) == int: - idx = np.array([idx]) - elif type(idx) == slice: - idx = np.arange(start=idx.start if idx.start else 0, - stop=idx.stop, - step=idx.step) - if type(idx) == np.ndarray: - return self.htr_tbl.get_nodes(idx.flatten()) - else: - raise IndexError - class NodesContainer: """Container defined by a List of Nodes pointer objects. Allows a List @@ -706,7 +682,7 @@ def __getitem__(self, idx: ty.Union[int, slice]) -> HeteroTable: if np.isscalar(idx) or isinstance(idx, np.signedinteger): idx_temp = int(idx) grp, grp_idx = self.data_idx[idx_temp].astype(int) - data = self.data[grp][grp_idx : grp_idx + 1] + data = self.data[grp][grp_idx] table.add_homotable(data) # TODO: handle conn_idx changes elif type(idx) == slice: data_idx = self.data_idx[idx].view(np.ndarray).astype(int) diff --git a/src/lava/proc/conv/ncmodels.py b/src/lava/proc/conv/ncmodels.py index 5b8275c96..f6b0cafe8 100644 --- a/src/lava/proc/conv/ncmodels.py +++ b/src/lava/proc/conv/ncmodels.py @@ -139,40 +139,106 @@ def allocate(self, net: Net): padding=self.padding.var.get(), dilation=self.dilation.var.get(), group=self.groups.var.get()) + wgts = scipy.sparse.csc_matrix((wgt, (dst, src)), + shape=process_size) - # Sort sparse weight in the order of input dimension - idx = np.argsort(src) - src = src[idx] - dst = dst[idx] - wgt = wgt[idx] + # Allocate dendritic accumulators + dend_acc_cfg: Nodes = net.dend_acc_cfg.allocate(shape=1, + num_delay_bits=0) + dend_acc: Nodes = net.dend_acc.allocate(shape=(process_size[0],)) - # Allocate Input Axons - ax_in: Nodes = net.ax_in.allocate(shape=input_shape) ax_in_cfg: Nodes = net.ax_in_cfg.allocate( shape=1, num_message_bits=self.num_message_bits.var.get()) - - # Allocate Synapses - syn: Nodes = net.syn.allocate(shape=wgt.shape, - weights=wgt, - delays=0) syn_cfg: Nodes = net.syn_cfg.allocate_sparse(shape=1, sign=sign_mode.value, num_wgt_bits=num_wgt_bits, num_dly_bits=0, wgt_exp=wgt_exp) - # Allocate dendritic accumulators - dend_acc_cfg: Nodes = net.dend_acc_cfg.allocate(shape=1, - num_delay_bits=0) - dend_acc: Nodes = net.dend_acc.allocate(shape=(process_size[0],)) - + # Allocate input axons + ax_in_list: ty.List[Nodes] = [] + syn_list: ty.List[Nodes] = [] + col_idx_list: ty.List[np.ndarray] = [] + for i in range(process_size[1]): + ax_in: Nodes = net.ax_in.allocate(shape=(1,)) + ax_in_list.append(ax_in) + + column_slice_wgt: scipy.sparse.csc_matrix = wgts[:, i] + column_non_zero_indices: np.ndarray = column_slice_wgt.indices + + ax_in.connect(ax_in_cfg) + if len(column_non_zero_indices) == 0: + continue + + column_tensor_wgt: np.ndarray = column_slice_wgt.data + # Allocate synapses + syn: Nodes = net.syn.allocate( + shape=(len(column_tensor_wgt),), + weights=column_tensor_wgt, + delays=0) + syn_list.append(syn) + + ax_in.connect(syn) + syn.connect(syn_cfg) + # Connect synapse to corresponding dendritic accumulators + # TODO: DR/SS enable slicing of Nodes + # syn.connect(dend_acc[column_non_zero_indices]) + # print(f'{i=}') + # print(f'{len(syn.global_idx)=}') + # print(f'{len(column_non_zero_indices)=}') + # print() + # net.syn.connect_heterotable(net.dend_acc, + # syn.global_idx, + # column_non_zero_indices) + col_idx_list.append(column_non_zero_indices) + + col_idx = np.concatenate(col_idx_list) # Connect InPort of Process to neurons + # Connect NodeGroups + # ToDo: This will only connect to the latest set of syn nodes + # returned by the loop. Modify so that ax_in connects to all + # allocated Synapses. + # ax_in.connect(syn) + + # TODO: enable merging a list of nodes into a single node + # ax_in = Node.stack(ax_in_list) + # syn = Node.stack(syn_list) + # and change the next three lines + ax_in = Nodes(net.ax_in, np.arange( + process_size[1]), (process_size[1],)) self.s_in.connect(ax_in) - # Connect Nodes - ax_in.connect(ax_in_cfg) - ax_in[src].connect(syn) - syn.connect(syn_cfg) - syn.connect(dend_acc[dst]) + + # print(f'{net.syn=}') + # print(f'{net.syn.data_idx=}') + # print(f'{syn_list[0].global_idx=}') + # print(f'{syn_list[-1].global_idx=}') + # print(f'{len(net.syn.data_idx)=}') + # print(f'{len(net.syn)=}') + # print(f'{process_size=}') + # print(f'{np.prod(process_size)=}') + # syn = Nodes(net.syn, + # global_idx=np.arange(len(net.syn)), + # shape=(process_size[1], process_size[0])) + # # ax_in.connect(syn) + # syn.connect(dend_acc) + net.syn.connect_heterotable(net.dend_acc, + from_idx=np.arange(len(net.syn)), + to_idx=col_idx) # Connect output axon to OutPort of Process dend_acc.connect(dend_acc_cfg) dend_acc.connect(self.a_out) + + # for table in net.heterotables: + # print(table.name) + # print(table) + # # print(table.conn_idx.keys()) + # for row in net.syn.conn_idx[net.dend_acc]: + # print(row) + # import matplotlib.pyplot as plt + # conn_idx = net.syn.conn_idx[net.dend_acc] + # image = np.zeros(np.max(conn_idx, axis=0) + 2) + # image[conn_idx[:, 0], conn_idx[:, 1]] = 1 + # plt.savefig('conn_idx.png', dpi=600) + # plt.close() + # plt.imshow(image[-1000:]) + # plt.show()