Skip to content

Commit

Permalink
Issue #150/#155 initial process_id deny list support in DeepGraphSpli…
Browse files Browse the repository at this point in the history
…tter
  • Loading branch information
soxofaan committed Sep 20, 2024
1 parent 7d69369 commit d62f6a6
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 8 deletions.
3 changes: 3 additions & 0 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,9 @@ def supporting_backends(node_id: str, node: dict) -> Union[List[str], None]:
graph_splitter = DeepGraphSplitter(
supporting_backends=supporting_backends,
primary_backend=split_strategy.get("crossbackend", {}).get("primary_backend"),
# TODO: instead of this hardcoded deny-list, build it based on backend metadata inspection?
# TODO: make a config for this?
split_deny_list={"aggregate_spatial", "load_geojson", "load_url"},
)
else:
raise ValueError(f"Invalid graph split strategy {graph_split_method!r}")
Expand Down
36 changes: 29 additions & 7 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import openeo
from openeo import BatchJob
from openeo.util import deep_get
from openeo_driver.jobregistry import JOB_STATUS

from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND
Expand All @@ -51,6 +52,7 @@
SubGraphId = str
NodeId = str
BackendId = str
ProcessId = str


# Annotation for a function that maps node information (id and its node dict)
Expand Down Expand Up @@ -750,10 +752,17 @@ def next_nodes(node_id: NodeId) -> Iterable[NodeId]:

return up, down

def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
def produce_split_locations(
self,
limit: int = 20,
allow_split: Callable[[NodeId], bool] = lambda n: True,
) -> Iterator[List[NodeId]]:
"""
Produce disjoint subgraphs that can be processed independently.
:param limit: maximum number of split locations to produce
:param allow_split: predicate to determine if a node can be split on (e.g. to deny splitting on certain nodes)
:return: iterator of node listings.
Each node listing encodes a graph split (nodes ids where to split).
A node listing is ordered with the following in mind:
Expand All @@ -765,6 +774,8 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
the previous split.
- etc
"""
# TODO: allow_split: possible to make this backend-dependent in some way?

# Find nodes that have empty set of backend_candidates
forsaken_nodes = self.find_forsaken_nodes()

Expand All @@ -779,13 +790,11 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
articulation_points: Set[NodeId] = set(self.find_articulation_points())
_log.debug(f"_GraphViewer.produce_split_locations: {articulation_points=}")

# TODO: allow/deny lists of what openEO processes can be split on? E.g. only split raster cube paths

# Walk upstream from forsaken nodes to find articulation points, where we can cut
split_options = [
n
for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False)
if n in articulation_points
if n in articulation_points and allow_split(n)
]
_log.debug(f"_GraphViewer.produce_split_locations: {split_options=}")
if not split_options:
Expand All @@ -799,7 +808,7 @@ def produce_split_locations(self, limit: int = 20) -> Iterator[List[NodeId]]:
assert not up.find_forsaken_nodes()
# Recursively split downstream part if necessary
if down.find_forsaken_nodes():
down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1)))
down_splits = list(down.produce_split_locations(limit=max(limit - 1, 1), allow_split=allow_split))
else:
down_splits = [[]]

Expand Down Expand Up @@ -834,11 +843,20 @@ def split_at_multiple(self, split_nodes: List[NodeId]) -> Dict[Union[NodeId, Non
class DeepGraphSplitter(ProcessGraphSplitterInterface):
"""
More advanced graph splitting (compared to just splitting off `load_collection` nodes)
:param split_deny_list: list of process ids that should not be split on
"""

def __init__(self, supporting_backends: SupportingBackendsMapper, primary_backend: Optional[BackendId] = None):
def __init__(
self,
supporting_backends: SupportingBackendsMapper,
primary_backend: Optional[BackendId] = None,
split_deny_list: Iterable[ProcessId] = (),
):
self._supporting_backends_mapper = supporting_backends
self._primary_backend = primary_backend
# TODO also support other deny mechanisms, e.g. callable instead of a deny list?
self._split_deny_list = set(split_deny_list)

def _pick_backend(self, backend_candidates: Union[frozenset[BackendId], None]) -> BackendId:
if backend_candidates is None:
Expand All @@ -855,7 +873,11 @@ def split(self, process_graph: FlatPG) -> _PGSplitResult:
flat_graph=process_graph, supporting_backends=self._supporting_backends_mapper
)

for split_nodes in graph.produce_split_locations():
def allow_split(node_id: NodeId) -> bool:
process_id = deep_get(process_graph, node_id, "process_id", default=None)
return process_id not in self._split_deny_list

for split_nodes in graph.produce_split_locations(allow_split=allow_split):
_log.debug(f"DeepGraphSplitter.split: evaluating split nodes: {split_nodes=}")

split_views = graph.split_at_multiple(split_nodes=split_nodes)
Expand Down
72 changes: 71 additions & 1 deletion tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,9 +962,36 @@ def test_produce_split_locations_merge_longer_triangle(self):
)
assert list(graph.produce_split_locations(limit=4)) == [["bands2"], ["mask1"], ["lc2"], ["lc1"]]

def test_produce_split_locations_allow_split(self):
"""Usage of custom allow_list predicate"""
flat = {
# lc1 lc2
# | |
# bands1 bands2
# \ /
# merge1
"lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
"bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
"merge1": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}},
},
}
graph = _GraphViewer.from_flat_graph(
flat,
supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
)
assert list(graph.produce_split_locations()) == [["bands1"], ["bands2"], ["lc1"], ["lc2"]]
assert list(graph.produce_split_locations(allow_split=lambda n: n not in {"bands1", "lc2"})) == [
["bands2"],
["lc1"],
]


class TestDeepGraphSplitter:
def test_simple_no_split(self):
def test_no_split(self):
splitter = DeepGraphSplitter(supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"]}))
flat = {
"lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}},
Expand Down Expand Up @@ -1150,3 +1177,46 @@ def test_split_with_primary_backend(self, primary_backend, secondary_graph):
primary_backend_id=primary_backend,
secondary_graphs=[secondary_graph],
)

@pytest.mark.parametrize(
["split_deny_list", "split_node", "primary_node_ids", "secondary_node_ids"],
[
({}, "temporal2", {"lc1", "bands1", "temporal2", "merge"}, {"lc2", "temporal2"}),
({"filter_bands", "filter_temporal"}, "lc2", {"lc1", "lc2", "bands1", "temporal2", "merge"}, {"lc2"}),
],
)
def test_split_deny_list(self, split_deny_list, split_node, primary_node_ids, secondary_node_ids):
"""
Simple deep split use case:
two load_collections from different backends, with some additional filtering, merged.
"""
splitter = DeepGraphSplitter(
supporting_backends=supporting_backends_from_node_id_dict({"lc1": ["b1"], "lc2": ["b2"]}),
primary_backend="b1",
split_deny_list=split_deny_list,
)
flat = {
# lc1 lc2
# | |
# bands1 temporal2
# \ /
# merge
"lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
"temporal2": {
"process_id": "filter_temporal",
"arguments": {"data": {"from_node": "lc2"}, "extent": ["2022", "2023"]},
},
"merge": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "temporal2"}},
"result": True,
},
}
result = splitter.split(flat)
assert result == _PGSplitResult(
primary_node_ids=primary_node_ids,
primary_backend_id="b1",
secondary_graphs=[_PGSplitSubGraph(split_node=split_node, node_ids=secondary_node_ids, backend_id="b2")],
)

0 comments on commit d62f6a6

Please sign in to comment.