From 5fb10ee42cd9677a4a0d56cdd5ae2de992312c64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niko=20B=C3=B6ckerman?= Date: Mon, 29 Jul 2024 12:08:36 +0000 Subject: [PATCH] Add Digraph tool class --- adventofcode/tooling/digraph.py | 71 +++++++++++++++++++++++++ tests/test_digraph.py | 94 +++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 adventofcode/tooling/digraph.py create mode 100644 tests/test_digraph.py diff --git a/adventofcode/tooling/digraph.py b/adventofcode/tooling/digraph.py new file mode 100644 index 0000000..46ff3a8 --- /dev/null +++ b/adventofcode/tooling/digraph.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from functools import cache +from typing import Iterable, Protocol, runtime_checkable + + +@runtime_checkable +class NodeId(typing.Hashable, Protocol): + pass + + +@dataclass(kw_only=True, slots=True) +class Digraph[Id: NodeId, N]: + nodes: dict[Id, N] # TODO: Consider replacing with a frozendict + arcs: tuple[DigraphArc[Id], ...] + + def get_arcs_to(self, node_id: Id, /) -> list[DigraphArc[Id]]: + return _get_arcs_to_node(node_id, self.arcs) + + def get_arcs_from(self, node_id: Id, /) -> list[DigraphArc[Id]]: + return _get_arcs_from_node(node_id, self.arcs) + + +class DigraphArc[Id: NodeId](Protocol): + @property + def from_(self) -> Id: ... + @property + def to(self) -> Id: ... + + +@dataclass(frozen=True, slots=True) +class Arc[Id: NodeId]: + from_: Id + to: Id + + +class DigraphCreator[Id: NodeId, N]: + def __init__(self) -> None: + self._nodes: dict[Id, N] = dict() + self._arcs: list[DigraphArc[Id]] = list() + + def add_node(self, node_id: Id, node: N, /) -> None: + if node_id in self._nodes: + raise ValueError(node_id) + self._nodes[node_id] = node + + def add_arc(self, arc: DigraphArc[Id], /) -> None: + if arc.from_ not in self._nodes: + raise ValueError(arc.from_) + if arc.to not in self._nodes: + raise ValueError(arc.to) + self._arcs.append(arc) + + def create(self) -> Digraph[Id, N]: + return Digraph(nodes=self._nodes, arcs=tuple(self._arcs)) + + +@cache +def _get_arcs_to_node[Id: NodeId]( + node_id: Id, arcs: Iterable[DigraphArc[Id]] +) -> list[DigraphArc[Id]]: + return list(arc for arc in arcs if arc.to == node_id) + + +@cache +def _get_arcs_from_node[Id: NodeId]( + node_id: Id, arcs: Iterable[DigraphArc[Id]] +) -> list[DigraphArc[Id]]: + return list(arc for arc in arcs if arc.from_ == node_id) diff --git a/tests/test_digraph.py b/tests/test_digraph.py new file mode 100644 index 0000000..ecc8d0f --- /dev/null +++ b/tests/test_digraph.py @@ -0,0 +1,94 @@ +from dataclasses import dataclass +from typing import assert_type + +from adventofcode.tooling.digraph import Arc, Digraph, DigraphCreator + + +def test_digraph_creator_simple() -> None: + creator = DigraphCreator[int, int]() + creator.add_node(1, 11) + creator.add_node(2, 22) + creator.add_arc(Arc(1, 2)) + digraph = creator.create() + assert digraph.nodes == {1: 11, 2: 22} + assert digraph.arcs == {Arc(1, 2)} + assert_type(digraph.nodes, dict[int, int]) + + +def test_digraph_creator_two_types() -> None: + creator = DigraphCreator[int | str, int | str]() + creator.add_node("a", "aa") + creator.add_node(1, 11) + creator.add_node("b", 22) + creator.add_arc(Arc("a", 1)) + creator.add_arc(Arc("a", "b")) + creator.add_arc(Arc(1, "b")) + digraph = creator.create() + assert digraph.nodes == {"a": "aa", 1: 11, "b": 22} + assert digraph.arcs == {Arc("a", 1), Arc("a", "b"), Arc(1, "b")} + assert_type(digraph.nodes, dict[int | str, int | str]) + + +def test_digraph_creator_multiple_inherited_classes() -> None: + @dataclass + class Base: + name: str + + class Child1(Base): + pass + + class Child2(Base): + pass + + creator = DigraphCreator[str, Child1 | Child2]() + creator.add_node("a", Child1("a")) + creator.add_node("b", Child2("b")) + creator.add_arc(Arc("a", "b")) + digraph = creator.create() + assert digraph.nodes == {"a": Child1("a"), "b": Child2("b")} + assert digraph.arcs == {Arc("a", "b")} + assert_type(digraph.nodes, dict[str, Child1 | Child2]) + + +def test_digraph_get_arcs() -> None: + digraph = Digraph[int, int]( + nodes={1: 11, 2: 22, 3: 33, 4: 44}, + arcs=frozenset((Arc(1, 2), Arc(1, 3), Arc(2, 3), Arc(3, 1))), + ) + assert digraph.get_arcs_from(1) == {Arc(1, 2), Arc(1, 3)} + assert digraph.get_arcs_from(2) == {Arc(2, 3)} + assert digraph.get_arcs_from(3) == {Arc(3, 1)} + assert digraph.get_arcs_from(4) == set() + assert digraph.get_arcs_to(1) == {Arc(3, 1)} + assert digraph.get_arcs_to(2) == {Arc(1, 2)} + assert digraph.get_arcs_to(3) == {Arc(1, 3), Arc(2, 3)} + assert digraph.get_arcs_to(4) == set() + + +def test_digraph_weighted_arcs() -> None: + @dataclass(frozen=True) + class WeightedArc: + from_: str + to: str + weight: int + + digraph_creator = DigraphCreator[str, int]() + digraph_creator.add_node("a", 1) + digraph_creator.add_node("b", 2) + digraph_creator.add_node("c", 3) + digraph_creator.add_arc(WeightedArc("a", "b", 3)) + digraph_creator.add_arc(WeightedArc("a", "c", 4)) + digraph_creator.add_arc(WeightedArc("b", "c", 5)) + digraph = digraph_creator.create() + assert digraph.get_arcs_from("a") == { + WeightedArc("a", "b", 3), + WeightedArc("a", "c", 4), + } + assert digraph.get_arcs_from("b") == {WeightedArc("b", "c", 5)} + assert digraph.get_arcs_from("c") == set() + assert digraph.get_arcs_to("a") == set() + assert digraph.get_arcs_to("b") == {WeightedArc("a", "b", 3)} + assert digraph.get_arcs_to("c") == { + WeightedArc("a", "c", 4), + WeightedArc("b", "c", 5), + }