Skip to content

Commit

Permalink
Merge pull request #1485 from hakonsbm/nest3/nc_array_indexing
Browse files Browse the repository at this point in the history
Add array indexing to NodeCollections
  • Loading branch information
heplesser authored Apr 24, 2020
2 parents 3b22045 + a3cc3ed commit 15986ce
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 1 deletion.
34 changes: 34 additions & 0 deletions nestkernel/nest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,4 +546,38 @@ apply( const ParameterDatum& param, const DictionaryDatum& positions )
return param->apply( source_nc, target_tkns );
}

Datum*
node_collection_array_index( const Datum* datum, const long* array, unsigned long n )
{
const NodeCollectionDatum node_collection = *dynamic_cast< const NodeCollectionDatum* >( datum );
assert( node_collection->size() >= n );
std::vector< index > node_ids;
node_ids.reserve( n );

for ( auto node_ptr = array; node_ptr != array + n; ++node_ptr )
{
node_ids.push_back( node_collection->operator[]( *node_ptr ) );
}
return new NodeCollectionDatum( NodeCollection::create( node_ids ) );
}

Datum*
node_collection_array_index( const Datum* datum, const bool* array, unsigned long n )
{
const NodeCollectionDatum node_collection = *dynamic_cast< const NodeCollectionDatum* >( datum );
assert( node_collection->size() == n );
std::vector< index > node_ids;
node_ids.reserve( n );

auto nc_it = node_collection->begin();
for ( auto node_ptr = array; node_ptr != array + n; ++node_ptr, ++nc_it )
{
if ( *node_ptr )
{
node_ids.push_back( ( *nc_it ).node_id );
}
}
return new NodeCollectionDatum( NodeCollection::create( node_ids ) );
}

} // namespace nest
3 changes: 3 additions & 0 deletions nestkernel/nest.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,9 @@ double get_value( const ParameterDatum& param );
bool is_spatial( const ParameterDatum& param );
std::vector< double > apply( const ParameterDatum& param, const NodeCollectionDatum& nc );
std::vector< double > apply( const ParameterDatum& param, const DictionaryDatum& positions );

Datum* node_collection_array_index( const Datum* datum, const long* array, unsigned long n );
Datum* node_collection_array_index( const Datum* datum, const bool* array, unsigned long n );
}


Expand Down
12 changes: 12 additions & 0 deletions nestkernel/node_collection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,18 @@ NodeCollection::create( const index node_id )
return NodeCollection::create_( { node_id } );
}

NodeCollectionPTR
NodeCollection::create( const std::vector< index >& node_ids_vector )
{
if ( node_ids_vector.size() == 0 )
{
return NodeCollection::create_();
}
auto node_ids = node_ids_vector; // Create a copy to be able to sort
std::sort( node_ids.begin(), node_ids.end() );
return NodeCollection::create_( node_ids );
}

