Skip to content

Commit

Permalink
Issue #150 WIP more advanced pg splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 13, 2024
1 parent 65816b4 commit 9746ffc
Show file tree
Hide file tree
Showing 2 changed files with 553 additions and 5 deletions.
346 changes: 341 additions & 5 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
from __future__ import annotations

import collections
import copy
import dataclasses
import datetime
import fractions
import functools
import itertools
import logging
import time
import types
from contextlib import nullcontext
from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple
from typing import (
Callable,
Dict,
Iterable,
Iterator,
List,
Mapping,
Optional,
Protocol,
Sequence,
Set,
Tuple,
Union,
)

import openeo
from openeo import BatchJob
Expand All @@ -14,7 +33,12 @@
from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND
from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob
from openeo_aggregator.partitionedjobs.splitting import AbstractJobSplitter
from openeo_aggregator.utils import FlatPG, PGWithMetadata, SkipIntermittentFailures
from openeo_aggregator.utils import (
_UNSET,
FlatPG,
PGWithMetadata,
SkipIntermittentFailures,
)

_log = logging.getLogger(__name__)

Expand All @@ -24,6 +48,10 @@
SubGraphId = str


class GraphSplitException(Exception):
pass


class GetReplacementCallable(Protocol):
"""
Type annotation for callback functions that produce a node replacement
Expand Down Expand Up @@ -88,11 +116,11 @@ def split_streaming(
(e.g. main "primary" graph comes last).
The iterator approach allows working with a dynamic `get_replacement` implementation
that adapting to on previously produced subgraphs
that can be adaptive to previously produced subgraphs
(e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately).
:return: tuple containing:
- subgraph id, recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}')
:return: Iterator of tuples containing:
- subgraph id, it's recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}')
- SubJob
- dependencies as list of subgraph ids
"""
Expand Down Expand Up @@ -351,3 +379,311 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai
}
for sid in subjobs.keys()
}


# Type aliases to make things more self-documenting
NodeId = str
BackendId = str


@dataclasses.dataclass(frozen=True)
class _FrozenNode:
"""
Node in a _FrozenGraph, with pointers to other nodes it depends on (needs data/input from)
and nodes to which it is input to.
This is as immutable as possible (as far as Python allows) to
be used and reused in iterative/recursive graph handling algorithms,
without having to worry about accidentally changing state.
"""

# TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs?

# Node ids of other nodes this node depends on (aka parents)
depends_on: frozenset[NodeId]
# Node ids of other nodes that depend on this node (aka children)
flows_to: frozenset[NodeId]

# Backend ids this node is marked to be supported on
# value None means it is unknown/unconstrained for this node
backend_candidates: Union[frozenset[BackendId], None]

def __repr__(self):
return "".join(
[
f"Node ",
f"@({','.join(self.backend_candidates) if self.backend_candidates else None})",
]
+ [f"<{d}" for d in self.depends_on]
+ [f">{f}" for f in self.flows_to]
)

def clone(self, backend_candidates: Union[frozenset[BackendId], None] = _UNSET) -> "_FrozenNode":
"""Clone this node, optionally updating backend_candidates"""
backend_candidates = self.backend_candidates if backend_candidates is _UNSET else backend_candidates
return _FrozenNode(
depends_on=self.depends_on,
flows_to=self.flows_to,
backend_candidates=backend_candidates,
)


class _FrozenGraph:
"""
Graph of _FrozenNode objects.
"""

# TODO: find better class name: e.g. SplitGraphView, GraphSplitUtility, GraphSplitter, ...?

def __init__(self, graph: dict[NodeId, _FrozenNode]):
# Work with a read-only proxy to prevent accidental changes
# TODO: check consistency of references?
self._graph: Mapping[NodeId, _FrozenNode] = types.MappingProxyType(graph)

def __repr__(self):
return f"<{type(self).__name__}({self._graph})>"

