Skip to content

Commit

Permalink
refactor: generic graph class
Browse files Browse the repository at this point in the history
  • Loading branch information
jsolaas committed Nov 10, 2023
1 parent 9629f22 commit 6f63e40
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 57 deletions.
73 changes: 73 additions & 0 deletions src/libecalc/common/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from typing import Dict, Generic, List, Protocol, TypeVar

import networkx as nx
from typing_extensions import Self

NodeID = str


class NodeWithID(Protocol):
@property
def id(self) -> NodeID:
...


TNode = TypeVar("TNode", bound=NodeWithID)


class Graph(Generic[TNode]):
def __init__(self):
self.graph = nx.DiGraph()
self.nodes: Dict[NodeID, TNode] = {}

def add_node(self, node: TNode) -> Self:
self.graph.add_node(node.id)
self.nodes[node.id] = node
return self

def add_edge(self, from_id: NodeID, to_id: NodeID) -> Self:
if from_id not in self.nodes or to_id not in self.nodes:
raise ValueError("Add node before adding edges")

self.graph.add_edge(from_id, to_id)
return self

def add_subgraph(self, subgraph: Graph) -> Self:
self.nodes.update(subgraph.nodes)
self.graph = nx.compose(self.graph, subgraph.graph)
return self

def get_successors(self, node_id: NodeID, recursively=False) -> List[NodeID]:
if recursively:
return [
successor_id
for successor_id in nx.dfs_tree(self.graph, source=node_id).nodes()
if successor_id != node_id
]
else:
return list(self.graph.successors(node_id))

def get_predecessor(self, node_id: NodeID) -> NodeID:
predecessors = list(self.graph.predecessors(node_id))
if len(predecessors) > 1:
raise ValueError(
f"Tried to get a single predecessor of node with several predecessors. NodeID: {node_id}, "
f"Predecessors: {', '.join(predecessors)}"
)
return predecessors[0]

@property
def root(self) -> NodeID:
return list(nx.topological_sort(self.graph))[0]

def get_node(self, node_id: NodeID) -> TNode:
return self.nodes[node_id]

@property
def sorted_node_ids(self) -> List[NodeID]:
return list(nx.topological_sort(self.graph))

def __iter__(self):
return iter(self.graph)
60 changes: 3 additions & 57 deletions src/libecalc/dto/component_graph.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List

import networkx as nx
from typing import List

from libecalc import dto
from libecalc.common.component_info.component_level import ComponentLevel
from libecalc.common.graph import Graph, NodeID
from libecalc.dto.base import ComponentType
from libecalc.dto.node_info import NodeInfo

if TYPE_CHECKING:
from libecalc.dto.components import ComponentDTO

NodeID = str


class ComponentGraph:
def __init__(self):
self.graph = nx.DiGraph()
self.nodes: Dict[NodeID, ComponentDTO] = {}

def add_node(self, node: ComponentDTO):
self.graph.add_node(node.id)
self.nodes[node.id] = node

def add_edge(self, from_id: NodeID, to_id: NodeID):
if from_id not in self.nodes or to_id not in self.nodes:
raise ValueError("Add node before adding edges")

self.graph.add_edge(from_id, to_id)

def add_subgraph(self, subgraph: ComponentGraph):
self.nodes.update(subgraph.nodes)
self.graph = nx.compose(self.graph, subgraph.graph)

def get_successors(self, node_id: NodeID, recursively=False) -> List[NodeID]:
if recursively:
return [
successor_id
for successor_id in nx.dfs_tree(self.graph, source=node_id).nodes()
if successor_id != node_id
]
else:
return list(self.graph.successors(node_id))

def get_predecessor(self, node_id: NodeID) -> NodeID:
predecessors = list(self.graph.predecessors(node_id))
if len(predecessors) > 1:
raise ValueError("Component with several parents encountered.")
return predecessors[0]

