From 2027868f5722adac0eedb122d23888eba2367ef6 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Wed, 4 Sep 2024 11:23:22 +0200 Subject: [PATCH] Build clean nestedSDFG without unused data connectors --- .../runners/dace_fieldview/gtir_to_sdfg.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py index 75a5aa07e3..8d23aa42b0 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/gtir_to_sdfg.py @@ -19,7 +19,6 @@ from typing import Any, Dict, List, Optional, Protocol, Sequence, Set, Tuple, Union import dace -import dace.transformation.dataflow as dace_dataflow from gt4py import eve from gt4py.eve import concepts @@ -395,13 +394,21 @@ def visit_Lambda( lambda_symbols = self.global_symbols | { pname: type_ for pname, (_, type_) in lambda_args_mapping.items() } + # obtain the set of symbols that are used in the lambda node and all its child nodes + used_symbols = {str(sym.id) for sym in eve.walk_values(node).if_isinstance(gtir.SymRef)} nsdfg = dace.SDFG(f"{sdfg.label}_nested") nstate = nsdfg.add_state("lambda") + # add sdfg storage for the symbols that need to be passed as input parameters, + # that is only the symbols that are used in the context of the lambda node self._add_sdfg_params( nsdfg, - [gtir.Sym(id=p_name, type=p_type) for p_name, p_type in lambda_symbols.items()], + [ + gtir.Sym(id=p_name, type=p_type) + for p_name, p_type in lambda_symbols.items() + if p_name in used_symbols + ], ) lambda_nodes = GTIRToSDFG(self.offset_provider, lambda_symbols.copy()).visit( @@ -527,9 +534,5 @@ def build_sdfg_from_gtir( sdfg = sdfg_genenerator.visit(program) assert isinstance(sdfg, dace.SDFG) - # nested-SDFGs for let-lambda may contain unused symbols, in which case - # we can remove unnecesssary data connectors (not done by dace simplify pass) - sdfg.apply_transformations_repeated(dace_dataflow.PruneConnectors) - sdfg.simplify() return sdfg