diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 984917f..082b0b8 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -53,6 +53,12 @@ BackendId = str +# Annotation for a function that maps node information (id and its node dict) +# to id(s) of the backend(s) that support it. +# Returning None means that support is unconstrained (any backend is assumed to support it). +SupportingBackendsMapper = Callable[[NodeId, dict], Union[BackendId, Iterable[BackendId], None]] + + class GraphSplitException(Exception): pass @@ -126,6 +132,8 @@ class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface): Simple process graph splitter that just splits off load_collection nodes. """ + # TODO: migrate backend_for_collection to SupportingBackendsMapper format? + def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False): # TODO: also support not not having a backend_for_collection map? self._backend_for_collection = backend_for_collection @@ -440,7 +448,7 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai } -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, init=False, eq=True) class _FrozenNode: """ Node in a _FrozenGraph, with pointers to other nodes it depends on (needs data/input from) @@ -451,9 +459,8 @@ class _FrozenNode: without having to worry about accidentally changing state. """ - # TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs? # TODO: better name for this class? - # TODO: use NamedTuple instead of dataclass? + # TODO: type coercion in __init__ of frozen dataclasses is bit ugly. Use attrs with field converters instead? # Node ids of other nodes this node depends on (aka parents) depends_on: frozenset[NodeId] @@ -465,15 +472,26 @@ class _FrozenNode: # TODO: Move this to _FrozenGraph as responsibility? backend_candidates: Union[frozenset[BackendId], None] + def __init__( + self, + *, + depends_on: Optional[Iterable[NodeId]] = None, + flows_to: Optional[Iterable[NodeId]] = None, + backend_candidates: Union[BackendId, Iterable[BackendId], None] = None, + ): + super().__init__() + object.__setattr__(self, "depends_on", frozenset(depends_on or [])) + object.__setattr__(self, "flows_to", frozenset(flows_to or [])) + if isinstance(backend_candidates, str): + backend_candidates = frozenset([backend_candidates]) + elif backend_candidates is None: + backend_candidates = None + else: + backend_candidates = frozenset(backend_candidates) + object.__setattr__(self, "backend_candidates", backend_candidates) + def __repr__(self): - return "".join( - [ - f"Node ", - f"@({','.join(self.backend_candidates) if self.backend_candidates else None})", - ] - + [f"<{d}" for d in self.depends_on] - + [f">{f}" for f in self.flows_to] - ) + return f"<{type(self).__name__}({self.depends_on}, {self.flows_to}, {self.backend_candidates})>" class _FrozenGraph: @@ -493,7 +511,7 @@ def __repr__(self): return f"<{type(self).__name__}({self._graph})>" @classmethod - def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]): + def from_flat_graph(cls, flat_graph: FlatPG, supporting_backends: SupportingBackendsMapper = (lambda n, d: None)): """ Build _FrozenGraph from a flat process graph representation """ @@ -508,14 +526,9 @@ def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId flows_to[from_node].append(node_id) graph = { node_id: _FrozenNode( - depends_on=frozenset(depends_on.get(node_id, [])), - flows_to=frozenset(flows_to.get(node_id, [])), - backend_candidates=( - # TODO move this logic to _FrozenNode.__init__ - frozenset(backend_candidates_map.get(node_id)) - if node_id in backend_candidates_map - else None - ), + depends_on=depends_on.get(node_id, []), + flows_to=flows_to.get(node_id, []), + backend_candidates=supporting_backends(node_id, node), ) for node_id, node in flat_graph.items() } @@ -525,7 +538,7 @@ def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId def from_edges( cls, edges: Iterable[Tuple[NodeId, NodeId]], - backend_candidates_map: Optional[Dict[NodeId, Iterable[BackendId]]] = None, + supporting_backends_mapper: SupportingBackendsMapper = (lambda n, d: None), ): """ Simple factory to build graph from parent-child tuples for testing purposes @@ -539,13 +552,9 @@ def from_edges( graph = { node_id: _FrozenNode( # Note that we just use node id as process id. Do we have better options here? - depends_on=frozenset(depends_on.get(node_id, [])), - flows_to=frozenset(flows_to.get(node_id, [])), - backend_candidates=( - frozenset(backend_candidates_map.get(node_id)) - if backend_candidates_map and node_id in backend_candidates_map - else None - ), + depends_on=depends_on.get(node_id, []), + flows_to=flows_to.get(node_id, []), + backend_candidates=supporting_backends_mapper(node_id, {}), ) for node_id in set(depends_on.keys()).union(flows_to.keys()) } @@ -700,16 +709,17 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]: raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.") up_graph = {n: self.node(n) for n in up_node_ids} + # Replacement of original split node: no `flows_to` links up_graph[split_node_id] = _FrozenNode( depends_on=split_node.depends_on, - flows_to=frozenset(), backend_candidates=split_node.backend_candidates, ) up = _FrozenGraph(graph=up_graph) down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids} + # Replacement of original split node: no `depends_on` links + # and perhaps more importantly: do not copy over the original `backend_candidates`` down_graph[split_node_id] = _FrozenNode( - depends_on=frozenset(), flows_to=split_node.flows_to, backend_candidates=None, ) @@ -783,18 +793,12 @@ class DeepGraphSplitter(ProcessGraphSplitterInterface): More advanced graph splitting (compared to just splitting off `load_collection` nodes) """ - # TODO: unify: - # - backend_for_collection: Callable[[CollectionId], BackendId] - # - backend_candidates_map: Dict[NodeId, Iterable[BackendId]] - # Note that the nodeid-backendid mapping smells like bad decoupling - # as the process graph is given to split methods, while mapping to __init__ - # TODO: validation for Iterable[BackendId] (avoid passing a single string instead of iterable of strings) - def __init__(self, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]): - self._backend_candidates_map = backend_candidates_map + def __init__(self, supporting_backends: SupportingBackendsMapper): + self._supporting_backends_mapper = supporting_backends def split(self, process_graph: FlatPG) -> _PGSplitResult: graph = _FrozenGraph.from_flat_graph( - flat_graph=process_graph, backend_candidates_map=self._backend_candidates_map + flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper ) # TODO: make picking "optimal" split location set a bit more deterministic (e.g. sort first) diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 66218e4..f8f7707 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -18,6 +18,7 @@ GraphSplitException, LoadCollectionGraphSplitter, SubGraphId, + SupportingBackendsMapper, _FrozenGraph, _FrozenNode, _PGSplitResult, @@ -438,6 +439,36 @@ def test_basic(self, aggregator: _FakeAggregator): } +class TestFrozenNode: + def test_default(self): + node = _FrozenNode() + assert node.depends_on == frozenset() + assert node.flows_to == frozenset() + assert node.backend_candidates is None + + def test_basic(self): + node = _FrozenNode(depends_on=["a", "b"], flows_to=["c", "d"], backend_candidates="X") + assert node.depends_on == frozenset(["a", "b"]) + assert node.flows_to == frozenset(["c", "d"]) + assert node.backend_candidates == frozenset(["X"]) + + def test_eq(self): + assert _FrozenNode() == _FrozenNode() + assert _FrozenNode( + depends_on=["a", "b"], + flows_to=["c", "d"], + backend_candidates="X", + ) == _FrozenNode( + depends_on=("b", "a"), + flows_to={"d", "c"}, + backend_candidates=["X"], + ) + + +def supporting_backends_from_node_id_dict(data: dict) -> SupportingBackendsMapper: + return lambda node_id, node: data.get(node_id) + + class TestFrozenGraph: def test_empty(self): graph = _FrozenGraph(graph={}) @@ -448,13 +479,12 @@ def test_from_flat_graph_basic(self): "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"]}) + graph = _FrozenGraph.from_flat_graph( + flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]}) + ) assert sorted(graph.iter_nodes()) == [ - ( - "lc1", - _FrozenNode(frozenset(), frozenset(["ndvi1"]), backend_candidates=frozenset(["b1"])), - ), - ("ndvi1", _FrozenNode(frozenset(["lc1"]), frozenset([]), backend_candidates=None)), + ("lc1", _FrozenNode(flows_to=["ndvi1"], backend_candidates="b1")), + ("ndvi1", _FrozenNode(depends_on=["lc1"])), ] # TODO: test from_flat_graph with more complex graphs @@ -462,12 +492,12 @@ def test_from_flat_graph_basic(self): def test_from_edges(self): graph = _FrozenGraph.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]) assert sorted(graph.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)), - ("b", _FrozenNode(frozenset(), frozenset(["d"]), backend_candidates=None)), - ("c", _FrozenNode(frozenset(["a"]), frozenset(["e"]), backend_candidates=None)), - ("d", _FrozenNode(frozenset(["b"]), frozenset(["e"]), backend_candidates=None)), - ("e", _FrozenNode(frozenset(["c", "d"]), frozenset("f"), backend_candidates=None)), - ("f", _FrozenNode(frozenset(["e"]), frozenset(), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["c"])), + ("b", _FrozenNode(flows_to=["d"])), + ("c", _FrozenNode(depends_on=["a"], flows_to=["e"])), + ("d", _FrozenNode(depends_on=["b"], flows_to=["e"])), + ("e", _FrozenNode(depends_on=["c", "d"], flows_to=["f"])), + ("f", _FrozenNode(depends_on=["e"])), ] @pytest.mark.parametrize( @@ -512,7 +542,7 @@ def test_get_backend_candidates_basic(self): # \ / # d [("a", "b"), ("b", "d"), ("c", "d")], - backend_candidates_map={"a": ["b1"], "c": ["b2"]}, + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": ["b1"], "c": ["b2"]}), ) assert graph.get_backend_candidates_for_node("a") == {"b1"} assert graph.get_backend_candidates_for_node("b") == {"b1"} @@ -535,7 +565,6 @@ def test_get_backend_candidates_none(self): # \ / # d [("a", "b"), ("b", "d"), ("c", "d")], - backend_candidates_map={}, ) assert graph.get_backend_candidates_for_node("a") is None assert graph.get_backend_candidates_for_node("b") is None @@ -553,7 +582,9 @@ def test_get_backend_candidates_intersection(self): # \ / # f [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], - backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, + supporting_backends_mapper=supporting_backends_from_node_id_dict( + {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]} + ), ) assert graph.get_backend_candidates_for_node("a") == {"b1", "b2"} assert graph.get_backend_candidates_for_node("b") == {"b2", "b3"} @@ -576,7 +607,9 @@ def test_find_forsaken_nodes(self): # / \ # g h [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f"), ("f", "g"), ("f", "h")], - backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, + supporting_backends_mapper=supporting_backends_from_node_id_dict( + {"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]} + ), ) assert graph.find_forsaken_nodes() == {"e", "f", "g", "h"} @@ -585,7 +618,7 @@ def test_find_articulation_points_basic(self): "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={}) + graph = _FrozenGraph.from_flat_graph(flat) assert graph.find_articulation_points() == {"lc1", "ndvi1"} @pytest.mark.parametrize( @@ -660,40 +693,45 @@ def test_find_articulation_points_basic(self): ], ) def test_find_articulation_points(self, flat, expected): - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={}) + graph = _FrozenGraph.from_flat_graph(flat) assert graph.find_articulation_points() == expected def test_split_at_minimal(self): - graph = _FrozenGraph.from_edges([("a", "b")], backend_candidates_map={"a": "A"}) + graph = _FrozenGraph.from_edges( + [("a", "b")], supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}) + ) # Split at a up, down = graph.split_at("a") assert sorted(up.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=frozenset(["A"]))), + ("a", _FrozenNode(backend_candidates=["A"])), ] assert sorted(down.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=None)), - ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["b"])), + ("b", _FrozenNode(depends_on=["a"])), ] # Split at b up, down = graph.split_at("b") assert sorted(up.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), - ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["b"], backend_candidates=["A"])), + ("b", _FrozenNode(depends_on=["a"])), ] assert sorted(down.iter_nodes()) == [ - ("b", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ("b", _FrozenNode()), ] def test_split_at_basic(self): - graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c")], backend_candidates_map={"a": "A"}) + graph = _FrozenGraph.from_edges( + [("a", "b"), ("b", "c")], + supporting_backends_mapper=supporting_backends_from_node_id_dict({"a": "A"}), + ) up, down = graph.split_at("b") assert sorted(up.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), - ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["b"], backend_candidates=["A"])), + ("b", _FrozenNode(depends_on=["a"])), ] assert sorted(down.iter_nodes()) == [ - ("b", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)), - ("c", _FrozenNode(frozenset(["b"]), frozenset([]), backend_candidates=None)), + ("b", _FrozenNode(flows_to=["c"])), + ("c", _FrozenNode(depends_on=["b"])), ] def test_split_at_complex(self): @@ -709,29 +747,37 @@ def test_split_at_complex(self): ) def test_split_at_non_articulation_point(self): - graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c"), ("a", "c")]) + graph = _FrozenGraph.from_edges( + # a + # /| + # b | + # \| + # c + [("a", "b"), ("b", "c"), ("a", "c")] + ) + with pytest.raises(GraphSplitException, match="not an articulation point"): _ = graph.split_at("b") # These should still work up, down = graph.split_at("a") assert sorted(up.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ("a", _FrozenNode()), ] assert sorted(down.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), - ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), - ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["b", "c"])), + ("b", _FrozenNode(depends_on=["a"], flows_to=["c"])), + ("c", _FrozenNode(depends_on=["a", "b"])), ] up, down = graph.split_at("c") assert sorted(up.iter_nodes()) == [ - ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), - ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), - ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ("a", _FrozenNode(flows_to=["b", "c"])), + ("b", _FrozenNode(depends_on=["a"], flows_to=["c"])), + ("c", _FrozenNode(depends_on=["a", "b"])), ] assert sorted(down.iter_nodes()) == [ - ("c", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ("c", _FrozenNode()), ] def test_produce_split_locations_simple(self): @@ -743,7 +789,9 @@ def test_produce_split_locations_simple(self): "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": "b1"}) + graph = _FrozenGraph.from_flat_graph( + flat, supporting_backends=supporting_backends_from_node_id_dict({"lc1": "b1"}) + ) assert list(graph.produce_split_locations()) == [[]] def test_produce_split_locations_merge_basic(self): @@ -762,7 +810,10 @@ def test_produce_split_locations_merge_basic(self): "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, }, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + graph = _FrozenGraph.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]] def test_produce_split_locations_merge_longer(self): @@ -781,7 +832,10 @@ def test_produce_split_locations_merge_longer(self): "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, }, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + graph = _FrozenGraph.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]] assert list(graph.produce_split_locations(limit=4)) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]] @@ -807,13 +861,16 @@ def test_produce_split_locations_merge_longer_triangle(self): "arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}}, }, } - graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + graph = _FrozenGraph.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]] class TestDeepGraphSplitter: def test_simple_no_split(self): - splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"]}) + splitter = DeepGraphSplitter(supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]})) flat = { "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, @@ -829,7 +886,9 @@ def test_simple_split(self): """ Most simple split use case: two load_collections from different backends, merged. """ - splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}) + ) flat = { # lc1 lc2 # \ / @@ -860,7 +919,9 @@ def test_simple_deep_split(self): Simple deep split use case: two load_collections from different backends, with some additional filtering, merged. """ - splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}) + ) flat = { # lc1 lc2 # | | @@ -888,7 +949,9 @@ def test_simple_deep_split(self): ) def test_shallow_triple_split(self): - splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + ) flat = { # lc1 lc2 lc3 # \ / / @@ -919,7 +982,9 @@ def test_shallow_triple_split(self): ) def test_triple_split(self): - splitter = DeepGraphSplitter(backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"], "lc3": ["b3"]}) + ) flat = { # lc1 lc2 lc3 # | | |