From f0637c6dd54ee77cb5bad8198e6d31186eef0fdb Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Thu, 11 Jan 2024 09:40:09 +0000 Subject: [PATCH] fix bug --- ppsci/utils/symbolic.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/ppsci/utils/symbolic.py b/ppsci/utils/symbolic.py index 3baa3cf8f6..7bbf08b0b1 100644 --- a/ppsci/utils/symbolic.py +++ b/ppsci/utils/symbolic.py @@ -933,35 +933,38 @@ def _expr_to_callable_nodes( fused_node_seq, list ), "'fused_node_seq' should be list of 'FusedDerivativeNode'" gid0, nid0 = candidate_pos[0] - logger.debug( + print( f"Fused {len(candidate_pos)} derivatives nodes: " f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into" f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])" ) + # mark merged node + for i, (gid, nid) in enumerate(candidate_pos): + assert isinstance(callable_nodes_group[gid][nid], DerivativeNode) + callable_nodes_group[gid][nid].merged = True + # replace first mergable node with fused node sequence(packed in list) # then mask the rest merged node to None(except [gid0, nid0]) - for i, (gid, nid) in enumerate(candidate_pos): + for i, (gid, nid) in enumerate(candidate_pos[1:]): # keep the end node of each group to avoid generating empty callable # node sequence, this will not effect performance since cache strategy # in Node.forward - callable_nodes_group[gid][nid].merged = True - if i == 0: - if nid0 == len(callable_nodes_group[gid0]) - 1: - callable_nodes_group[gid0].insert(nid0, fused_node_seq) - else: - callable_nodes_group[gid0] = fused_node_seq - else: - if nid != len(callable_nodes_group[gid]) - 1: - callable_nodes_group[gid][nid] = None + if nid != len(callable_nodes_group[gid]) - 1: + callable_nodes_group[gid][nid] = None + + if nid0 == len(callable_nodes_group[gid0]) - 1: + callable_nodes_group[gid0].insert(nid0, fused_node_seq) + else: + callable_nodes_group[gid0][nid0] = fused_node_seq # re-organize callable_nodes_group, remove None element and unpack list for i in range(len(callable_nodes_group)): tmp = [] for j in range(len(callable_nodes_group[i])): - if isinstance( - callable_nodes_group[i][j], (Node, FusedDerivativeNode) - ): + if isinstance(callable_nodes_group[i][j], Node): + tmp.append(callable_nodes_group[i][j]) + elif isinstance(callable_nodes_group[i][j], FusedDerivativeNode): tmp.append(callable_nodes_group[i][j]) elif isinstance(callable_nodes_group[i][j], list) and isinstance( callable_nodes_group[i][j][0], FusedDerivativeNode @@ -971,7 +974,8 @@ def _expr_to_callable_nodes( assert ( callable_nodes_group[i][j] is None ), f"Unexpected element: {callable_nodes_group[i][j]}" - callable_nodes_group[i] = tmp + callable_nodes_group[i] = [t for t in tmp] + else: # exit while loop if no more fused break