Skip to content

Commit

Permalink
Issue #150 CrossBackendSplitter: decouple graph splitting from SubJob…
Browse files Browse the repository at this point in the history
… yielding

Introduce ProcessGraphSplitterInterface, with first LoadCollectionGraphSplitter implementation based on existing simple graph splitting logic
  • Loading branch information
soxofaan committed Sep 17, 2024
1 parent 415d58e commit 0977fd6
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 44 deletions.
5 changes: 4 additions & 1 deletion scripts/crossbackend-processing-poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
LoadCollectionGraphSplitter,
run_partitioned_job,
)

Expand Down Expand Up @@ -62,7 +63,9 @@ def backend_for_collection(collection_id) -> str:
metadata = connection.describe_collection(collection_id)
return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0]

splitter = CrossBackendSplitter(backend_for_collection=backend_for_collection, always_split=True)
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True)
)
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})
_log.info(f"Partitioned job: {pjob!r}")

Expand Down
13 changes: 9 additions & 4 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@
single_backend_collection_post_processing,
)
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
LoadCollectionGraphSplitter,
)
from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter
from openeo_aggregator.partitionedjobs.tracking import (
PartitionedJobConnection,
Expand Down Expand Up @@ -940,9 +943,11 @@ def backend_for_collection(collection_id) -> str:
return self._catalog.get_backends_for_collection(cid=collection_id)[0]

splitter = CrossBackendSplitter(
backend_for_collection=backend_for_collection,
# TODO: job option for `always_split` feature?
always_split=True,
graph_splitter=LoadCollectionGraphSplitter(
backend_for_collection=backend_for_collection,
# TODO: job option for `always_split` feature?
always_split=True,
)
)

pjob_id = self.partitioned_job_tracker.create_crossbackend_pjob(
Expand Down
115 changes: 82 additions & 33 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import collections
import copy
import dataclasses
Expand All @@ -18,6 +19,7 @@
Iterator,
List,
Mapping,
NamedTuple,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -45,6 +47,7 @@
_LOAD_RESULT_PLACEHOLDER = "_placeholder:"

# Some type annotation aliases to make things more self-documenting
CollectionId = str
SubGraphId = str
NodeId = str
BackendId = str
Expand Down Expand Up @@ -87,6 +90,75 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId)
}


class _SubGraphData(NamedTuple):
split_node: NodeId
node_ids: Set[NodeId]
backend_id: BackendId


class _PGSplitResult(NamedTuple):
primary_node_ids: Set[NodeId]
primary_backend_id: BackendId
secondary_graphs: List[_SubGraphData]


class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta):
@abc.abstractmethod
def split(self, process_graph: FlatPG) -> _PGSplitResult:
"""
Split given process graph (flat graph representation) into sub graphs
Returns primary graph data (node ids and backend id)
and secondary graphs data (list of tuples: split node id, subgraph node ids,backend id)
"""
...


class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface):
"""Simple process graph splitter that just splits off load_collection nodes"""

def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
# TODO: also support not not having a backend_for_collection map?
self._backend_for_collection = backend_for_collection
self._always_split = always_split

def split(self, process_graph: FlatPG) -> _PGSplitResult:
# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
cid: self._backend_for_collection(cid)
for cid in (
node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection"
)
}
backend_usage = collections.Counter(backend_per_collection.values())
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")

# TODO: more options to determine primary backend?
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")

primary_has_load_collection = False
primary_graph_node_ids = set()
secondary_graphs: List[_SubGraphData] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
primary_graph_node_ids.add(node_id)
primary_has_load_collection = True
else:
secondary_graphs.append(_SubGraphData(split_node=node_id, node_ids={node_id}, backend_id=bid))
else:
primary_graph_node_ids.add(node_id)

return _PGSplitResult(
primary_node_ids=primary_graph_node_ids,
primary_backend_id=primary_backend,
secondary_graphs=secondary_graphs,
)


