diff --git a/mars/optimization/logical/core.py b/mars/optimization/logical/core.py index af9b8f1e91..9f99def5d6 100644 --- a/mars/optimization/logical/core.py +++ b/mars/optimization/logical/core.py @@ -134,7 +134,7 @@ def _replace_node(self, original_node: EntityType, new_node: EntityType): def _replace_subgraph( self, graph: Optional[EntityGraph], - removed_nodes: Optional[Set[EntityType]], + nodes_to_remove: Optional[Set[EntityType]], new_results: Optional[List[Entity]] = None, ): """ @@ -146,7 +146,7 @@ def _replace_subgraph( ---------- graph : EntityGraph, optional The input graph. If it's none, no new node and edge will be added. - removed_nodes : Set[EntityType], optional + nodes_to_remove : Set[EntityType], optional The nodes to be removed. All the edges connected with them are removed as well. new_results : List[EntityType], optional, default None The updated results of the graph. If it's None, then the results will not be updated. @@ -160,18 +160,18 @@ def _replace_subgraph( affected_successors = set() output_to_node = dict() - removed_nodes = removed_nodes or set() + nodes_to_remove = nodes_to_remove or set() if graph is not None: # Add the output key -> node of the subgraph for node in graph.iter_nodes(): - if node in removed_nodes: + if node in nodes_to_remove: raise ValueError(f"The node {node} is in the removed set") for output in node.outputs: output_to_node[output.key] = node - for node in removed_nodes: + for node in nodes_to_remove: for affected_successor in self._graph.iter_successors(node): - if affected_successor not in removed_nodes: + if affected_successor not in nodes_to_remove: affected_successors.add(affected_successor) # Check whether affected successors' inputs are in subgraph for affected_successor in affected_successors: @@ -180,7 +180,7 @@ def _replace_subgraph( raise ValueError( f"The output {inp} of node {affected_successor} is missing in the subgraph" ) - for node in removed_nodes: + for node in nodes_to_remove: self._graph.remove_node(node) if graph is None: @@ -200,7 +200,7 @@ def _replace_subgraph( self._graph.add_edge(pred_node, node) if new_results is not None: - self._graph.results = new_results.copy() + self._graph.results = list(new_results) def _add_collapsable_predecessor(self, node: EntityType, predecessor: EntityType): pred_original = self._records.get_original_entity(predecessor, predecessor)