class ComponentGraph(Graph):
def get_parent_installation_id(self, node_id: NodeID) -> NodeID:
"""
Simple helper function to get the installation of any component with id
Expand All @@ -69,10 +29,6 @@ def get_parent_installation_id(self, node_id: NodeID) -> NodeID:
parent_id = self.get_predecessor(node_id)
return self.get_parent_installation_id(parent_id)

@property
def root(self) -> NodeID:
return list(nx.topological_sort(self.graph))[0]

def get_node_info(self, node_id: NodeID) -> NodeInfo:
component_dto = self.nodes[node_id]
if isinstance(component_dto, dto.Asset):
Expand All @@ -97,9 +53,6 @@ def get_node_info(self, node_id: NodeID) -> NodeInfo:
component_level=component_level,
)

def get_node(self, node_id: NodeID) -> ComponentDTO:
return self.nodes[node_id]

def get_node_id_by_name(self, name: str) -> NodeID:
for node in self.nodes.values():
if node.name == name:
Expand All @@ -109,10 +62,3 @@ def get_node_id_by_name(self, name: str) -> NodeID:

def get_nodes_of_type(self, component_type: ComponentType) -> List[NodeID]:
return [node.id for node in self.nodes.values() if node.component_type == component_type]

@property
def sorted_node_ids(self) -> List[NodeID]:
return list(nx.topological_sort(self.graph))

def __iter__(self):
return iter(self.graph)
102 changes: 102 additions & 0 deletions src/tests/libecalc/common/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from dataclasses import dataclass
from typing import Optional

import pytest
from libecalc.common.graph import Graph, NodeID


@dataclass
class Node:
node_id: NodeID
some_data: Optional[str] = None

@property
def id(self):
return self.node_id


@pytest.fixture
def graph_with_subgraph():
subgraph = (
Graph()
.add_node(Node(node_id="3"))
.add_node(Node(node_id="4", some_data="test"))
.add_edge(from_id="3", to_id="4")
)

graph = (
Graph()
.add_subgraph(subgraph)
.add_node(Node(node_id="2"))
.add_node(Node(node_id="1"))
.add_edge(from_id="1", to_id="2")
.add_edge("2", "3")
)

return graph


class TestGraph:
def test_get_successors(self):
graph = (
Graph()
.add_node(Node(node_id="1"))
.add_node(Node(node_id="2"))
.add_node(Node(node_id="3"))
.add_edge(from_id="1", to_id="2")
.add_edge(from_id="2", to_id="3")
)
assert graph.get_successors("1") == ["2"]
assert graph.get_successors("1", recursively=True) == ["2", "3"]
assert graph.get_successors("2", recursively=True) == ["3"]
assert graph.get_successors("3", recursively=True) == []

def test_get_predecessors(self):
graph = (
Graph()
.add_node(Node(node_id="1"))
.add_node(Node(node_id="2"))
.add_node(Node(node_id="3"))
.add_edge(from_id="1", to_id="2")
.add_edge(from_id="2", to_id="3")
)

assert graph.get_predecessor("3") == "2"
assert graph.get_predecessor("2") == "1"

graph.add_edge("1", "3")
with pytest.raises(ValueError) as exc_info:
graph.get_predecessor("3")

assert str(exc_info.value) == (
"Tried to get a single predecessor of node with several predecessors. NodeID: " "3, Predecessors: 2, 1"
)

def test_subgraph(self, graph_with_subgraph):
assert len(graph_with_subgraph.nodes) == 4
assert set(graph_with_subgraph.nodes.keys()) == {"1", "2", "3", "4"}
assert graph_with_subgraph.get_predecessor("3") == "2"

def test_topological_sort(self, graph_with_subgraph):
assert graph_with_subgraph.sorted_node_ids == ["1", "2", "3", "4"]

def test_root(self, graph_with_subgraph):
assert graph_with_subgraph.root == "1"

def test_get_node_with_data(self, graph_with_subgraph):
"""
Assert returned node is the same node that we passed in.
Args:
graph_with_subgraph:
Returns:
"""
node = graph_with_subgraph.get_node("4")
# TODO: fix typing
# Uncomment line below to see the type, no import needed, install mypy in poetry venv, run poetry shell, then
# run mypy on the file 'mypy <path to file>'.
# reveal_type(node)
assert isinstance(node, Node)
assert node.some_data == "test"

0 comments on commit 6f63e40

Please sign in to comment.