diff --git a/examples/howto/plot_connectivity.py b/examples/howto/plot_connectivity.py index eec6e747f5..373cfa278b 100644 --- a/examples/howto/plot_connectivity.py +++ b/examples/howto/plot_connectivity.py @@ -72,6 +72,7 @@ def get_network(probability=1.0): net = jones_2009_model(add_drives_from_params=True) net.clear_connectivity() + net.clear_drives() # Pyramidal cell connections location, receptor = 'distal', 'ampa' diff --git a/hnn_core/network.py b/hnn_core/network.py index 1c2865d918..79b2d86a87 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -1261,23 +1261,77 @@ def add_connection(self, src_gids, target_gids, loc, receptor, self.connectivity.append(deepcopy(conn)) - def clear_connectivity(self): - """Remove all connections defined in Network.connectivity + def _get_expected_connectivities(self, src_types='all'): + """Return expected connectivities left after clearng connections. + + Parameters + ---------- + src_types : list | all + Connection source types to be cleared + + Returns + ------- + int + Number of connections left after the deletion + operation """ - connectivity = list() - for conn in self.connectivity: - if conn['src_type'] in self.external_drives.keys(): - connectivity.append(conn) - self.connectivity = connectivity - - def clear_drives(self): - """Remove all drives defined in Network.connectivity""" - connectivity = list() - for conn in self.connectivity: - if conn['src_type'] not in self.external_drives.keys(): - connectivity.append(conn) - self.external_drives = dict() - self.connectivity = connectivity + if src_types == 'all': + return 0 + deleted_connectivities = 0 + for src_type in src_types: + deleted_connectivities += len(pick_connection(self, + src_gids=src_type)) + return len(self.connectivity) - deleted_connectivities + + def clear_connectivity(self, src_types="all"): + """Remove connections with src_type in Network.connectivity. + + Parameters + ---------- + src_types : list | 'all' | 'external' | 'internal' + Source types of connections to be cleared + 'all' - Clear all connections (Default) + 'external' - Clear connections from the external drives to the + local network + 'internal' - Clear connections between cells of the local network + + """ + if src_types == "all": + src_types = list(self.gid_ranges.keys()) + elif src_types == "external": + src_types = self.drive_names + elif src_types == "internal": + src_types = list((src_type for src_type in self.gid_ranges.keys() + if src_type not in self.drive_names)) + _validate_type(src_types, list, 'src_types', 'list, drives, local') + # Finding connection indices to be deleted + conn_idxs = list() + for src_type in src_types: + conn_idxs.extend(pick_connection(self, src_gids=src_type)) + + # Deleting the indices + for conn_idx in sorted(conn_idxs, reverse=True): + del self.connectivity[conn_idx] + + def clear_drives(self, drive_names='all'): + """Remove all drives defined in Network.connectivity. + + Parameters + ---------- + drive_names : list | 'all' + The drive_names to remove + """ + if drive_names == 'all': + drive_names = self.drive_names + _validate_type(drive_names, (list,)) + for drive_name in drive_names: + del self.external_drives[drive_name] + self.clear_connectivity(src_types=drive_names) + + @property + def drive_names(self): + """Returns a list containing names of all external drives.""" + return list(self.external_drives.keys()) def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, method='psa', min_distance=0.5): diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 6a6c5da9d4..42f8303596 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -788,10 +788,43 @@ def test_network_connectivity(): pick_connection(**kwargs) # Test removing connections from net.connectivity - # Needs to be updated if number of drives change in preceeding tests - net.clear_connectivity() - assert len(net.connectivity) == 4 # 2 drives x 2 target cell types + # Test invalid argument type + with pytest.raises(TypeError, match="src_types must be an instance of"): + net.clear_connectivity(src_types=10) + + # Test Clearing connections of local src_types + + # Deleting all connections with src_type as 'L2_pyramidal' + expected_connectivities = (net._get_expected_connectivities( + src_types=['L2_pyramidal'])) + net.clear_connectivity(src_types=['L2_pyramidal']) + assert len(net.connectivity) == expected_connectivities + + # Deleting all local connections + src_types = list() + external_drives = net.drive_names + for conn in net.connectivity: + if (conn['src_type'] not in src_types and + conn['src_type'] not in external_drives): + src_types.append(conn['src_type']) + expected_connectivities = (net._get_expected_connectivities( + src_types=src_types)) + net.clear_connectivity(src_types='internal') + assert len(net.connectivity) == expected_connectivities + + # Testing deletion of a custom number of drives + + # Deleting one external drive + all_drives = net.drive_names + drives_to_be_deleted = all_drives[0:1] + expected_connectivities = (net._get_expected_connectivities( + src_types=drives_to_be_deleted)) + net.clear_drives(drive_names=drives_to_be_deleted) + assert len(net.connectivity) == expected_connectivities + + # Deleting all external drives net.clear_drives() + # All internal and external connections are deleted assert len(net.connectivity) == 0 with pytest.warns(UserWarning, match='No connections'):