Skip to content

Commit

Permalink
Issue #150 Do some renaming now that design dust has settled a bit
Browse files Browse the repository at this point in the history
also add consistency check of  _GraphViewer node map
  • Loading branch information
soxofaan committed Sep 20, 2024
1 parent 32de10f commit cc55c5a
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 95 deletions.
95 changes: 56 additions & 39 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,18 +448,24 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai
}


def to_frozenset(value: Union[Iterable[str], str]) -> frozenset[str]:
"""Coerce value to frozenset of strings"""
if isinstance(value, str):
value = [value]
return frozenset(value)


@dataclasses.dataclass(frozen=True, init=False, eq=True)
class _FrozenNode:
class _GVNode:
"""
Node in a _FrozenGraph, with pointers to other nodes it depends on (needs data/input from)
and nodes to which it is input to.
Node in a _GraphViewer, with pointers to other nodes it depends on (needs data/input from)
and nodes to which it provides input to.
This is as immutable as possible (as far as Python allows) to
be used and reused in iterative/recursive graph handling algorithms,
without having to worry about accidentally changing state.
This structure designed to be as immutable as possible (as far as Python allows)
to be (re)used in iterative/recursive graph handling algorithms,
without having to worry about accidentally propagating changed state to other parts of the graph.
"""

# TODO: better name for this class?
# TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead?

# Node ids of other nodes this node depends on (aka parents)
Expand All @@ -469,51 +475,62 @@ class _FrozenNode:

# Backend ids this node is marked to be supported on
# value None means it is unknown/unconstrained for this node
# TODO: Move this to _FrozenGraph as responsibility?
# TODO: Move this to _GraphViewer as responsibility?
backend_candidates: Union[frozenset[BackendId], None]

def __init__(
self,
*,
depends_on: Optional[Iterable[NodeId]] = None,
flows_to: Optional[Iterable[NodeId]] = None,
backend_candidates: Union[BackendId, Iterable[BackendId], None] = None,
depends_on: Union[Iterable[NodeId], NodeId, None] = None,
flows_to: Union[Iterable[NodeId], NodeId, None] = None,
backend_candidates: Union[Iterable[BackendId], BackendId, None] = None,
):
super().__init__()
object.__setattr__(self, "depends_on", frozenset(depends_on or []))
object.__setattr__(self, "flows_to", frozenset(flows_to or []))
if isinstance(backend_candidates, str):
backend_candidates = frozenset([backend_candidates])
elif backend_candidates is None:
backend_candidates = None
else:
backend_candidates = frozenset(backend_candidates)
object.__setattr__(self, "depends_on", to_frozenset(depends_on or []))
object.__setattr__(self, "flows_to", to_frozenset(flows_to or []))
backend_candidates = to_frozenset(backend_candidates) if backend_candidates is not None else None
object.__setattr__(self, "backend_candidates", backend_candidates)

def __repr__(self):
return f"<{type(self).__name__}({self.depends_on}, {self.flows_to}, {self.backend_candidates})>"


class _FrozenGraph:
class _GraphViewer:
"""
Graph of _FrozenNode objects.
Internal utility to have read-only view on the topological structure of a proces graph
and track the flow of backend support.
"""

# TODO: find better class name: e.g. SplitGraphView, GraphSplitUtility, GraphSplitter, ...?
# TODO: add more logging of what is happening under the hood

def __init__(self, graph: dict[NodeId, _FrozenNode]):
def __init__(self, node_map: dict[NodeId, _GVNode]):
self._check_consistency(node_map=node_map)
# Work with a read-only proxy to prevent accidental changes
# TODO: check consistency of references?
self._graph: Mapping[NodeId, _FrozenNode] = types.MappingProxyType(graph)
self._graph: Mapping[NodeId, _GVNode] = types.MappingProxyType(node_map)