NodeCollectionPTR
NodeCollection::create_()
{
Expand Down
9 changes: 9 additions & 0 deletions nestkernel/node_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,15 @@ class NodeCollection
*/
static NodeCollectionPTR create( const index node_id );

/**
* Create a NodeCollection from an array of node IDs. Results in a primitive if the
* node IDs are homogeneous and contiguous, or a composite otherwise.
*
* @param node_ids Array of node IDs from which to create the NodeCollection
* @return a NodeCollection pointer to the created NodeCollection
*/
static NodeCollectionPTR create( const std::vector< index >& node_ids );

/**
* Check to see if the fingerprint of the NodeCollection matches that of the
* kernel.
Expand Down
31 changes: 30 additions & 1 deletion pynest/nest/lib/hl_api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,37 @@ def __getitem__(self, key):
if abs(key + (key >= 0)) > self.__len__():
raise IndexError('index value outside of the NodeCollection')
return sli_func('Take', self._datum, [key + (key >= 0)])
elif isinstance(key, (list, tuple)):
if len(key) == 0:
return NodeCollection([])
# Must check if elements are bool first, because bool inherits from int
if all(isinstance(x, bool) for x in key):
if len(key) != len(self):
raise IndexError('Bool index array must be the same length as NodeCollection')
np_key = numpy.array(key, dtype=numpy.bool)
# Checking that elements are not instances of bool too, because bool inherits from int
elif all(isinstance(x, int) and not isinstance(x, bool) for x in key):
np_key = numpy.array(key, dtype=numpy.uint64)
if len(numpy.unique(np_key)) != len(np_key):
raise ValueError('All node IDs in a NodeCollection have to be unique')
else:
raise TypeError('Indices must be integers or bools')
return take_array_index(self._datum, np_key)
elif isinstance(key, numpy.ndarray):
if len(key) == 0:
return NodeCollection([])
if len(key.shape) != 1:
raise TypeError('NumPy indices must one-dimensional')
is_booltype = numpy.issubdtype(key.dtype, numpy.dtype(bool).type)
if not (is_booltype or numpy.issubdtype(key.dtype, numpy.integer)):
raise TypeError('NumPy indices must be an array of integers or bools')
if is_booltype and len(key) != len(self):
raise IndexError('Bool index array must be the same length as NodeCollection')
if not is_booltype and len(numpy.unique(key)) != len(key):
raise ValueError('All node IDs in a NodeCollection have to be unique')
return take_array_index(self._datum, key)
else:
raise IndexError('only integers and slices are valid indices')
raise IndexError('only integers, slices, lists, tuples, and numpy arrays are valid indices')

def __contains__(self, node_id):
return sli_func('MemberQ', self._datum, node_id)
Expand Down
2 changes: 2 additions & 0 deletions pynest/nest/ll_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,15 @@
'sps',
'sr',
'stack_checker',
'take_array_index',
]


engine = kernel.NESTEngine()

sli_push = sps = engine.push
sli_pop = spp = engine.pop
take_array_index = engine.take_array_index
connect_arrays = engine.connect_arrays


Expand Down
56 changes: 56 additions & 0 deletions pynest/nest/tests/test_NodeCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,62 @@ def test_Create_accepts_empty_params_dict(self):
"""
nest.Create('iaf_psc_delta', params={})

def test_array_indexing(self):
"""NodeCollection array indexing"""
n = nest.Create('iaf_psc_alpha', 10)
cases = [[1, 2],
[2, 5],
[0, 2, 5, 7, 9],
(5, 2),
[]
]
fail_cases = [([5, 10, 15], IndexError), # Index not in NodeCollection
([2, 5.5], TypeError), # Not all indices are ints
([[2, 4], [6, 8]], TypeError), # Too many dimensions
([2, 2], ValueError), # Non-unique elements
]
if HAVE_NUMPY:
cases += [np.array(c) for c in cases]
fail_cases += [(np.array(c), e) for c, e in fail_cases]
for case in cases:
print(type(case), case)
ref = [i + 1 for i in case]
ref.sort()
sliced = n[case]
self.assertEqual(sliced.tolist(), ref)
for case, err in fail_cases:
print(type(case), case)
with self.assertRaises(err):
sliced = n[case]

def test_array_indexing_bools(self):
"""NodeCollection array indexing with bools"""
n = nest.Create('iaf_psc_alpha', 5)
cases = [[True for _ in range(len(n))],
[False for _ in range(len(n))],
[True, False, True, False, True],
]
fail_cases = [([True for _ in range(len(n)-1)], IndexError), # Too few bools
([True for _ in range(len(n)+1)], IndexError), # Too many bools
([[True, False], [True, False]], TypeError), # Too many dimensions
([True, False, 2.5, False, True], TypeError), # Not all indices are bools
([1, False, 1, False, 1], TypeError), # Mixing bools and ints
]
if HAVE_NUMPY:
cases += [np.array(c) for c in cases]
# Cutting off fail_cases before cases that mix bools and ints,
# because converting them to NumPy arrays converts bools to ints.
fail_cases += [(np.array(c), e) for c, e in fail_cases[:-2]]
for case in cases:
print(type(case), case)
ref = [i for i, b in zip(range(1, 11), case) if b]
sliced = n[case]
self.assertEqual(sliced.tolist(), ref)
for case, err in fail_cases:
print(type(case), case)
with self.assertRaises(err):
sliced = n[case]


def suite():
suite = unittest.makeSuite(TestNodeCollection, 'test')
Expand Down
2 changes: 2 additions & 0 deletions pynest/pynestkernel.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ cdef extern from "neststartup.h":
void c_set_communicator "set_communicator" (object) with gil

cdef extern from "nest.h" namespace "nest":
Datum* node_collection_array_index(const Datum* node_collection, const long* array, unsigned long n) except +
Datum* node_collection_array_index(const Datum* node_collection, const cbool* array, unsigned long n) except +
void connect_arrays( long* sources, long* targets, double* weights, double* delays, vector[string]& p_keys, double* p_values, size_t n, string syn_model ) except +

cdef extern from *:
Expand Down
38 changes: 38 additions & 0 deletions pynest/pynestkernel.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,44 @@ cdef class NESTEngine(object):

return ret

def take_array_index(self, node_collection, array):
if self.pEngine is NULL:
raise NESTErrors.PyNESTError("engine uninitialized")

if not (isinstance(node_collection, SLIDatum) and (<SLIDatum> node_collection).dtype == SLI_TYPE_NODECOLLECTION.decode()):
raise TypeError('node_collection must be a NodeCollection, got {}'.format(type(node_collection)))
if not isinstance(array, numpy.ndarray):
raise TypeError('array must be a 1-dimensional NumPy array of ints or bools, got {}'.format(type(array)))
if not array.ndim == 1:
raise TypeError('array must be a 1-dimensional NumPy array, got {}-dimensional NumPy array'.format(array.ndim))

# Get pointers to the first element in the Numpy array
cdef long[:] array_long_mv
cdef long* array_long_ptr

cdef cbool[:] array_bool_mv
cdef cbool* array_bool_ptr

cdef Datum* nc_datum = python_object_to_datum(node_collection)

try:
if array.dtype == numpy.bool:
# Boolean C-type arrays are not supported in NumPy, so we use an 8-bit integer array
array_bool_mv = numpy.ascontiguousarray(array, dtype=numpy.uint8)
array_bool_ptr = &array_bool_mv[0]
new_nc_datum = node_collection_array_index(nc_datum, array_bool_ptr, len(array))
return sli_datum_to_object(new_nc_datum)
elif numpy.issubdtype(array.dtype, numpy.integer):
array_long_mv = numpy.ascontiguousarray(array, dtype=numpy.long)
array_long_ptr = &array_long_mv[0]
new_nc_datum = node_collection_array_index(nc_datum, array_long_ptr, len(array))
return sli_datum_to_object(new_nc_datum)
else:
raise TypeError('array must be a NumPy array of ints or bools, got {}'.format(array.dtype))
except RuntimeError as e:
exceptionCls = getattr(NESTErrors, str(e))
raise exceptionCls('take_array_index', '') from None

def connect_arrays(self, sources, targets, weights, delays, synapse_model, syn_param_keys, syn_param_values):
"""Calls connect_arrays function, bypassing SLI to expose pointers to the NumPy arrays"""
if self.pEngine is NULL:
Expand Down

0 comments on commit 15986ce

Please sign in to comment.