Skip to content

Commit

Permalink
Using merge instead of replace as the result updating strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
ericpai committed Jun 28, 2023
1 parent c53e623 commit 79b218f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 37 deletions.
42 changes: 29 additions & 13 deletions mars/optimization/logical/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def _replace_subgraph(
self,
graph: Optional[EntityGraph],
nodes_to_remove: Optional[Set[EntityType]],
new_results: Optional[List[Entity]] = None,
new_results: Optional[List[Entity]],
results_to_remove: Optional[List[Entity]],
):
"""
Replace the subgraph from the self._graph represented by a list of nodes with input graph.
Expand All @@ -148,19 +149,28 @@ def _replace_subgraph(
The input graph. If it's none, no new node and edge will be added.
nodes_to_remove : Set[EntityType], optional
The nodes to be removed. All the edges connected with them are removed as well.
new_results : List[EntityType], optional, default None
The updated results of the graph. If it's None, then the results will not be updated.
new_results : List[Entity], optional
The new results to be added to the graph.
results_to_remove : List[Entity], optional
The results to be removed from the graph. If a result is not in self._graph.results, it will be ignored.
Raises
------
ReplaceSubgraphError
If the input key of the removed node's successor can't be found in the subgraph.
Or some of the nodes of the subgraph are in removed ones.
ValueError
1. If the input key of the removed node's successor can't be found in the subgraph.
2. Or some of the nodes of the subgraph are in removed ones.
3. Or the added result is not a valid output of any node in the updated graph.
"""
affected_successors = set()

output_to_node = dict()
nodes_to_remove = nodes_to_remove or set()
results_to_remove = results_to_remove or list()
new_results = new_results or list()
final_results = set(
filter(lambda x: x not in results_to_remove, self._graph.results)
)

if graph is not None:
# Add the output key -> node of the subgraph
for node in graph.iter_nodes():
Expand All @@ -169,6 +179,17 @@ def _replace_subgraph(
for output in node.outputs:
output_to_node[output.key] = node

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
if node not in nodes_to_remove:
for output in node.outputs:
output_to_node[output.key] = node

for result in new_results:
if result.key not in output_to_node:
raise ValueError(f"Unknown result {result} to add")
final_results.update(new_results)

for node in nodes_to_remove:
for affected_successor in self._graph.iter_successors(node):
if affected_successor not in nodes_to_remove:
Expand All @@ -180,17 +201,13 @@ def _replace_subgraph(
raise ValueError(
f"The output {inp} of node {affected_successor} is missing in the subgraph"
)
# Here all the pre-check are passed, we start to replace the subgraph
for node in nodes_to_remove:
self._graph.remove_node(node)

if graph is None:
return

# Add the output key -> node of the original graph
for node in self._graph.iter_nodes():
for output in node.outputs:
output_to_node[output.key] = node

for node in graph.iter_nodes():
self._graph.add_node(node)

Expand All @@ -199,8 +216,7 @@ def _replace_subgraph(
pred_node = output_to_node[inp.key]
self._graph.add_edge(pred_node, node)

if new_results is not None:
self._graph.results = list(new_results)
self._graph.results = list(final_results)

def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType):
pred_original = self._records.get_original_entity(predecessor, predecessor)
Expand Down
63 changes: 39 additions & 24 deletions mars/optimization/logical/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
from typing import Optional

import pytest


Expand All @@ -24,8 +26,8 @@ class _MockRule(OptimizationRule):
def apply(self) -> bool:
pass

def replace_subgraph(self, graph, removed_nodes, new_results=None):
self._replace_subgraph(graph, removed_nodes, new_results)
def replace_subgraph(self, graph, nodes_to_remove, new_results, results_to_remove):
self._replace_subgraph(graph, nodes_to_remove, new_results, results_to_remove)


def test_replace_tileable_subgraph():
Expand Down Expand Up @@ -78,11 +80,15 @@ def test_replace_tileable_subgraph():
c2 = g1.successors(key_to_node[s2.key])[0]
c5 = g1.successors(key_to_node[s5.key])[0]

expected_results = [v8.outputs[0]]
new_results = [v8.outputs[0]]
removed_results = [
v6.outputs[0],
v8.outputs[0], # v8.outputs[0] is not in the original results, so we ignore it.
]
r.replace_subgraph(
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, expected_results
g2, {key_to_node[op.key] for op in [v3, v4, v6]}, new_results, removed_results
)
assert g1.results == expected_results
assert g1.results == new_results

expected_nodes = {s1, c1, v1, s2, c2, v2, s5, c5, v5, v7, v8}
assert set(g1) == {key_to_node[n.key] for n in expected_nodes}
Expand Down Expand Up @@ -110,10 +116,10 @@ def test_replace_tileable_subgraph():
def test_replace_null_subgraph():
"""
Original Graph:
s1 ---> c1 ---> v1 ---> v3 <--- v2 <--- c2 <--- s2
s1 ---> c1 ---> v1 ---> v3(out) <--- v2 <--- c2 <--- s2
Target Graph:
c1 ---> v1 ---> v3 <--- v2 <--- c2
c1 ---> v1 ---> v3 <--- v2(out) <--- c2
The nodes [s1, s2] will be removed.
Subgraph is None
Expand All @@ -129,30 +135,39 @@ def test_replace_null_subgraph():
c2 = g1.successors(key_to_node[s2.key])[0]
r = _MockRule(g1, None, None)
expected_results = [v3.outputs[0]]

# delete c5 s5 will fail
with pytest.raises(ValueError) as e:
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
r.replace_subgraph(
None, {key_to_node[op.key] for op in [s1, s2]}, None, [v2.outputs[0]]
)

assert g1.results == expected_results
assert set(g1) == {key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, v3}}
expected_edges = {
s1: [c1],
c1: [v1],
v1: [v3],
s2: [c2],
c2: [v2],
v2: [v3],
v3: [],
}
for pred, successors in expected_edges.items():
pred_node = key_to_node[pred.key]
assert g1.count_successors(pred_node) == len(successors)
for successor in successors:
assert g1.has_successor(pred_node, key_to_node[successor.key])

c1.inputs.clear()
c2.inputs.clear()
r.replace_subgraph(None, {key_to_node[op.key] for op in [s1, s2]})
assert g1.results == expected_results
r.replace_subgraph(
None,
{key_to_node[op.key] for op in [s1, s2]},
[v2.outputs[0]],
[v3.outputs[0]],
)
assert g1.results == [v2.outputs[0]]
assert set(g1) == {key_to_node[n.key] for n in {c1, v1, c2, v2, v3}}
expected_edges = {
c1: [v1],
Expand Down Expand Up @@ -198,7 +213,7 @@ def test_replace_subgraph_without_removing_nodes():
c2 = g1.successors(key_to_node[s2.key])[0]
c3 = g2.successors(key_to_node[s3.key])[0]
r = _MockRule(g1, None, None)
r.replace_subgraph(g2, None, expected_results)
r.replace_subgraph(g2, None, [v3.outputs[0]], None)
assert g1.results == expected_results
assert set(g1) == {
key_to_node[n.key] for n in {s1, c1, v1, s2, c2, v2, s3, c3, v3, v4}
Expand Down

0 comments on commit 79b218f

Please sign in to comment.