@staticmethod
def _check_consistency(node_map: dict[NodeId, _GVNode]):
"""Check (link) consistency of given node map"""
key_ids = set(node_map.keys())
linked_ids = set(k for n in node_map.values() for k in n.depends_on.union(n.flows_to))
unknown = linked_ids.difference(key_ids)
if unknown:
raise GraphSplitException(f"Inconsistent node map: {key_ids=} != {linked_ids=}: {unknown=}")
bad_links = set()
for node_id, node in node_map.items():
bad_links.update((other, node_id) for other in node.depends_on if node_id not in node_map[other].flows_to)
bad_links.update((node_id, other) for other in node.flows_to if node_id not in node_map[other].depends_on)
if bad_links:
raise GraphSplitException(f"Inconsistent node map: {bad_links=}")

def __repr__(self):
return f"<{type(self).__name__}({self._graph})>"

@classmethod
def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBackendsMapper = (lambda n, d: None)):
"""
Build _FrozenGraph from a flat process graph representation
Build _GraphViewer from a flat process graph representation
"""
# Extract dependency links between nodes
depends_on = collections.defaultdict(list)
Expand All @@ -525,14 +542,14 @@ def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBack
depends_on[node_id].append(from_node)
flows_to[from_node].append(node_id)
graph = {
node_id: _FrozenNode(
node_id: _GVNode(
depends_on=depends_on.get(node_id, []),
flows_to=flows_to.get(node_id, []),
backend_candidates=supporting_backends(node_id, node),
)
for node_id, node in flat_graph.items()
}
return cls(graph=graph)
return cls(node_map=graph)

@classmethod
def from_edges(
Expand All @@ -550,22 +567,22 @@ def from_edges(
flows_to[parent].append(child)

graph = {
node_id: _FrozenNode(
node_id: _GVNode(
# Note that we just use node id as process id. Do we have better options here?
depends_on=depends_on.get(node_id, []),
flows_to=flows_to.get(node_id, []),
backend_candidates=supporting_backends_mapper(node_id, {}),
)
for node_id in set(depends_on.keys()).union(flows_to.keys())
}
return cls(graph=graph)
return cls(node_map=graph)

def node(self, node_id: NodeId) -> _FrozenNode:
def node(self, node_id: NodeId) -> _GVNode:
if node_id not in self._graph:
raise GraphSplitException(f"Invalid node id {node_id}.")
return self._graph[node_id]

def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]:
def iter_nodes(self) -> Iterator[Tuple[NodeId, _GVNode]]:
"""Iterate through node_id-node pairs"""
yield from self._graph.items()

Expand Down Expand Up @@ -686,12 +703,12 @@ def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]:
# Select articulation points: nodes where all flows have weight 1
return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values()))

def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]:
def split_at(self, split_node_id: NodeId) -> Tuple[_GraphViewer, _GraphViewer]:
"""
Split graph at given node id (must be articulation point),
creating two new graphs, containing original nodes and adaptation of the split node.
:return: two _FrozenGraph objects: the upstream subgraph and the downstream subgraph
:return: two _GraphViewer objects: the upstream subgraph and the downstream subgraph
"""
split_node = self.node(split_node_id)

Expand All @@ -710,20 +727,20 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:

up_graph = {n: self.node(n) for n in up_node_ids}
# Replacement of original split node: no `flows_to` links
up_graph[split_node_id] = _FrozenNode(
up_graph[split_node_id] = _GVNode(
depends_on=split_node.depends_on,
backend_candidates=split_node.backend_candidates,
)
up = _FrozenGraph(graph=up_graph)
up = _GraphViewer(node_map=up_graph)

down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids}
# Replacement of original split node: no `depends_on` links
# and perhaps more importantly: do not copy over the original `backend_candidates``
down_graph[split_node_id] = _FrozenNode(
down_graph[split_node_id] = _GVNode(
flows_to=split_node.flows_to,
backend_candidates=None,
)
down = _FrozenGraph(graph=down_graph)
down = _GraphViewer(node_map=down_graph)

return up, down

Expand Down Expand Up @@ -797,7 +814,7 @@ def __init__(self, supporting_backends: SupportingBackendsMapper):
self._supporting_backends_mapper = supporting_backends

def split(self, process_graph: FlatPG) -> _PGSplitResult:
graph = _FrozenGraph.from_flat_graph(
graph = _GraphViewer.from_flat_graph(
flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper
)

Expand Down
Loading

0 comments on commit cc55c5a

Please sign in to comment.