From 2eed3679b513c5e308fea7a666f9a3b6bb497ec6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:15:08 +0200 Subject: [PATCH] feat[dace]: Custom SDFG inline pass (#1649) Added a custom pass for inlining SDFG. The function builds upon the traditional inline pass (`InlineSDFG`), however, before it is run some cleaning steps are preformed (`PruneConectors` and `PruneSymbols`) which increase the likelihood that the inlining can be done. The cost is that state fusing is performed. Also the `gt_simplify()` function is modified, instead of the build in behaviour it will only run the GT4Py specific one. This behaviour can not be changed. --- .../transformations/__init__.py | 10 ++- .../transformations/auto_opt.py | 83 +++++++++++++++++-- 2 files changed, 83 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py index 53fa1eee05..8852dd6d2d 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/__init__.py @@ -12,7 +12,13 @@ that explains the general structure and requirements on the SDFGs. """ -from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify +from .auto_opt import ( + GT_SIMPLIFY_DEFAULT_SKIP_SET, + gt_auto_optimize, + gt_inline_nested_sdfg, + gt_set_iteration_order, + gt_simplify, +) from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize from .loop_blocking import LoopBlocking from .map_orderer import MapIterationOrder @@ -21,6 +27,7 @@ __all__ = [ + "GT_SIMPLIFY_DEFAULT_SKIP_SET", "GPUSetBlockSize", "LoopBlocking", "MapIterationOrder", @@ -29,6 +36,7 @@ "SerialMapPromoterGPU", "gt_auto_optimize", "gt_gpu_transformation", + "gt_inline_nested_sdfg", "gt_set_iteration_order", "gt_set_gpu_blocksize", "gt_simplify", diff --git a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py index 3895f7f5e8..37cc89aa2b 100644 --- a/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py +++ b/src/gt4py/next/program_processors/runners/dace_fieldview/transformations/auto_opt.py @@ -11,9 +11,8 @@ from typing import Any, Final, Iterable, Optional, Sequence import dace -from dace.transformation import dataflow as dace_dataflow +from dace.transformation import dataflow as dace_dataflow, passes as dace_passes from dace.transformation.auto import auto_optimize as dace_aoptimize -from dace.transformation.passes import simplify as dace_passes_simplify from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace_fieldview import ( @@ -36,17 +35,19 @@ def gt_simplify( sdfg: dace.SDFG, validate: bool = True, validate_all: bool = False, - skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET, + skip: Optional[Iterable[str]] = None, ) -> Any: """Performs simplifications on the SDFG in place. Instead of calling `sdfg.simplify()` directly, you should use this function, as it is specially tuned for GridTool based SDFGs. - By default this function will run the normal DaCe simplify pass, but skip - passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it - will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not - added automatically. + This function runs the DaCe simplification pass, but the following passes are + replaced: + - `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called. + + Furthermore, by default, or if `None` is passed fro `skip` the passes listed in + `GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped. Args: sdfg: The SDFG to optimize. @@ -55,11 +56,23 @@ def gt_simplify( skip: List of simplify passes that should not be applied, defaults to `GT_SIMPLIFY_DEFAULT_SKIP_SET`. """ - return dace_passes_simplify.SimplifyPass( + # Ensure that `skip` is a `set` + skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip) + + if "InlineSDFGs" not in skip: + gt_inline_nested_sdfg( + sdfg=sdfg, + multistate=True, + permissive=False, + validate=validate, + validate_all=validate_all, + ) + + return dace_passes.SimplifyPass( validate=validate, validate_all=validate_all, verbose=False, - skip=set(skip) if skip is not None else skip, + skip=(skip | {"InlineSDFGs"}), ).apply_pass(sdfg, {}) @@ -91,6 +104,58 @@ def gt_set_iteration_order( ) +def gt_inline_nested_sdfg( + sdfg: dace.SDFG, + multistate: bool = True, + permissive: bool = False, + validate: bool = True, + validate_all: bool = False, +) -> dace.SDFG: + """Perform inlining of nested SDFG into their parent SDFG. + + The function uses DaCe's `InlineSDFG` transformation, the same used in simplify. + However, before the inline transformation is run the function will run some + cleaning passes that allows inlining nested SDFGs. + As a side effect, the function will split stages into more states. + + Args: + sdfg: The SDFG that should be processed, will be modified in place and returned. + multistate: Allow inlining of multistate nested SDFG, defaults to `True`. + permissive: Be less strict on the accepted SDFGs. + validate: Perform validation after the transformation has finished. + validate_all: Performs extensive validation. + """ + first_iteration = True + i = 0 + while True: + print(f"ITERATION: {i}") + nb_preproccess = sdfg.apply_transformations_repeated( + [dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors], + validate=False, + validate_all=validate_all, + ) + if (nb_preproccess == 0) and (not first_iteration): + break + + # Create and configure the inline pass + inline_sdfg = dace_passes.InlineSDFGs() + inline_sdfg.progress = False + inline_sdfg.permissive = permissive + inline_sdfg.multistate = multistate + + # Apply the inline pass + nb_inlines = inline_sdfg.apply_pass(sdfg, {}) + + # Check result, if needed and test if we can stop + if validate_all or validate: + sdfg.validate() + if nb_inlines == 0: + break + first_iteration = False + + return sdfg + + def gt_auto_optimize( sdfg: dace.SDFG, gpu: bool,