Skip to content

Commit

Permalink
Merge pull request #2601 from med-ayssar/remove-nodeCollection-from-n…
Browse files Browse the repository at this point in the history
…ode-class

Remove NodeCollection pointer from Node class
  • Loading branch information
heplesser authored Sep 13, 2023
2 parents c2258df + 89eb27c commit c04c3dc
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 42 deletions.
23 changes: 0 additions & 23 deletions nestkernel/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
#include "nest_names.h"
#include "nest_time.h"
#include "nest_types.h"
#include "node_collection.h"
#include "secondary_event.h"

// Includes from sli:
Expand Down Expand Up @@ -202,10 +201,6 @@ class Node
*/
size_t get_node_id() const;

/**
* Return lockpointer to the NodeCollection that created this node.
*/
NodeCollectionPTR get_nc() const;

/**
* Return model ID of the node.
Expand Down Expand Up @@ -879,11 +874,6 @@ class Node
private:
void set_node_id_( size_t ); //!< Set global node id

/**
* Set the original NodeCollection of this node.
*/
void set_nc_( NodeCollectionPTR );

/** Return a new dictionary datum .
*
* This function is called by get_status_base() and returns a new
Expand Down Expand Up @@ -958,8 +948,6 @@ class Node
bool frozen_; //!< node shall not be updated if true
bool initialized_; //!< state and buffers have been initialized
bool node_uses_wfr_; //!< node uses waveform relaxation method

NodeCollectionPTR nc_ptr_; //!< Original NodeCollection of this node, used to extract node-specific metadata
};

inline bool
Expand Down Expand Up @@ -1028,11 +1016,6 @@ Node::get_node_id() const
return node_id_;
}

inline NodeCollectionPTR
Node::get_nc() const
{
return nc_ptr_;
}

inline void
Node::set_node_id_( size_t i )
Expand All @@ -1041,12 +1024,6 @@ Node::set_node_id_( size_t i )
}


inline void
Node::set_nc_( NodeCollectionPTR nc_ptr )
{
nc_ptr_ = nc_ptr;
}

inline int
Node::get_model_id() const
{
Expand Down
25 changes: 25 additions & 0 deletions nestkernel/node_collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,17 @@ class NodeCollection
*/
virtual bool has_proxies() const = 0;

/**
* return the first stored ID (i.e, ID at index zero) inside the NodeCollection
*/
size_t get_first() const;

/**
* return the last stored ID inside the NodeCollection
*/
size_t get_last() const;


private:
unsigned long fingerprint_; //!< Unique identity of the kernel that created the NodeCollection
static NodeCollectionPTR create_();
Expand Down Expand Up @@ -667,6 +678,20 @@ NodeCollection::set_metadata( NodeCollectionMetadataPTR )
throw KernelException( "Cannot set Metadata on this type of NodeCollection." );
}

inline size_t
NodeCollection::get_first() const
{
return ( *begin() ).node_id;
}

inline size_t
NodeCollection::get_last() const
{
size_t offset = size() - 1;
return ( *( begin() + offset ) ).node_id;
}


inline nc_const_iterator&
nc_const_iterator::operator+=( const size_t n )
{
Expand Down
52 changes: 42 additions & 10 deletions nestkernel/node_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "node_manager.h"

// C++ includes:
#include <algorithm>
#include <set>

// Includes from libnestutil:
Expand All @@ -47,6 +48,7 @@ namespace nest

NodeManager::NodeManager()
: local_nodes_( 1 )
, node_collection_container_()
, wfr_nodes_vec_()
, wfr_is_used_( false )
, wfr_network_size_( 0 ) // zero to force update
Expand All @@ -61,6 +63,7 @@ NodeManager::~NodeManager()
{
// We must destruct nodes here, since devices may need to close files.
destruct_nodes_();
clear_node_collection_container();
}

void
Expand All @@ -79,6 +82,7 @@ void
NodeManager::finalize()
{
destruct_nodes_();
clear_node_collection_container();
}

void
Expand Down Expand Up @@ -135,18 +139,19 @@ NodeManager::add_node( size_t model_id, long n )
.swap( exceptions_raised_ );

auto nc_ptr = NodeCollectionPTR( new NodeCollectionPrimitive( min_node_id, max_node_id, model_id ) );
append_node_collection_( nc_ptr );

if ( model->has_proxies() )
{
add_neurons_( *model, min_node_id, max_node_id, nc_ptr );
add_neurons_( *model, min_node_id, max_node_id );
}
else if ( not model->one_node_per_process() )
{
add_devices_( *model, min_node_id, max_node_id, nc_ptr );
add_devices_( *model, min_node_id, max_node_id );
}
else
{
add_music_nodes_( *model, min_node_id, max_node_id, nc_ptr );
add_music_nodes_( *model, min_node_id, max_node_id );
}

// check if any exceptions have been raised
Expand Down Expand Up @@ -181,9 +186,8 @@ NodeManager::add_node( size_t model_id, long n )
return nc_ptr;
}