class CrossBackendSplitter(AbstractJobSplitter):
"""
Split a process graph, to be executed across multiple back-ends,
Expand All @@ -97,14 +169,12 @@ class CrossBackendSplitter(AbstractJobSplitter):
"""

def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False):
def __init__(self, graph_splitter: ProcessGraphSplitterInterface):
"""
:param backend_for_collection: callable that determines backend id for given collection id
:param always_split: split all load_collections, also when on same backend
"""
# TODO: just handle this `backend_for_collection` callback with a regular method?
self.backend_for_collection = backend_for_collection
self._always_split = always_split
self._graph_splitter = graph_splitter

def split_streaming(
self,
Expand All @@ -127,36 +197,12 @@ def split_streaming(
- dependencies as list of subgraph ids
"""

# Extract necessary back-ends from `load_collection` usage
backend_per_collection: Dict[str, str] = {
cid: self.backend_for_collection(cid)
for cid in (
node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection"
)
}
backend_usage = collections.Counter(backend_per_collection.values())
_log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}")

# TODO: more options to determine primary backend?
primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None
secondary_backends = {b for b in backend_usage if b != primary_backend}
_log.info(f"Backend split: {primary_backend=} {secondary_backends=}")
graph_split_result = self._graph_splitter.split(process_graph=process_graph)

primary_has_load_collection = False
sub_graphs: List[Tuple[NodeId, Set[NodeId], BackendId]] = []
for node_id, node in process_graph.items():
if node["process_id"] == "load_collection":
bid = backend_per_collection[node["arguments"]["id"]]
if bid == primary_backend and (not self._always_split or not primary_has_load_collection):
primary_has_load_collection = True
else:
sub_graphs.append((node_id, {node_id}, bid))

primary_graph_node_ids = set(process_graph.keys()).difference(n for _, ns, _ in sub_graphs for n in ns)
primary_pg = {k: process_graph[k] for k in primary_graph_node_ids}
primary_pg = {k: process_graph[k] for k in graph_split_result.primary_node_ids}
primary_dependencies = []

for node_id, subgraph_node_ids, backend_id in sub_graphs:
for node_id, subgraph_node_ids, backend_id in graph_split_result.secondary_graphs:
# New secondary pg
sub_id = f"{backend_id}:{node_id}"
sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids}
Expand All @@ -178,8 +224,11 @@ def split_streaming(
primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id))
primary_dependencies.append(sub_id)

primary_id = main_subgraph_id
yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies)
yield (
main_subgraph_id,
SubJob(process_graph=primary_pg, backend_id=graph_split_result.primary_backend_id),
primary_dependencies,
)

def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob:
"""Split given process graph into a `PartitionedJob`"""
Expand Down
25 changes: 19 additions & 6 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
GraphSplitException,
LoadCollectionGraphSplitter,
SubGraphId,
_FrozenGraph,
_FrozenNode,
Expand All @@ -26,15 +27,19 @@
class TestCrossBackendSplitter:
def test_split_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
)
res = splitter.split({"process_graph": process_graph})

assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)}
assert res.dependencies == {}

def test_split_streaming_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
)
res = splitter.split_streaming(process_graph)
assert isinstance(res, types.GeneratorType)
assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])]
Expand All @@ -56,7 +61,9 @@ def test_split_basic(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
res = splitter.split({"process_graph": process_graph})

assert res.subjobs == {
Expand Down Expand Up @@ -119,7 +126,9 @@ def test_split_streaming_basic(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
result = splitter.split_streaming(process_graph)
assert isinstance(result, types.GeneratorType)

Expand Down Expand Up @@ -179,7 +188,9 @@ def test_split_streaming_get_replacement(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)

batch_jobs = {}

Expand Down Expand Up @@ -375,7 +386,9 @@ def test_basic(self, aggregator: _FakeAggregator):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})

connection = openeo.Connection(aggregator.url)
Expand Down

0 comments on commit 0977fd6

Please sign in to comment.