Skip to content

Commit

Permalink
Merge pull request #199 from intel-innersource/revert-197-upstream_pr
Browse files Browse the repository at this point in the history
Revert "Vectorization of conv_as_sparse ncproc and enabling slicing of nodes"
  • Loading branch information
bamsumit authored Jul 2, 2022
2 parents 138beca + 246a702 commit 21c31ce
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 48 deletions.
4 changes: 3 additions & 1 deletion src/lava/magma/compiler/builders/py_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,9 @@ def set_resources(self,
data_ht: HomoTable = reg_view.reg[i].data[0]
field_names: ty.List[str] = data_ht.field_names
for f in field_names:
ev[f].append(data_ht[f].item())
fi: int = data_ht.field_names.index(f)
data = data_ht[fi]
ev[f].append(data)

reg.set(list(range(len(selected_indices))), **ev)

Expand Down
26 changes: 1 addition & 25 deletions src/lava/magma/core/nets/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,30 +101,6 @@ def connect(
"other Nodes instances and to NcOutPorts, not "
f"to type '{type(target)}'.")

def __getitem__(self, idx: ty.Union[int, slice, np.ndarray]) -> Nodes:
"""Slice the Node object.
Parameters
----------
idx : ty.Union[int, slice, np.ndarray]
Slicing description.
Returns
-------
Nodes
Node representing the sliced indices only.
"""
if type(idx) == int:
idx = np.array([idx])
elif type(idx) == slice:
idx = np.arange(start=idx.start if idx.start else 0,
stop=idx.stop,
step=idx.step)
if type(idx) == np.ndarray:
return self.htr_tbl.get_nodes(idx.flatten())
else:
raise IndexError


class NodesContainer:
"""Container defined by a List of Nodes pointer objects. Allows a List
Expand Down Expand Up @@ -706,7 +682,7 @@ def __getitem__(self, idx: ty.Union[int, slice]) -> HeteroTable:
if np.isscalar(idx) or isinstance(idx, np.signedinteger):
idx_temp = int(idx)
grp, grp_idx = self.data_idx[idx_temp].astype(int)
data = self.data[grp][grp_idx : grp_idx + 1]
data = self.data[grp][grp_idx]
table.add_homotable(data) # TODO: handle conn_idx changes
elif type(idx) == slice:
data_idx = self.data_idx[idx].view(np.ndarray).astype(int)
Expand Down
110 changes: 88 additions & 22 deletions src/lava/proc/conv/ncmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,106 @@ def allocate(self, net: Net):
padding=self.padding.var.get(),
dilation=self.dilation.var.get(),
group=self.groups.var.get())
wgts = scipy.sparse.csc_matrix((wgt, (dst, src)),
shape=process_size)

# Sort sparse weight in the order of input dimension
idx = np.argsort(src)
src = src[idx]
dst = dst[idx]
wgt = wgt[idx]
# Allocate dendritic accumulators
dend_acc_cfg: Nodes = net.dend_acc_cfg.allocate(shape=1,
num_delay_bits=0)
dend_acc: Nodes = net.dend_acc.allocate(shape=(process_size[0],))

# Allocate Input Axons
ax_in: Nodes = net.ax_in.allocate(shape=input_shape)
ax_in_cfg: Nodes = net.ax_in_cfg.allocate(
shape=1,
num_message_bits=self.num_message_bits.var.get())

# Allocate Synapses
syn: Nodes = net.syn.allocate(shape=wgt.shape,
weights=wgt,
delays=0)
syn_cfg: Nodes = net.syn_cfg.allocate_sparse(shape=1,
sign=sign_mode.value,
num_wgt_bits=num_wgt_bits,
num_dly_bits=0,
wgt_exp=wgt_exp)
# Allocate dendritic accumulators
dend_acc_cfg: Nodes = net.dend_acc_cfg.allocate(shape=1,
num_delay_bits=0)
dend_acc: Nodes = net.dend_acc.allocate(shape=(process_size[0],))

# Allocate input axons
ax_in_list: ty.List[Nodes] = []
syn_list: ty.List[Nodes] = []
col_idx_list: ty.List[np.ndarray] = []
for i in range(process_size[1]):
ax_in: Nodes = net.ax_in.allocate(shape=(1,))
ax_in_list.append(ax_in)

column_slice_wgt: scipy.sparse.csc_matrix = wgts[:, i]
column_non_zero_indices: np.ndarray = column_slice_wgt.indices

ax_in.connect(ax_in_cfg)
if len(column_non_zero_indices) == 0:
continue

column_tensor_wgt: np.ndarray = column_slice_wgt.data
# Allocate synapses
syn: Nodes = net.syn.allocate(
shape=(len(column_tensor_wgt),),
weights=column_tensor_wgt,
delays=0)
syn_list.append(syn)

ax_in.connect(syn)
syn.connect(syn_cfg)
# Connect synapse to corresponding dendritic accumulators
# TODO: DR/SS enable slicing of Nodes
# syn.connect(dend_acc[column_non_zero_indices])
# print(f'{i=}')
# print(f'{len(syn.global_idx)=}')
# print(f'{len(column_non_zero_indices)=}')
# print()
# net.syn.connect_heterotable(net.dend_acc,
# syn.global_idx,
# column_non_zero_indices)
col_idx_list.append(column_non_zero_indices)

col_idx = np.concatenate(col_idx_list)
# Connect InPort of Process to neurons
# Connect NodeGroups
# ToDo: This will only connect to the latest set of syn nodes
# returned by the loop. Modify so that ax_in connects to all
# allocated Synapses.
# ax_in.connect(syn)

# TODO: enable merging a list of nodes into a single node
# ax_in = Node.stack(ax_in_list)
# syn = Node.stack(syn_list)
# and change the next three lines
ax_in = Nodes(net.ax_in, np.arange(
process_size[1]), (process_size[1],))
self.s_in.connect(ax_in)
# Connect Nodes
ax_in.connect(ax_in_cfg)
ax_in[src].connect(syn)
syn.connect(syn_cfg)
syn.connect(dend_acc[dst])

# print(f'{net.syn=}')
# print(f'{net.syn.data_idx=}')
# print(f'{syn_list[0].global_idx=}')
# print(f'{syn_list[-1].global_idx=}')
# print(f'{len(net.syn.data_idx)=}')
# print(f'{len(net.syn)=}')
# print(f'{process_size=}')
# print(f'{np.prod(process_size)=}')
# syn = Nodes(net.syn,
# global_idx=np.arange(len(net.syn)),
# shape=(process_size[1], process_size[0]))
# # ax_in.connect(syn)
# syn.connect(dend_acc)
net.syn.connect_heterotable(net.dend_acc,
from_idx=np.arange(len(net.syn)),
to_idx=col_idx)
# Connect output axon to OutPort of Process
dend_acc.connect(dend_acc_cfg)
dend_acc.connect(self.a_out)

# for table in net.heterotables:
# print(table.name)
# print(table)
# # print(table.conn_idx.keys())
# for row in net.syn.conn_idx[net.dend_acc]:
# print(row)
# import matplotlib.pyplot as plt
# conn_idx = net.syn.conn_idx[net.dend_acc]
# image = np.zeros(np.max(conn_idx, axis=0) + 2)
# image[conn_idx[:, 0], conn_idx[:, 1]] = 1
# plt.savefig('conn_idx.png', dpi=600)
# plt.close()
# plt.imshow(image[-1000:])
# plt.show()

0 comments on commit 21c31ce

Please sign in to comment.