-
Notifications
You must be signed in to change notification settings - Fork 370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add array indexing to NodeCollections #1485
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding these additional features to better integrate NEST with NumPy arrays!
I have added some minor comments inline.
Another comment not strictly related to your changes.
When working with lists or arrays, if an index outside of the NodeCollection is used, then the right error is thrown:
n = nest.Create('iaf_psc_alpha', 10)
n[[1, 5, 10]]
IndexError: pos points outside of the NodeCollection
nest-simulator/nestkernel/node_collection.h
Line 813 in 8bb59ac
throw std::out_of_range( "pos points outside of the NodeCollection" ); |
This does not happen when using integers or slices as indexes:
n = nest.Create('iaf_psc_alpha', 10)
n[10]
# OR
n[1:12]
~/workspace/nest-simulator/b/lib/python3.6/site-packages/nest/lib/hl_api_types.py in getitem(self, key)
216 step = 1 if key.step is None else key.step
217
--> 218 return sli_func('Take', self._datum, [start, stop, step])
219 elif isinstance(key, (int, numpy.integer)):
220 return sli_func('Take', self._datum, [key + (key >= 0)])~/workspace/nest-simulator/b/lib/python3.6/site-packages/nest/ll_api.py in sli_func(s, *args, **kwargs)
181 sli_push(args) # push array of arguments on SLI stack
182 sli_push(s) # push command string
--> 183 sli_run(slifun) # SLI support code to execute s on args
184 r = sli_pop() # return value is an array
185~/workspace/nest-simulator/b/lib/python3.6/site-packages/nest/ll_api.py in catching_sli_run(cmd)
131
132 exceptionCls = getattr(kernel.NESTErrors, errorname)
--> 133 raise exceptionCls(commandname, message)
134
135
NESTErrors.BadParameter: ('BadParameter in Take_g_a: stop <= size() required.', 'BadParameter', 'Take_g_a', ': stop <= size() required.')
I do not know if this PR is the place where to fix this inconsistency. Otherwise, I can create an ad-hoc PR for this.
pynest/pynestkernel.pyx
Outdated
new_nc_datum = node_collection_array_index(nc_datum, array_long_ptr, len(array)) | ||
return sli_datum_to_object(new_nc_datum) | ||
else: | ||
raise TypeError('array must be a NumPyArray of ints or bools, got {}'.format(array.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo. "NumPyArray" -> "NumPy array"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
if not (isinstance(node_collection, SLIDatum) and (<SLIDatum> node_collection).dtype == SLI_TYPE_NODECOLLECTION.decode()): | ||
raise TypeError('node_collection must be a NodeCollection, got {}'.format(type(node_collection))) | ||
if not isinstance(array, numpy.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we check here that array.ndim == 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I have added it.
pynest/pynestkernel.pyx
Outdated
if not (isinstance(node_collection, SLIDatum) and (<SLIDatum> node_collection).dtype == SLI_TYPE_NODECOLLECTION.decode()): | ||
raise TypeError('node_collection must be a NodeCollection, got {}'.format(type(node_collection))) | ||
if not isinstance(array, numpy.ndarray): | ||
raise TypeError('array must be a NumPy array of ints or bools, got {}'.format(type(array))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the above comment is addressed, you could add:
if not array.ndim == 1:
raise TypeError('array must be a 1-dimensional NumPy array, got {}-dimensional NumPy array'.format(array.ndim))
std::vector< index > node_ids; | ||
node_ids.reserve( n ); | ||
|
||
for ( auto node_ptr = array; node_ptr != array + n; ++node_ptr ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you are using auto
instead of long*
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that we don't care about the type of node_ptr
as long as it's the same as the type of array
. If the type of array
is changed, it only has to change in the definition of the function. Ensuring that they always have the same type can avoid some nasty bugs down the line if the type of array
is changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, got it.
node_ids.reserve( n ); | ||
|
||
auto nc_it = node_collection->begin(); | ||
for ( auto node_ptr = array; node_ptr != array + n; ++node_ptr, ++nc_it ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above: Is there a reason why you are using auto
instead of bool*
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
nestkernel/node_collection.cpp
Outdated
@@ -158,7 +170,7 @@ NodeCollection::create_( const std::vector< index >& node_ids ) | |||
|
|||
std::vector< NodeCollectionPrimitive > parts; | |||
|
|||
index old_node_id = 0; | |||
index old_node_id = current_first; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not relative to the array indexing introduced with this PR, right?
I assume it solves the bug that was not checking the repetition of the first element in the list. Something like nest.NodeCollection([2, 2, 3, 4])
was not throwing errors before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, it's not relevant for the array indexing. I can take it out and create a separate PR for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes are now in #1536.
@hakonsbm: I would only review once you addressed the comments by @alberto-antonietti and fixed the conflicts, OK? |
@alberto-antonietti Thanks for your review! I have addressed most of your comments, still working on the last one.
I agree the error message could be more informative here. If you create a PR for it, I'll be happy to review it. |
It is moved to a separate branch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes, I am going to take care of the PR for the error message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! 👍
This PR adds the ability to index NodeCollections with arrays of integers or bools, similar to advanced indexing in NumPy. The result is a new
NodeCollection
with the selected nodes. Note that aNodeCollection
can only contain unique node IDs. Array indexing supports lists, tuples, and NumPy arrays as indices.Example: