-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
178 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Dict, Generic, List, Protocol, TypeVar | ||
|
||
import networkx as nx | ||
from typing_extensions import Self | ||
|
||
NodeID = str | ||
|
||
|
||
class NodeWithID(Protocol): | ||
@property | ||
def id(self) -> NodeID: | ||
... | ||
|
||
|
||
TNode = TypeVar("TNode", bound=NodeWithID) | ||
|
||
|
||
class Graph(Generic[TNode]): | ||
def __init__(self): | ||
self.graph = nx.DiGraph() | ||
self.nodes: Dict[NodeID, TNode] = {} | ||
|
||
def add_node(self, node: TNode) -> Self: | ||
self.graph.add_node(node.id) | ||
self.nodes[node.id] = node | ||
return self | ||
|
||
def add_edge(self, from_id: NodeID, to_id: NodeID) -> Self: | ||
if from_id not in self.nodes or to_id not in self.nodes: | ||
raise ValueError("Add node before adding edges") | ||
|
||
self.graph.add_edge(from_id, to_id) | ||
return self | ||
|
||
def add_subgraph(self, subgraph: Graph) -> Self: | ||
self.nodes.update(subgraph.nodes) | ||
self.graph = nx.compose(self.graph, subgraph.graph) | ||
return self | ||
|
||
def get_successors(self, node_id: NodeID, recursively=False) -> List[NodeID]: | ||
if recursively: | ||
return [ | ||
successor_id | ||
for successor_id in nx.dfs_tree(self.graph, source=node_id).nodes() | ||
if successor_id != node_id | ||
] | ||
else: | ||
return list(self.graph.successors(node_id)) | ||
|
||
def get_predecessor(self, node_id: NodeID) -> NodeID: | ||
predecessors = list(self.graph.predecessors(node_id)) | ||
if len(predecessors) > 1: | ||
raise ValueError( | ||
f"Tried to get a single predecessor of node with several predecessors. NodeID: {node_id}, " | ||
f"Predecessors: {', '.join(predecessors)}" | ||
) | ||
return predecessors[0] | ||
|
||
@property | ||
def root(self) -> NodeID: | ||
return list(nx.topological_sort(self.graph))[0] | ||
|
||
def get_node(self, node_id: NodeID) -> TNode: | ||
return self.nodes[node_id] | ||
|
||
@property | ||
def sorted_node_ids(self) -> List[NodeID]: | ||
return list(nx.topological_sort(self.graph)) | ||
|
||
def __iter__(self): | ||
return iter(self.graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
import pytest | ||
from libecalc.common.graph import Graph, NodeID | ||
|
||
|
||
@dataclass | ||
class Node: | ||
node_id: NodeID | ||
some_data: Optional[str] = None | ||
|
||
@property | ||
def id(self): | ||
return self.node_id | ||
|
||
|
||
@pytest.fixture | ||
def graph_with_subgraph(): | ||
subgraph = ( | ||
Graph() | ||
.add_node(Node(node_id="3")) | ||
.add_node(Node(node_id="4", some_data="test")) | ||
.add_edge(from_id="3", to_id="4") | ||
) | ||
|
||
graph = ( | ||
Graph() | ||
.add_subgraph(subgraph) | ||
.add_node(Node(node_id="2")) | ||
.add_node(Node(node_id="1")) | ||
.add_edge(from_id="1", to_id="2") | ||
.add_edge("2", "3") | ||
) | ||
|
||
return graph | ||
|
||
|
||
class TestGraph: | ||
def test_get_successors(self): | ||
graph = ( | ||
Graph() | ||
.add_node(Node(node_id="1")) | ||
.add_node(Node(node_id="2")) | ||
.add_node(Node(node_id="3")) | ||
.add_edge(from_id="1", to_id="2") | ||
.add_edge(from_id="2", to_id="3") | ||
) | ||
assert graph.get_successors("1") == ["2"] | ||
assert graph.get_successors("1", recursively=True) == ["2", "3"] | ||
assert graph.get_successors("2", recursively=True) == ["3"] | ||
assert graph.get_successors("3", recursively=True) == [] | ||
|
||
def test_get_predecessors(self): | ||
graph = ( | ||
Graph() | ||
.add_node(Node(node_id="1")) | ||
.add_node(Node(node_id="2")) | ||
.add_node(Node(node_id="3")) | ||
.add_edge(from_id="1", to_id="2") | ||
.add_edge(from_id="2", to_id="3") | ||
) | ||
|
||
assert graph.get_predecessor("3") == "2" | ||
assert graph.get_predecessor("2") == "1" | ||
|
||
graph.add_edge("1", "3") | ||
with pytest.raises(ValueError) as exc_info: | ||
graph.get_predecessor("3") | ||
|
||
assert str(exc_info.value) == ( | ||
"Tried to get a single predecessor of node with several predecessors. NodeID: " "3, Predecessors: 2, 1" | ||
) | ||
|
||
def test_subgraph(self, graph_with_subgraph): | ||
assert len(graph_with_subgraph.nodes) == 4 | ||
assert set(graph_with_subgraph.nodes.keys()) == {"1", "2", "3", "4"} | ||
assert graph_with_subgraph.get_predecessor("3") == "2" | ||
|
||
def test_topological_sort(self, graph_with_subgraph): | ||
assert graph_with_subgraph.sorted_node_ids == ["1", "2", "3", "4"] | ||
|
||
def test_root(self, graph_with_subgraph): | ||
assert graph_with_subgraph.root == "1" | ||
|
||
def test_get_node_with_data(self, graph_with_subgraph): | ||
""" | ||
Assert returned node is the same node that we passed in. | ||
Args: | ||
graph_with_subgraph: | ||
Returns: | ||
""" | ||
node = graph_with_subgraph.get_node("4") | ||
# TODO: fix typing | ||
# Uncomment line below to see the type, no import needed, install mypy in poetry venv, run poetry shell, then | ||
# run mypy on the file 'mypy <path to file>'. | ||
# reveal_type(node) | ||
assert isinstance(node, Node) | ||
assert node.some_data == "test" |