Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Jan 11, 2024
1 parent db89a58 commit f0637c6
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions ppsci/utils/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,35 +933,38 @@ def _expr_to_callable_nodes(
fused_node_seq, list
), "'fused_node_seq' should be list of 'FusedDerivativeNode'"
gid0, nid0 = candidate_pos[0]
logger.debug(
print(
f"Fused {len(candidate_pos)} derivatives nodes: "
f"{[callable_nodes_group[i][j].expr for i, j in candidate_pos]} into"
f" fuse node sequence: {fused_node_seq} at position: ([{gid0}][{nid0}])"
)

# mark merged node
for i, (gid, nid) in enumerate(candidate_pos):
assert isinstance(callable_nodes_group[gid][nid], DerivativeNode)
callable_nodes_group[gid][nid].merged = True

# replace first mergable node with fused node sequence(packed in list)
# then mask the rest merged node to None(except [gid0, nid0])
for i, (gid, nid) in enumerate(candidate_pos):
for i, (gid, nid) in enumerate(candidate_pos[1:]):
# keep the end node of each group to avoid generating empty callable
# node sequence, this will not effect performance since cache strategy
# in Node.forward
callable_nodes_group[gid][nid].merged = True
if i == 0:
if nid0 == len(callable_nodes_group[gid0]) - 1:
callable_nodes_group[gid0].insert(nid0, fused_node_seq)
else:
callable_nodes_group[gid0] = fused_node_seq
else:
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None
if nid != len(callable_nodes_group[gid]) - 1:
callable_nodes_group[gid][nid] = None

if nid0 == len(callable_nodes_group[gid0]) - 1:
callable_nodes_group[gid0].insert(nid0, fused_node_seq)
else:
callable_nodes_group[gid0][nid0] = fused_node_seq

# re-organize callable_nodes_group, remove None element and unpack list
for i in range(len(callable_nodes_group)):
tmp = []
for j in range(len(callable_nodes_group[i])):
if isinstance(
callable_nodes_group[i][j], (Node, FusedDerivativeNode)
):
if isinstance(callable_nodes_group[i][j], Node):
tmp.append(callable_nodes_group[i][j])
elif isinstance(callable_nodes_group[i][j], FusedDerivativeNode):
tmp.append(callable_nodes_group[i][j])
elif isinstance(callable_nodes_group[i][j], list) and isinstance(
callable_nodes_group[i][j][0], FusedDerivativeNode
Expand All @@ -971,7 +974,8 @@ def _expr_to_callable_nodes(
assert (
callable_nodes_group[i][j] is None
), f"Unexpected element: {callable_nodes_group[i][j]}"
callable_nodes_group[i] = tmp
callable_nodes_group[i] = [t for t in tmp]

else:
# exit while loop if no more fused
break
Expand Down

0 comments on commit f0637c6

Please sign in to comment.