diff --git a/src/libecalc/common/graph.py b/src/libecalc/common/graph.py new file mode 100644 index 000000000..6409303da --- /dev/null +++ b/src/libecalc/common/graph.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Dict, Generic, List, Protocol, TypeVar + +import networkx as nx +from typing_extensions import Self + +NodeID = str + + +class NodeWithID(Protocol): + @property + def id(self) -> NodeID: + ... + + +TNode = TypeVar("TNode", bound=NodeWithID) + + +class Graph(Generic[TNode]): + def __init__(self): + self.graph = nx.DiGraph() + self.nodes: Dict[NodeID, TNode] = {} + + def add_node(self, node: TNode) -> Self: + self.graph.add_node(node.id) + self.nodes[node.id] = node + return self + + def add_edge(self, from_id: NodeID, to_id: NodeID) -> Self: + if from_id not in self.nodes or to_id not in self.nodes: + raise ValueError("Add node before adding edges") + + self.graph.add_edge(from_id, to_id) + return self + + def add_subgraph(self, subgraph: Graph) -> Self: + self.nodes.update(subgraph.nodes) + self.graph = nx.compose(self.graph, subgraph.graph) + return self + + def get_successors(self, node_id: NodeID, recursively=False) -> List[NodeID]: + if recursively: + return [ + successor_id + for successor_id in nx.dfs_tree(self.graph, source=node_id).nodes() + if successor_id != node_id + ] + else: + return list(self.graph.successors(node_id)) + + def get_predecessor(self, node_id: NodeID) -> NodeID: + predecessors = list(self.graph.predecessors(node_id)) + if len(predecessors) > 1: + raise ValueError( + f"Tried to get a single predecessor of node with several predecessors. NodeID: {node_id}, " + f"Predecessors: {', '.join(predecessors)}" + ) + return predecessors[0] + + @property + def root(self) -> NodeID: + return list(nx.topological_sort(self.graph))[0] + + def get_node(self, node_id: NodeID) -> TNode: + return self.nodes[node_id] + + @property + def sorted_node_ids(self) -> List[NodeID]: + return list(nx.topological_sort(self.graph)) + + def __iter__(self): + return iter(self.graph) diff --git a/src/libecalc/dto/component_graph.py b/src/libecalc/dto/component_graph.py index 1cd5775d7..339d03209 100644 --- a/src/libecalc/dto/component_graph.py +++ b/src/libecalc/dto/component_graph.py @@ -1,55 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List - -import networkx as nx +from typing import List from libecalc import dto from libecalc.common.component_info.component_level import ComponentLevel +from libecalc.common.graph import Graph, NodeID from libecalc.dto.base import ComponentType from libecalc.dto.node_info import NodeInfo -if TYPE_CHECKING: - from libecalc.dto.components import ComponentDTO - -NodeID = str - - -class ComponentGraph: - def __init__(self): - self.graph = nx.DiGraph() - self.nodes: Dict[NodeID, ComponentDTO] = {} - - def add_node(self, node: ComponentDTO): - self.graph.add_node(node.id) - self.nodes[node.id] = node - - def add_edge(self, from_id: NodeID, to_id: NodeID): - if from_id not in self.nodes or to_id not in self.nodes: - raise ValueError("Add node before adding edges") - - self.graph.add_edge(from_id, to_id) - - def add_subgraph(self, subgraph: ComponentGraph): - self.nodes.update(subgraph.nodes) - self.graph = nx.compose(self.graph, subgraph.graph) - - def get_successors(self, node_id: NodeID, recursively=False) -> List[NodeID]: - if recursively: - return [ - successor_id - for successor_id in nx.dfs_tree(self.graph, source=node_id).nodes() - if successor_id != node_id - ] - else: - return list(self.graph.successors(node_id)) - - def get_predecessor(self, node_id: NodeID) -> NodeID: - predecessors = list(self.graph.predecessors(node_id)) - if len(predecessors) > 1: - raise ValueError("Component with several parents encountered.") - return predecessors[0] +class ComponentGraph(Graph): def get_parent_installation_id(self, node_id: NodeID) -> NodeID: """ Simple helper function to get the installation of any component with id @@ -69,10 +29,6 @@ def get_parent_installation_id(self, node_id: NodeID) -> NodeID: parent_id = self.get_predecessor(node_id) return self.get_parent_installation_id(parent_id) - @property - def root(self) -> NodeID: - return list(nx.topological_sort(self.graph))[0] - def get_node_info(self, node_id: NodeID) -> NodeInfo: component_dto = self.nodes[node_id] if isinstance(component_dto, dto.Asset): @@ -97,9 +53,6 @@ def get_node_info(self, node_id: NodeID) -> NodeInfo: component_level=component_level, ) - def get_node(self, node_id: NodeID) -> ComponentDTO: - return self.nodes[node_id] - def get_node_id_by_name(self, name: str) -> NodeID: for node in self.nodes.values(): if node.name == name: @@ -109,10 +62,3 @@ def get_node_id_by_name(self, name: str) -> NodeID: def get_nodes_of_type(self, component_type: ComponentType) -> List[NodeID]: return [node.id for node in self.nodes.values() if node.component_type == component_type] - - @property - def sorted_node_ids(self) -> List[NodeID]: - return list(nx.topological_sort(self.graph)) - - def __iter__(self): - return iter(self.graph) diff --git a/src/tests/libecalc/common/test_graph.py b/src/tests/libecalc/common/test_graph.py new file mode 100644 index 000000000..f3a93e785 --- /dev/null +++ b/src/tests/libecalc/common/test_graph.py @@ -0,0 +1,102 @@ +from dataclasses import dataclass +from typing import Optional + +import pytest +from libecalc.common.graph import Graph, NodeID + + +@dataclass +class Node: + node_id: NodeID + some_data: Optional[str] = None + + @property + def id(self): + return self.node_id + + +@pytest.fixture +def graph_with_subgraph(): + subgraph = ( + Graph() + .add_node(Node(node_id="3")) + .add_node(Node(node_id="4", some_data="test")) + .add_edge(from_id="3", to_id="4") + ) + + graph = ( + Graph() + .add_subgraph(subgraph) + .add_node(Node(node_id="2")) + .add_node(Node(node_id="1")) + .add_edge(from_id="1", to_id="2") + .add_edge("2", "3") + ) + + return graph + + +class TestGraph: + def test_get_successors(self): + graph = ( + Graph() + .add_node(Node(node_id="1")) + .add_node(Node(node_id="2")) + .add_node(Node(node_id="3")) + .add_edge(from_id="1", to_id="2") + .add_edge(from_id="2", to_id="3") + ) + assert graph.get_successors("1") == ["2"] + assert graph.get_successors("1", recursively=True) == ["2", "3"] + assert graph.get_successors("2", recursively=True) == ["3"] + assert graph.get_successors("3", recursively=True) == [] + + def test_get_predecessors(self): + graph = ( + Graph() + .add_node(Node(node_id="1")) + .add_node(Node(node_id="2")) + .add_node(Node(node_id="3")) + .add_edge(from_id="1", to_id="2") + .add_edge(from_id="2", to_id="3") + ) + + assert graph.get_predecessor("3") == "2" + assert graph.get_predecessor("2") == "1" + + graph.add_edge("1", "3") + with pytest.raises(ValueError) as exc_info: + graph.get_predecessor("3") + + assert str(exc_info.value) == ( + "Tried to get a single predecessor of node with several predecessors. NodeID: " "3, Predecessors: 2, 1" + ) + + def test_subgraph(self, graph_with_subgraph): + assert len(graph_with_subgraph.nodes) == 4 + assert set(graph_with_subgraph.nodes.keys()) == {"1", "2", "3", "4"} + assert graph_with_subgraph.get_predecessor("3") == "2" + + def test_topological_sort(self, graph_with_subgraph): + assert graph_with_subgraph.sorted_node_ids == ["1", "2", "3", "4"] + + def test_root(self, graph_with_subgraph): + assert graph_with_subgraph.root == "1" + + def test_get_node_with_data(self, graph_with_subgraph): + """ + Assert returned node is the same node that we passed in. + + Args: + graph_with_subgraph: + + Returns: + + """ + node = graph_with_subgraph.get_node("4") + # TODO: fix typing + # Uncomment line below to see the type, no import needed, install mypy in poetry venv, run poetry shell, then + # run mypy on the file 'mypy '. + # reveal_type(node) + assert isinstance(node, Node) + assert node.some_data == "test"