diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index f09b720..398aaeb 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -983,6 +983,9 @@ def supporting_backends(node_id: str, node: dict) -> Union[List[str], None]: graph_splitter = DeepGraphSplitter( supporting_backends=supporting_backends, primary_backend=split_strategy.get("crossbackend", {}).get("primary_backend"), + # TODO: instead of this hardcoded deny-list, build it based on backend metadata inspection? + # TODO: make a config for this? + split_deny_list={"aggregate_spatial", "load_geojson", "load_url"}, ) else: raise ValueError(f"Invalid graph split strategy {graph_split_method!r}") diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 2be30c6..076b021 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -30,6 +30,7 @@ import openeo from openeo import BatchJob +from openeo.util import deep_get from openeo_driver.jobregistry import JOB_STATUS from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND @@ -51,6 +52,7 @@ SubGraphId = str NodeId = str BackendId = str +ProcessId = str # Annotation for a function that maps node information (id and its node dict) @@ -750,10 +752,17 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]: return up, down - def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]: + def produce_split_locations( + self, + limit: int = 20, + allow_split: Callable[[NodeId], bool] = lambda n: True, + ) -> Iterator[List[NodeId]]: """ Produce disjoint subgraphs that can be processed independently. + :param limit: maximum number of split locations to produce + :param allow_split: predicate to determine if a node can be split on (e.g. to deny splitting on certain nodes) + :return: iterator of node listings. Each node listing encodes a graph split (nodes ids where to split). A node listing is ordered with the following in mind: @@ -765,6 +774,8 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]: the previous split. - etc """ + # TODO: allow_split: possible to make this backend-dependent in some way? + # Find nodes that have empty set of backend_candidates forsaken_nodes = self.find_forsaken_nodes() @@ -779,13 +790,11 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]: articulation_points: Set[NodeId] = set(self.find_articulation_points()) _log.debug(f"_GraphViewer.produce_split_locations: {articulation_points=}") - # TODO: allow/deny lists of what openEO processes can be split on? E.g. only split raster cube paths - # Walk upstream from forsaken nodes to find articulation points, where we can cut split_options = [ n for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False) - if n in articulation_points + if n in articulation_points and allow_split(n) ] _log.debug(f"_GraphViewer.produce_split_locations: {split_options=}") if not split_options: @@ -799,7 +808,7 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]: assert not up.find_forsaken_nodes() # Recursively split downstream part if necessary if down.find_forsaken_nodes(): - down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1))) + down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1), allow_split=allow_split)) else: down_splits = [[]] @@ -834,11 +843,20 @@ def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, Non class DeepGraphSplitter(ProcessGraphSplitterInterface): """ More advanced graph splitting (compared to just splitting off `load_collection` nodes) + + :param split_deny_list: list of process ids that should not be split on """ - def __init__(self, supporting_backends: SupportingBackendsMapper, primary_backend: Optional[BackendId] = None): + def __init__( + self, + supporting_backends: SupportingBackendsMapper, + primary_backend: Optional[BackendId] = None, + split_deny_list: Iterable[ProcessId] = (), + ): self._supporting_backends_mapper = supporting_backends self._primary_backend = primary_backend + # TODO also support other deny mechanisms, e.g. callable instead of a deny list? + self._split_deny_list = set(split_deny_list) def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId: if backend_candidates is None: @@ -855,7 +873,11 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult: flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper ) - for split_nodes in graph.produce_split_locations(): + def allow_split(node_id: NodeId) -> bool: + process_id = deep_get(process_graph, node_id, "process_id", default=None) + return process_id not in self._split_deny_list + + for split_nodes in graph.produce_split_locations(allow_split=allow_split): _log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}") split_views = graph.split_at_multiple(split_nodes=split_nodes) diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index 0be8664..adf0e65 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -962,9 +962,36 @@ def test_produce_split_locations_merge_longer_triangle(self): ) assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]] + def test_produce_split_locations_allow_split(self): + """Usage of custom allow_list predicate""" + flat = { + # lc1 lc2 + # | | + # bands1 bands2 + # \ / + # merge1 + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _GraphViewer.from_flat_graph( + flat, + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + ) + assert list(graph.produce_split_locations()) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]] + assert list(graph.produce_split_locations(allow_split=lambda n: n not in {"bands1", "lc2"})) == [ + ["bands2"], + ["lc1"], + ] + class TestDeepGraphSplitter: - def test_simple_no_split(self): + def test_no_split(self): splitter = DeepGraphSplitter(supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]})) flat = { "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, @@ -1150,3 +1177,46 @@ def test_split_with_primary_backend(self, primary_backend, secondary_graph): primary_backend_id=primary_backend, secondary_graphs=[secondary_graph], ) + + @pytest.mark.parametrize( + ["split_deny_list", "split_node", "primary_node_ids", "secondary_node_ids"], + [ + ({}, "temporal2", {"lc1", "bands1", "temporal2", "merge"}, {"lc2", "temporal2"}), + ({"filter_bands", "filter_temporal"}, "lc2", {"lc1", "lc2", "bands1", "temporal2", "merge"}, {"lc2"}), + ], + ) + def test_split_deny_list(self, split_deny_list, split_node, primary_node_ids, secondary_node_ids): + """ + Simple deep split use case: + two load_collections from different backends, with some additional filtering, merged. + """ + splitter = DeepGraphSplitter( + supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}), + primary_backend="b1", + split_deny_list=split_deny_list, + ) + flat = { + # lc1 lc2 + # | | + # bands1 temporal2 + # \ / + # merge + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "temporal2": { + "process_id": "filter_temporal", + "arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]}, + }, + "merge": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}}, + "result": True, + }, + } + result = splitter.split(flat) + assert result == _PGSplitResult( + primary_node_ids=primary_node_ids, + primary_backend_id="b1", + secondary_graphs=[_PGSplitSubGraph(split_node=split_node, node_ids=secondary_node_ids, backend_id="b2")], + )