@classmethod
def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]):
"""
Build _FrozenGraph from a flat process graph representation
"""
# Extract dependency links between nodes
depends_on = collections.defaultdict(list)
flows_to = collections.defaultdict(list)
for node_id, node in flat_graph.items():
for arg_value in node.get("arguments", {}).values():
if isinstance(arg_value, dict) and list(arg_value.keys()) == ["from_node"]:
from_node = arg_value["from_node"]
depends_on[node_id].append(from_node)
flows_to[from_node].append(node_id)
graph = {
node_id: _FrozenNode(
depends_on=frozenset(depends_on.get(node_id, [])),
flows_to=frozenset(flows_to.get(node_id, [])),
backend_candidates=(
# TODO move this logic to _FrozenNode.__init__
frozenset(backend_candidates_map.get(node_id))
if node_id in backend_candidates_map
else None
),
)
for node_id, node in flat_graph.items()
}
return cls(graph=graph)

@classmethod
def from_edges(
cls,
edges: Iterable[Tuple[NodeId, NodeId]],
backend_candidates_map: Optional[Dict[NodeId, Iterable[BackendId]]] = None,
):
"""
Simple factory to build graph from parent-child tuples for testing purposes
"""
depends_on = collections.defaultdict(list)
flows_to = collections.defaultdict(list)
for parent, child in edges:
depends_on[child].append(parent)
flows_to[parent].append(child)

graph = {
node_id: _FrozenNode(
# Note that we just use node id as process id. Do we have better options here?
process_id=node_id,
depends_on=frozenset(depends_on.get(node_id, [])),
flows_to=frozenset(flows_to.get(node_id, [])),
backend_candidates=(
frozenset(backend_candidates_map.get(node_id))
if backend_candidates_map and node_id in backend_candidates_map
else None
),
)
for node_id in set(depends_on.keys()).union(flows_to.keys())
}
return cls(graph=graph)

def node(self, node_id: NodeId) -> _FrozenNode:
return self._graph[node_id]

def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]:
"""Iterate through node_id-node pairs"""
yield from self._graph.items()

def _walk(
self, seeds: Iterable[NodeId], next_nodes: Callable[[NodeId], Iterable[NodeId]], include_seeds: bool = True
) -> Iterator[NodeId]:
"""
Walk the graph nodes starting from given seed nodes, taking steps as defined by `next_nodes` function.
Optionally include seeds or not, and walk breadth first.
"""
if include_seeds:
visited = set()
to_visit = list(seeds)
else:
visited = set(seeds)
to_visit = [n for s in seeds for n in next_nodes(s)]

while to_visit:
node_id = to_visit.pop(0)
if node_id in visited:
continue
yield node_id
visited.add(node_id)
to_visit.extend(set(next_nodes(node_id)).difference(visited))

def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]:
"""
Walk upstream nodes (along `depends_on` link) starting from given seed nodes.
Optionally include seeds or not, and walk breadth first.
"""
return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).depends_on, include_seeds=include_seeds)

def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]:
"""
Walk downstream nodes (along `flows_to` link) starting from given seed nodes.
Optionally include seeds or not, and walk breadth first.
"""
return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds)

def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId], None]:
"""Determine backend candidates for given node id"""
if self.node(node_id).backend_candidates is not None:
# Node has explicit backend candidates listed
return self.node(node_id).backend_candidates
elif self.node(node_id).depends_on:
# Backend support is unset: determine it (as intersection) from upstream nodes
# TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated)
upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on)
upstream_candidates = [c for c in upstream_candidates if c is not None]
return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates)
else:
return None

def find_forsaken_nodes(self) -> Set[NodeId]:
"""
Find nodes that have no backend candidates to process them
"""
return set(node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates(node_id) == set())

