Skip to content

Commit

Permalink
Issue #150 Use a callable supporting_backends in DeepGraphSplitter
Browse files Browse the repository at this point in the history
Instead of dictionary

Also move type coercion logic to _FrozenNode __init__ to eliminate boilerplate code
  • Loading branch information
soxofaan committed Sep 20, 2024
1 parent b0efad8 commit 32de10f
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 87 deletions.
82 changes: 43 additions & 39 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
BackendId = str


# Annotation for a function that maps node information (id and its node dict)
# to id(s) of the backend(s) that support it.
# Returning None means that support is unconstrained (any backend is assumed to support it).
SupportingBackendsMapper = Callable[[NodeId, dict], Union[BackendId, Iterable[BackendId], None]]


class GraphSplitException(Exception):
pass

Expand Down Expand Up @@ -126,6 +132,8 @@ class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface):
Simple process graph splitter that just splits off load_collection nodes.
"""

# TODO: migrate backend_for_collection to SupportingBackendsMapper format?

def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
# TODO: also support not not having a backend_for_collection map?
self._backend_for_collection = backend_for_collection
Expand Down Expand Up @@ -440,7 +448,7 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai
}


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, init=False, eq=True)
class _FrozenNode:
"""
Node in a _FrozenGraph, with pointers to other nodes it depends on (needs data/input from)
Expand All @@ -451,9 +459,8 @@ class _FrozenNode:
without having to worry about accidentally changing state.
"""

# TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs?
# TODO: better name for this class?
# TODO: use NamedTuple instead of dataclass?
# TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead?

# Node ids of other nodes this node depends on (aka parents)
depends_on: frozenset[NodeId]
Expand All @@ -465,15 +472,26 @@ class _FrozenNode:
# TODO: Move this to _FrozenGraph as responsibility?
backend_candidates: Union[frozenset[BackendId], None]

def __init__(
self,
*,
depends_on: Optional[Iterable[NodeId]] = None,
flows_to: Optional[Iterable[NodeId]] = None,
backend_candidates: Union[BackendId, Iterable[BackendId], None] = None,
):
super().__init__()
object.__setattr__(self, "depends_on", frozenset(depends_on or []))
object.__setattr__(self, "flows_to", frozenset(flows_to or []))
if isinstance(backend_candidates, str):
backend_candidates = frozenset([backend_candidates])
elif backend_candidates is None:
backend_candidates = None
else:
backend_candidates = frozenset(backend_candidates)
object.__setattr__(self, "backend_candidates", backend_candidates)

def __repr__(self):
return "".join(
[
f"Node ",
f"@({','.join(self.backend_candidates) if self.backend_candidates else None})",
]
+ [f"<{d}" for d in self.depends_on]
+ [f">{f}" for f in self.flows_to]
)
return f"<{type(self).__name__}({self.depends_on}, {self.flows_to}, {self.backend_candidates})>"


class _FrozenGraph:
Expand All @@ -493,7 +511,7 @@ def __repr__(self):
return f"<{type(self).__name__}({self._graph})>"

@classmethod
def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]):
def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBackendsMapper = (lambda n, d: None)):
"""
Build _FrozenGraph from a flat process graph representation
"""
Expand All @@ -508,14 +526,9 @@ def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId
flows_to[from_node].append(node_id)
graph = {
node_id: _FrozenNode(
depends_on=frozenset(depends_on.get(node_id, [])),
flows_to=frozenset(flows_to.get(node_id, [])),
backend_candidates=(
# TODO move this logic to _FrozenNode.__init__
frozenset(backend_candidates_map.get(node_id))
if node_id in backend_candidates_map
else None
),
depends_on=depends_on.get(node_id, []),
flows_to=flows_to.get(node_id, []),
backend_candidates=supporting_backends(node_id, node),
)
for node_id, node in flat_graph.items()
}
Expand All @@ -525,7 +538,7 @@ def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId
def from_edges(
cls,
edges: Iterable[Tuple[NodeId, NodeId]],
backend_candidates_map: Optional[Dict[NodeId, Iterable[BackendId]]] = None,
supporting_backends_mapper: SupportingBackendsMapper = (lambda n, d: None),
):
"""
Simple factory to build graph from parent-child tuples for testing purposes
Expand All @@ -539,13 +552,9 @@ def from_edges(
graph = {
node_id: _FrozenNode(
# Note that we just use node id as process id. Do we have better options here?
depends_on=frozenset(depends_on.get(node_id, [])),
flows_to=frozenset(flows_to.get(node_id, [])),
backend_candidates=(
frozenset(backend_candidates_map.get(node_id))
if backend_candidates_map and node_id in backend_candidates_map
else None
),
depends_on=depends_on.get(node_id, []),
flows_to=flows_to.get(node_id, []),
backend_candidates=supporting_backends_mapper(node_id, {}),
)
for node_id in set(depends_on.keys()).union(flows_to.keys())
}
Expand Down Expand Up @@ -700,16 +709,17 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:
raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.")

up_graph = {n: self.node(n) for n in up_node_ids}
# Replacement of original split node: no `flows_to` links
up_graph[split_node_id] = _FrozenNode(
depends_on=split_node.depends_on,
flows_to=frozenset(),
backend_candidates=split_node.backend_candidates,
)
up = _FrozenGraph(graph=up_graph)

down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids}
# Replacement of original split node: no `depends_on` links
# and perhaps more importantly: do not copy over the original `backend_candidates``
down_graph[split_node_id] = _FrozenNode(
depends_on=frozenset(),
flows_to=split_node.flows_to,
backend_candidates=None,
)
Expand Down Expand Up @@ -783,18 +793,12 @@ class DeepGraphSplitter(ProcessGraphSplitterInterface):
More advanced graph splitting (compared to just splitting off `load_collection` nodes)
"""

# TODO: unify:
# - backend_for_collection: Callable[[CollectionId], BackendId]
# - backend_candidates_map: Dict[NodeId, Iterable[BackendId]]
# Note that the nodeid-backendid mapping smells like bad decoupling
# as the process graph is given to split methods, while mapping to __init__
# TODO: validation for Iterable[BackendId] (avoid passing a single string instead of iterable of strings)
def __init__(self, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]):
self._backend_candidates_map = backend_candidates_map
def __init__(self, supporting_backends: SupportingBackendsMapper):
self._supporting_backends_mapper = supporting_backends

def split(self, process_graph: FlatPG) -> _PGSplitResult:
graph = _FrozenGraph.from_flat_graph(
flat_graph=process_graph, backend_candidates_map=self._backend_candidates_map
flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper
)

# TODO: make picking "optimal" split location set a bit more deterministic (e.g. sort first)
Expand Down
Loading

0 comments on commit 32de10f

Please sign in to comment.