void
NodeManager::add_neurons_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr )
NodeManager::add_neurons_( Model& model, size_t min_node_id, size_t max_node_id )
{
const size_t num_vps = kernel().vp_manager.get_num_virtual_processes();
// Upper limit for number of neurons per thread; in practice, either
Expand All @@ -210,7 +214,6 @@ NodeManager::add_neurons_( Model& model, size_t min_node_id, size_t max_node_id,
{
Node* node = model.create( t );
node->set_node_id_( node_id );
node->set_nc_( nc_ptr );
node->set_model_id( model.get_model_id() );
node->set_thread( t );
node->set_vp( vp );
Expand All @@ -231,7 +234,7 @@ NodeManager::add_neurons_( Model& model, size_t min_node_id, size_t max_node_id,
}

void
NodeManager::add_devices_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr )
NodeManager::add_devices_( Model& model, size_t min_node_id, size_t max_node_id )
{
const size_t n_per_thread = max_node_id - min_node_id + 1;

Expand All @@ -249,7 +252,6 @@ NodeManager::add_devices_( Model& model, size_t min_node_id, size_t max_node_id,

Node* node = model.create( t );
node->set_node_id_( node_id );
node->set_nc_( nc_ptr );
node->set_model_id( model.get_model_id() );
node->set_thread( t );
node->set_vp( kernel().vp_manager.thread_to_vp( t ) );
Expand All @@ -270,7 +272,7 @@ NodeManager::add_devices_( Model& model, size_t min_node_id, size_t max_node_id,
}

void
NodeManager::add_music_nodes_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr )
NodeManager::add_music_nodes_( Model& model, size_t min_node_id, size_t max_node_id )
{
#pragma omp parallel
{
Expand All @@ -286,7 +288,6 @@ NodeManager::add_music_nodes_( Model& model, size_t min_node_id, size_t max_node

Node* node = model.create( 0 );
node->set_node_id_( node_id );
node->set_nc_( nc_ptr );
node->set_model_id( model.get_model_id() );
node->set_thread( 0 );
node->set_vp( kernel().vp_manager.thread_to_vp( 0 ) );
Expand All @@ -307,6 +308,37 @@ NodeManager::add_music_nodes_( Model& model, size_t min_node_id, size_t max_node
} // omp parallel
}

NodeCollectionPTR
NodeManager::node_id_to_node_collection( const size_t node_id ) const
{
// find the largest ID in node_collection_last_ that is still smaller than node_id
auto it = std::lower_bound( node_collection_last_.begin(), node_collection_last_.end(), node_id );

// compute the position of the nodeCollection based on the position of the ID found above
size_t pos = it - node_collection_last_.begin();
return node_collection_container_.at( pos );
}

NodeCollectionPTR
NodeManager::node_id_to_node_collection( Node* node ) const
{
return node_id_to_node_collection( node->get_node_id() );
}

void
NodeManager::append_node_collection_( NodeCollectionPTR ncp )
{
node_collection_container_.push_back( ncp );
node_collection_last_.push_back( ncp->get_last() );
}

void
NodeManager::clear_node_collection_container()
{
node_collection_container_.clear();
node_collection_last_.clear();
}

NodeCollectionPTR
NodeManager::get_nodes( const DictionaryDatum& params, const bool local_only )
{
Expand Down
36 changes: 33 additions & 3 deletions nestkernel/node_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,20 @@ class NodeManager : public ManagerInterface
bool have_nodes_changed() const;
void set_have_nodes_changed( const bool changed );

/**
* @brief Map the node ID to its original primitive NodeCollection object.
* @param node_id The node ID
* @return The primitive NodeCollection object containing the node ID that falls in [first, last)
*/
NodeCollectionPTR node_id_to_node_collection( const size_t node_id ) const;

/**
* @brief Map the node to its original primitive NodeCollection object.
* @param node Node instance
* @return The primitive NodeCollection object containing the node with node ID falls in [first, last)
*/
NodeCollectionPTR node_id_to_node_collection( Node* node ) const;

private:
/**
* Initialize the network data structures.
Expand Down Expand Up @@ -280,7 +294,7 @@ class NodeManager : public ManagerInterface
* @param min_node_id node ID of first neuron to create.
* @param max_node_id node ID of last neuron to create (inclusive).
*/
void add_neurons_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr );
void add_neurons_( Model& model, size_t min_node_id, size_t max_node_id );

/**
* Add device nodes.
Expand All @@ -291,7 +305,7 @@ class NodeManager : public ManagerInterface
* @param min_node_id node ID of first neuron to create.
* @param max_node_id node ID of last neuron to create (inclusive).
*/
void add_devices_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr );
void add_devices_( Model& model, size_t min_node_id, size_t max_node_id );

/**
* Add MUSIC nodes.
Expand All @@ -303,7 +317,15 @@ class NodeManager : public ManagerInterface
* @param min_node_id node ID of first neuron to create.
* @param max_node_id node ID of last neuron to create (inclusive).
*/
void add_music_nodes_( Model& model, size_t min_node_id, size_t max_node_id, NodeCollectionPTR nc_ptr );
void add_music_nodes_( Model& model, size_t min_node_id, size_t max_node_id );

/**
* @brief Append the NodeCollection instance into the NodeManager::nodeCollection_container.
* @param ncp The NodeCollection instance.
*/
void append_node_collection_( NodeCollectionPTR ncp );

void clear_node_collection_container();

private:
/**
Expand All @@ -312,6 +334,14 @@ class NodeManager : public ManagerInterface
*/
std::vector< SparseNodeArray > local_nodes_;

std::vector< NodeCollectionPTR > node_collection_container_; //!< a vector of the original/primitive NodeCollection

std::vector< size_t >
node_collection_last_; //!< Store the ID of the last element in each NodeCollection instance.
//!< Mainly, the node_collection_last_ must be the same size as node_collection_container,
//!< where each element at position i in the nodeCollection_last_ is the last node ID
//!< stored in the node_collection_container_ at position i.

std::vector< std::vector< Node* > > wfr_nodes_vec_; //!< Nodelists for unfrozen nodes that
//!< use the waveform relaxation method
bool wfr_is_used_; //!< there is at least one node that uses
Expand Down
2 changes: 1 addition & 1 deletion nestkernel/parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ NodePosParameter::get_node_pos_( Node* node ) const
{
throw KernelException( "NodePosParameter: not node" );
}
NodeCollectionPTR nc = node->get_nc();
NodeCollectionPTR nc = kernel().node_manager.node_id_to_node_collection( node );
if ( not nc.get() )
{
throw KernelException( "NodePosParameter: not nc" );
Expand Down
7 changes: 2 additions & 5 deletions nestkernel/spatial.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,12 @@ get_position( NodeCollectionPTR layer_nc )
std::vector< double >
get_position( const size_t node_id )
{
Node* node = kernel().node_manager.get_node_or_proxy( node_id );

if ( not kernel().node_manager.is_local_node_id( node_id ) )
{
throw KernelException( "GetPosition is currently implemented for local nodes only." );
}

NodeCollectionPTR nc = node->get_nc();
NodeCollectionPTR nc = kernel().node_manager.node_id_to_node_collection( node_id );
NodeCollectionMetadataPTR meta = nc->get_metadata();

if ( not meta )
Expand Down Expand Up @@ -341,9 +339,8 @@ distance( const ArrayDatum conns )
throw KernelException( "Distance is currently implemented for local nodes only." );
}

Node* trgt_node = kernel().node_manager.get_node_or_proxy( trgt );

NodeCollectionPTR trgt_nc = trgt_node->get_nc();
NodeCollectionPTR trgt_nc = kernel().node_manager.node_id_to_node_collection( trgt );
NodeCollectionMetadataPTR meta = trgt_nc->get_metadata();

// distance is NaN if source, target is not spatially distributed
Expand Down

0 comments on commit c04c3dc

Please sign in to comment.