Skip to content

Commit

Permalink
Add PyDiGraph.neighbors_undirected (#1254)
Browse files Browse the repository at this point in the history
* Add `PyDiGraph.neighbors_undirected`

* Add reno

* add stub

* review comments

- additional test comparing w/ to_undirected
- example in docstring

* Apply suggestions from code review

---------

Co-authored-by: Ivan Carvalho <[email protected]>
  • Loading branch information
Cryoris and IvanIsCoding authored Jul 24, 2024
1 parent b6f0ff5 commit 8b5d38b
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
6 changes: 6 additions & 0 deletions releasenotes/notes/neighbors-undirected-087b032745ec002d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Added a new method :meth:`~rustworkx.PyDiGraph.neighbors_undirected` to
obtain the neighbors of a node in a directed graph, irrespective of the
edge directionality.
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,7 @@ class PyDiGraph(Generic[_S, _T]):
def make_symmetric(self, edge_payload_fn: Callable[[_T], _T] | None = ...) -> None: ...
def merge_nodes(self, u: int, v: int, /) -> None: ...
def neighbors(self, node: int, /) -> NodeIndices: ...
def neighbors_undirected(self, node: int, /) -> NodeIndices: ...
def node_indexes(self) -> NodeIndices: ...
def node_indices(self) -> NodeIndices: ...
def nodes(self) -> list[_S]: ...
Expand Down
32 changes: 32 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,38 @@ impl PyDiGraph {
}
}

/// Get the direction-agnostic neighbors (i.e. successors and predecessors) of a node.
///
/// This is functionally equivalent to converting the directed graph to an undirected
/// graph, and calling ``neighbors`` thereon. For example::
///
/// import rustworkx
///
/// dag = rustworkx.generators.directed_cycle_graph(num_nodes=10, bidirectional=False)
///
/// node = 3
/// neighbors = dag.neighbors_undirected(node)
/// same_neighbors = dag.to_undirected().neighbors(node)
///
/// assert sorted(neighbors) == sorted(same_neighbors)
///
/// :param int node: The index of the node to get the neighbors of
///
/// :returns: A list of the neighbor node indices
/// :rtype: NodeIndices
#[pyo3(text_signature = "(self, node, /)")]
pub fn neighbors_undirected(&self, node: usize) -> NodeIndices {
NodeIndices {
nodes: self
.graph
.neighbors_undirected(NodeIndex::new(node))
.map(|node| node.index())
.collect::<HashSet<usize>>()
.drain()
.collect(),
}
}

/// Get the successor indices of a node.
///
/// This will return a list of the node indicies for the succesors of
Expand Down
22 changes: 22 additions & 0 deletions tests/digraph/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import unittest

import rustworkx
import rustworkx.generators


class TestAdj(unittest.TestCase):
Expand Down Expand Up @@ -57,3 +58,24 @@ def test_no_neighbor(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
self.assertEqual([], dag.neighbors(node_a))

def test_undirected_neighbors(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", {"a": 1})

directed = dag.neighbors(node_b)
self.assertEqual([], directed)

undirected = dag.neighbors_undirected(node_b)
self.assertEqual([node_a], undirected)

def test_undirected_neighbors_cycle(self):
num_nodes = 10
dag = rustworkx.generators.directed_cycle_graph(num_nodes, bidirectional=False)
undirected_dag = dag.to_undirected()

for node in dag.node_indices():
undirected_neighbors = dag.neighbors_undirected(node)
expected_neighbors = undirected_dag.neighbors(node)
self.assertEqual(sorted(undirected_neighbors), sorted(expected_neighbors))

0 comments on commit 8b5d38b

Please sign in to comment.