Skip to content

Commit

Permalink
fixup! fixup! Issue #150 WIP more advanced pg splitting
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 16, 2024
1 parent e59043c commit 7276f0b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,10 @@ def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId],
# TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated)
upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on)
upstream_candidates = [c for c in upstream_candidates if c is not None]
return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates)
if upstream_candidates:
return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates)
else:
return None
else:
return None

Expand Down
51 changes: 50 additions & 1 deletion tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,16 @@ def test_get_backend_candidates_basic(self):
assert graph.get_backend_candidates("c") == {"b2"}
assert graph.get_backend_candidates("d") == set()

def test_get_backend_candidates_none(self):
graph = _FrozenGraph.from_edges(
[("a", "b"), ("b", "d"), ("c", "d")],
backend_candidates_map={},
)
assert graph.get_backend_candidates("a") is None
assert graph.get_backend_candidates("b") is None
assert graph.get_backend_candidates("c") is None
assert graph.get_backend_candidates("d") is None

def test_get_backend_candidates_intersection(self):
graph = _FrozenGraph.from_edges(
[("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")],
Expand Down Expand Up @@ -614,7 +624,7 @@ def test_produce_split_locations_simple(self):
graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": "b1"})
assert list(graph.produce_split_locations()) == [[]]

def test_produce_split_locations_basic(self):
def test_produce_split_locations_merge_basic(self):
"""
Basic produce_split_locations use case:
two load collections on different backends and a merge
Expand All @@ -629,3 +639,42 @@ def test_produce_split_locations_basic(self):
}
graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]})
assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]]

def test_produce_split_locations_merge_longer(self):
flat = {
"lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
"bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
"merge1": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}},
},
}
graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]})
assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]]
assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf(
[["bands1"], ["bands2"], ["lc1"], ["lc2"]],
[["bands2"], ["bands1"], ["lc2"], ["lc1"]],
)

def test_produce_split_locations_merge_longer_triangle(self):
flat = {
"lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}},
"bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}},
"mask1": {
"process_id": "mask",
"arguments": {"data": {"from_node": "bands1"}, "mask": {"from_node": "lc1"}},
},
"lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}},
"bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}},
"merge1": {
"process_id": "merge_cubes",
"arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}},
},
}
graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]})
assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf(
[["mask1"], ["bands2"], ["lc1"], ["lc2"]],
[["bands2"], ["mask1"], ["lc2"], ["lc1"]],
)

0 comments on commit 7276f0b

Please sign in to comment.