diff --git a/examples/howto/plot_connectivity.py b/examples/howto/plot_connectivity.py index eec6e747f..373cfa278 100644 --- a/examples/howto/plot_connectivity.py +++ b/examples/howto/plot_connectivity.py @@ -72,6 +72,7 @@ def get_network(probability=1.0): net = jones_2009_model(add_drives_from_params=True) net.clear_connectivity() + net.clear_drives() # Pyramidal cell connections location, receptor = 'distal', 'ampa' diff --git a/hnn_core/network.py b/hnn_core/network.py index 777406d1b..33cca7ff1 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -1253,23 +1253,70 @@ def add_connection(self, src_gids, target_gids, loc, receptor, self.connectivity.append(deepcopy(conn)) - def clear_connectivity(self): - """Remove all connections defined in Network.connectivity + def _clear_connectivity(self, src_types): + """Remove connections with src_type in Network.connectivity. + + Parameters + ---------- + src_types : list + Source types of connections to be cleared + """ - connectivity = list() - for conn in self.connectivity: - if conn['src_type'] in self.external_drives.keys(): - connectivity.append(conn) - self.connectivity = connectivity - - def clear_drives(self): - """Remove all drives defined in Network.connectivity""" - connectivity = list() - for conn in self.connectivity: - if conn['src_type'] not in self.external_drives.keys(): - connectivity.append(conn) - self.external_drives = dict() - self.connectivity = connectivity + _validate_type(src_types, list, 'src_types', 'list, drives, local') + # Finding connection indices to be deleted + conn_idxs = list() + for src_type in src_types: + conn_idxs.extend(pick_connection(self, src_gids=src_type)) + + # Deleting the indices + for conn_idx in sorted(conn_idxs, reverse=True): + del self.connectivity[conn_idx] + + def clear_connectivity(self, src_types='all'): + """Clear connections between cells of the local network + + Parameters + ---------- + src_types : list | 'all' + Source types of connections to be cleared + 'all' - Clear all connections between cells of the local network + + """ + if src_types == "all": + src_types = list((src_type for src_type in self.gid_ranges.keys() + if src_type not in self.drive_names)) + _validate_type(src_types, list, 'src_types', 'list') + drive_names = self.drive_names + for src_type in src_types: + if src_type in drive_names: + raise ValueError('src_types contains %s which is an external ' + 'drive.' % (src_type,)) + self._clear_connectivity(src_types) + + def clear_drives(self, drive_names='all'): + """Remove drives defined in Network.connectivity. + + Parameters + ---------- + drive_names : list | 'all' + The drive_names to remove + """ + if drive_names == 'all': + drive_names = self.drive_names + _validate_type(drive_names, list, 'drive_names', 'list') + all_drive_names = self.drive_names + for drive_name in drive_names: + if drive_name not in all_drive_names: + raise ValueError('drive_names contains %s which is not an ' + 'external drive.' % (drive_name,)) + for drive_name in drive_names: + del self.external_drives[drive_name] + self.clear_connectivity(src_types=drive_names) + + @property + def drive_names(self): + """Returns a list containing names of all external drives.""" + return list(self.external_drives.keys()) def add_electrode_array(self, name, electrode_pos, *, conductivity=0.3, method='psa', min_distance=0.5): diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index 3a96198ac..bcf27092d 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -510,6 +510,28 @@ def test_network_drives_legacy(): n_bursty_sources) +def get_expected_connectivities(net, src_types): + """Return expected connectivities left after clearing connections. + + Parameters + ---------- + net : The network instance + src_types : list + Connection source types to be cleared + + Returns + ------- + int + Number of connections left after the deletion + operation + """ + deleted_connectivities = 0 + for src_type in src_types: + deleted_connectivities += len(pick_connection(net, + src_gids=src_type)) + return len(net.connectivity) - deleted_connectivities + + def test_network_connectivity(): """Test manipulation of local network connectivity.""" params = read_params(params_fname) @@ -781,10 +803,55 @@ def test_network_connectivity(): pick_connection(**kwargs) # Test removing connections from net.connectivity - # Needs to be updated if number of drives change in preceeding tests - net.clear_connectivity() - assert len(net.connectivity) == 4 # 2 drives x 2 target cell types + # Test invalid argument type + with pytest.raises(TypeError, match="src_types must be an instance of"): + net.clear_connectivity(src_types=10) + + # Test Clearing connections of local src_types + + # Using clear_connections to delete a drive + with pytest.raises(ValueError, + match="src_types contains evdist1 which is an " + "external drive."): + net.clear_connectivity(src_types=['evdist1']) + + # Deleting all connections with src_type as 'L2_pyramidal' + expected_connectivities = (get_expected_connectivities( + net, src_types=['L2_pyramidal'])) + net.clear_connectivity(src_types=['L2_pyramidal']) + assert len(net.connectivity) == expected_connectivities + + # Deleting all local connections + src_types = list() + external_drives = net.drive_names + for conn in net.connectivity: + if (conn['src_type'] not in src_types and + conn['src_type'] not in external_drives): + src_types.append(conn['src_type']) + expected_connectivities = (get_expected_connectivities( + net, src_types=src_types)) + net.clear_connectivity(src_types='all') + assert len(net.connectivity) == expected_connectivities + + # Testing deletion of a custom number of drives + + # Using clear_drives to delete a local connection + with pytest.raises(ValueError, + match="drive_names contains L2_pyramidal which " + "is not an external drive."): + net.clear_drives(drive_names=['L2_pyramidal']) + + # Deleting one external drive + all_drives = net.drive_names + drives_to_be_deleted = all_drives[0:1] + expected_connectivities = (get_expected_connectivities( + net, src_types=drives_to_be_deleted)) + net.clear_drives(drive_names=drives_to_be_deleted) + assert len(net.connectivity) == expected_connectivities + + # Deleting all external drives net.clear_drives() + # All internal and external connections are deleted assert len(net.connectivity) == 0 with pytest.warns(UserWarning, match='No connections'):