Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[dace]: Custom SDFG inline pass #1649

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
that explains the general structure and requirements on the SDFGs.
"""

from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify
from .auto_opt import gt_auto_optimize, gt_inline_nested_sdfg, gt_set_iteration_order, gt_simplify
from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize
from .loop_blocking import LoopBlocking
from .map_orderer import MapIterationOrder
Expand All @@ -29,6 +29,7 @@
"SerialMapPromoterGPU",
"gt_auto_optimize",
"gt_gpu_transformation",
"gt_inline_nested_sdfg",
"gt_set_iteration_order",
"gt_set_gpu_blocksize",
"gt_simplify",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from typing import Any, Final, Iterable, Optional, Sequence

import dace
from dace.transformation import dataflow as dace_dataflow
from dace.transformation import dataflow as dace_dataflow, passes as dace_passes
from dace.transformation.auto import auto_optimize as dace_aoptimize
from dace.transformation.passes import simplify as dace_passes_simplify

from gt4py.next import common as gtx_common
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand All @@ -36,17 +35,19 @@ def gt_simplify(
sdfg: dace.SDFG,
validate: bool = True,
validate_all: bool = False,
skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET,
skip: Iterable[str] = GT_SIMPLIFY_DEFAULT_SKIP_SET,
) -> Any:
"""Performs simplifications on the SDFG in place.

Instead of calling `sdfg.simplify()` directly, you should use this function,
as it is specially tuned for GridTool based SDFGs.

By default this function will run the normal DaCe simplify pass, but skip
passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it
will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not
added automatically.
passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is given it will
not be modified, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` is not added by default.

Passes that are replaced:
- `InlineSDFGs`: Instead the `gt_inline_nested_sdfg()` will be used.

Args:
sdfg: The SDFG to optimize.
Expand All @@ -55,11 +56,23 @@ def gt_simplify(
skip: List of simplify passes that should not be applied, defaults
to `GT_SIMPLIFY_DEFAULT_SKIP_SET`.
"""
return dace_passes_simplify.SimplifyPass(
# Ensure that `skip` is a `set`
skip = set(skip)

if "InlineSDFGs" not in skip:
gt_inline_nested_sdfg(
sdfg=sdfg,
multistate=True,
permissive=False,
validate=validate,
validate_all=validate_all,
)

return dace_passes.SimplifyPass(
validate=validate,
validate_all=validate_all,
verbose=False,
skip=set(skip) if skip is not None else skip,
skip=set(skip) | {"InlineSDFGs"},
).apply_pass(sdfg, {})


Expand Down Expand Up @@ -91,6 +104,58 @@ def gt_set_iteration_order(
)


def gt_inline_nested_sdfg(
sdfg: dace.SDFG,
multistate: bool = True,
permissive: bool = False,
validate: bool = True,
validate_all: bool = False,
) -> dace.SDFG:
"""Perform inlining of nested SDFG into their parent SDFG.
philip-paul-mueller marked this conversation as resolved.
Show resolved Hide resolved

The function uses DaCe's `InlineSDFG` transformation, the same used in simplify.
However, before the inline transformation is run the function will run some
cleaning passes that allows inlining nested SDFGs.
As a side effect, the function will split stages into more states.

Args:
sdfg: The SDFG that should be processed, will be modified in place and returned.
multistate: Allow inlining of multistate nested SDFG, defaults to `True`.
permissive: Be less strict on the accepted SDFGs.
validate: Perform validation after the transformation has finished.
validate_all: Performs extensive validation.
"""
first_iteration = True
i = 0
while True:
print(f"ITERATION: {i}")
nb_preproccess = sdfg.apply_transformations_repeated(
[dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors],
validate=False,
validate_all=validate_all,
)
if (nb_preproccess == 0) and (not first_iteration):
break

# Create and configure the inline pass
inline_sdfg = dace_passes.InlineSDFGs()
inline_sdfg.progress = False
inline_sdfg.permissive = permissive
inline_sdfg.multistate = multistate

# Apply the inline pass
nb_inlines = inline_sdfg.apply_pass(sdfg, {})

# Check result, if needed and test if we can stop
if validate_all or validate:
sdfg.validate()
if nb_inlines == 0:
break
first_iteration = False

return sdfg


def gt_auto_optimize(
sdfg: dace.SDFG,
gpu: bool,
Expand Down
Loading