Skip to content

Commit

Permalink
Add Digraph tool class
Browse files Browse the repository at this point in the history
  • Loading branch information
nikobockerman committed Jul 29, 2024
1 parent 6b79de1 commit bdbdcc8
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
71 changes: 71 additions & 0 deletions adventofcode/tooling/digraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from __future__ import annotations

import typing
from dataclasses import dataclass
from functools import cache
from typing import Iterable, Protocol, runtime_checkable


@runtime_checkable
class NodeId(typing.Hashable, Protocol):
pass


@dataclass(kw_only=True, slots=True)
class Digraph[Id: NodeId, N]:
nodes: dict[Id, N] # TODO: Consider replacing with a frozendict
arcs: tuple[DigraphArc[Id], ...]

def get_arcs_to(self, node_id: Id, /) -> list[DigraphArc[Id]]:
return _get_arcs_to_node(node_id, self.arcs)

def get_arcs_from(self, node_id: Id, /) -> list[DigraphArc[Id]]:
return _get_arcs_from_node(node_id, self.arcs)


class DigraphArc[Id: NodeId](Protocol):
@property
def from_(self) -> Id: ...
@property
def to(self) -> Id: ...


@dataclass(frozen=True, slots=True)
class Arc[Id: NodeId]:
from_: Id
to: Id


class DigraphCreator[Id: NodeId, N]:
def __init__(self) -> None:
self._nodes: dict[Id, N] = dict()
self._arcs: list[DigraphArc[Id]] = list()

def add_node(self, node_id: Id, node: N, /) -> None:
if node_id in self._nodes:
raise ValueError(node_id)
self._nodes[node_id] = node

def add_arc(self, arc: DigraphArc[Id], /) -> None:
if arc.from_ not in self._nodes:
raise ValueError(arc.from_)
if arc.to not in self._nodes:
raise ValueError(arc.to)
self._arcs.append(arc)

def create(self) -> Digraph[Id, N]:
return Digraph(nodes=self._nodes, arcs=tuple(self._arcs))


@cache
def _get_arcs_to_node[Id: NodeId](
node_id: Id, arcs: Iterable[DigraphArc[Id]]
) -> list[DigraphArc[Id]]:
return list(arc for arc in arcs if arc.to == node_id)


@cache
def _get_arcs_from_node[Id: NodeId](
node_id: Id, arcs: Iterable[DigraphArc[Id]]
) -> list[DigraphArc[Id]]:
return list(arc for arc in arcs if arc.from_ == node_id)
94 changes: 94 additions & 0 deletions tests/test_digraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from dataclasses import dataclass
from typing import assert_type

from adventofcode.tooling.digraph import Arc, Digraph, DigraphCreator


def test_digraph_creator_simple() -> None:
creator = DigraphCreator[int, int]()
creator.add_node(1, 11)
creator.add_node(2, 22)
creator.add_arc(Arc(1, 2))
digraph = creator.create()
assert digraph.nodes == {1: 11, 2: 22}
assert digraph.arcs == (Arc(1, 2),)
assert_type(digraph.nodes, dict[int, int])


def test_digraph_creator_two_types() -> None:
creator = DigraphCreator[int | str, int | str]()
creator.add_node("a", "aa")
creator.add_node(1, 11)
creator.add_node("b", 22)
creator.add_arc(Arc("a", 1))
creator.add_arc(Arc("a", "b"))
creator.add_arc(Arc(1, "b"))
digraph = creator.create()
assert digraph.nodes == {"a": "aa", 1: 11, "b": 22}
assert digraph.arcs == (Arc("a", 1), Arc("a", "b"), Arc(1, "b"))
assert_type(digraph.nodes, dict[int | str, int | str])


def test_digraph_creator_multiple_inherited_classes() -> None:
@dataclass
class Base:
name: str

class Child1(Base):
pass

class Child2(Base):
pass

creator = DigraphCreator[str, Child1 | Child2]()
creator.add_node("a", Child1("a"))
creator.add_node("b", Child2("b"))
creator.add_arc(Arc("a", "b"))
digraph = creator.create()
assert digraph.nodes == {"a": Child1("a"), "b": Child2("b")}
assert digraph.arcs == (Arc("a", "b"),)
assert_type(digraph.nodes, dict[str, Child1 | Child2])


def test_digraph_get_arcs() -> None:
digraph = Digraph[int, int](
nodes={1: 11, 2: 22, 3: 33, 4: 44},
arcs=tuple((Arc(1, 2), Arc(1, 3), Arc(2, 3), Arc(3, 1))),
)
assert digraph.get_arcs_from(1) == [Arc(1, 2), Arc(1, 3)]
assert digraph.get_arcs_from(2) == [Arc(2, 3)]
assert digraph.get_arcs_from(3) == [Arc(3, 1)]
assert digraph.get_arcs_from(4) == []
assert digraph.get_arcs_to(1) == [Arc(3, 1)]
assert digraph.get_arcs_to(2) == [Arc(1, 2)]
assert digraph.get_arcs_to(3) == [Arc(1, 3), Arc(2, 3)]
assert digraph.get_arcs_to(4) == []


def test_digraph_weighted_arcs() -> None:
@dataclass(frozen=True)
class WeightedArc:
from_: str
to: str
weight: int

digraph_creator = DigraphCreator[str, int]()
digraph_creator.add_node("a", 1)
digraph_creator.add_node("b", 2)
digraph_creator.add_node("c", 3)
digraph_creator.add_arc(WeightedArc("a", "b", 3))
digraph_creator.add_arc(WeightedArc("a", "c", 4))
digraph_creator.add_arc(WeightedArc("b", "c", 5))
digraph = digraph_creator.create()
assert digraph.get_arcs_from("a") == [
WeightedArc("a", "b", 3),
WeightedArc("a", "c", 4),
]
assert digraph.get_arcs_from("b") == [WeightedArc("b", "c", 5)]
assert digraph.get_arcs_from("c") == []
assert digraph.get_arcs_to("a") == []
assert digraph.get_arcs_to("b") == [WeightedArc("a", "b", 3)]
assert digraph.get_arcs_to("c") == [
WeightedArc("a", "c", 4),
WeightedArc("b", "c", 5),
]

0 comments on commit bdbdcc8

Please sign in to comment.