def find_articulation_points(self) -> Set[NodeId]:
"""
Find articulation points (cut vertices) in the directed graph:
nodes that when removed would split the graph into multiple sub-graphs.
Note that, unlike in traditional graph theory, the search also includes leaf nodes
(e.g. nodes with no parents), as in this context of openEO graph splitting,
when we "cut" a node, we replace it with two disconnected new nodes
(one connecting to the original parents and one connecting to the original children).
"""
# Approach: label the start nodes (e.g. load_collection) with their id and weight 1.
# Propagate these labels along the depends-on links, but split/sum the weight according
# to the number of children/parents.
# At the end: the articulation points are the nodes where all flows have weight 1.

# Mapping: node_id -> start_node_id -> flow_weight
flow_weights: Dict[NodeId, Dict[NodeId, fractions.Fraction]] = {}

# Initialize at the pure input nodes (nodes with no upstream dependencies)
for node_id, node in self.iter_nodes():
if not node.depends_on:
flow_weights[node_id] = {node_id: fractions.Fraction(1, 1)}

# Propagate flow weights using recursion + caching
def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]:
nonlocal flow_weights
if node_id not in flow_weights:
flow_weights[node_id] = {}
# Calculate from upstream nodes
for upstream in self.node(node_id).depends_on:
for start_node_id, weight in get_flow_weights(upstream).items():
flow_weights[node_id].setdefault(start_node_id, fractions.Fraction(0, 1))
flow_weights[node_id][start_node_id] += weight / len(self.node(upstream).flows_to)
return flow_weights[node_id]

for node_id, node in self.iter_nodes():
get_flow_weights(node_id)

# Select articulation points: nodes where all flows have weight 1
return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values()))

def _split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]:
"""
Split graph at given node id
"""
split_node = self.node(split_node_id)

# TODO: first verify that node_id is a valid articulation point?
# Or let this fail, e.g. in validation of _FrozenGraph.__init__?

# Walk the graph, upstream from the split node
def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
node = self.node(node_id)
if node_id == split_node_id:
return node.depends_on
else:
return node.depends_on.union(node.flows_to)

up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes))
up_graph = {n: self.node(n) for n in up_node_ids}
up_graph[split_node_id] = _FrozenNode(
depends_on=split_node.depends_on,
flows_to=frozenset(),
backend_candidates=split_node.backend_candidates,
)
up = _FrozenGraph(graph=up_graph)

down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids}
down_graph[split_node_id] = _FrozenNode(
depends_on=frozenset(),
flows_to=split_node.flows_to,
backend_candidates=None,
)
down = _FrozenGraph(graph=down_graph)

return down, up

def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]:
"""
Produce disjoint subgraphs that can be processed independently
"""
# Find nodes that have empty set of backend_candidates
forsaken_nodes = self.find_forsaken_nodes()

if forsaken_nodes:
# Sort forsaken nodes (based on forsaken parent count), to start higher up the graph
# TODO: avoid need for this sort, and just use a better scoring metric higher up?
forsaken_nodes = sorted(
forsaken_nodes, reverse=True, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on)
)
# Collect nodes where we could split the graph in disjoint subgraphs
articulation_points: Set[NodeId] = set(self.find_articulation_points())

# Walk upstream from forsaken nodes to find articulation points, where we can cut
split_options = [
n
for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False)
if n in articulation_points
]
if not split_options:
raise GraphSplitException("No split options found.")
# TODO: how to handle limit? will it scale feasibly to iterate over all possibilities at this point?
# TODO: smarter picking of split node (e.g. one with most upstream nodes)
for split_node_id in split_options[:limit]:
# Split graph at this articulation point
down, up = self._split_at(split_node_id)
if down.find_forsaken_nodes():
down_splits = list(down.produce_split_locations(limit=limit - 1))
else:
down_splits = [[]]
if up.find_forsaken_nodes():
up_splits = list(up.produce_split_locations(limit=limit - 1))
else:
up_splits = [[]]

for down_split, up_split in itertools.product(down_splits, up_splits):
yield [split_node_id] + down_split + up_split

else:
# All nodes can be handled as is, no need to split
yield []
Loading

0 comments on commit 9746ffc

Please sign in to comment.