Skip to content

Commit

Permalink
Merge origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
tehrengruber committed Sep 20, 2024
2 parents ecdb6d7 + 21b1dfc commit 3080a0e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
that explains the general structure and requirements on the SDFGs.
"""

from .auto_opt import gt_auto_optimize, gt_set_iteration_order, gt_simplify
from .auto_opt import (
GT_SIMPLIFY_DEFAULT_SKIP_SET,
gt_auto_optimize,
gt_inline_nested_sdfg,
gt_set_iteration_order,
gt_simplify,
)
from .gpu_utils import GPUSetBlockSize, gt_gpu_transformation, gt_set_gpu_blocksize
from .loop_blocking import LoopBlocking
from .map_orderer import MapIterationOrder
Expand All @@ -21,6 +27,7 @@


__all__ = [
"GT_SIMPLIFY_DEFAULT_SKIP_SET",
"GPUSetBlockSize",
"LoopBlocking",
"MapIterationOrder",
Expand All @@ -29,6 +36,7 @@
"SerialMapPromoterGPU",
"gt_auto_optimize",
"gt_gpu_transformation",
"gt_inline_nested_sdfg",
"gt_set_iteration_order",
"gt_set_gpu_blocksize",
"gt_simplify",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from typing import Any, Final, Iterable, Optional, Sequence

import dace
from dace.transformation import dataflow as dace_dataflow
from dace.transformation import dataflow as dace_dataflow, passes as dace_passes
from dace.transformation.auto import auto_optimize as dace_aoptimize
from dace.transformation.passes import simplify as dace_passes_simplify

from gt4py.next import common as gtx_common
from gt4py.next.program_processors.runners.dace_fieldview import (
Expand All @@ -36,17 +35,19 @@ def gt_simplify(
sdfg: dace.SDFG,
validate: bool = True,
validate_all: bool = False,
skip: Optional[Iterable[str]] = GT_SIMPLIFY_DEFAULT_SKIP_SET,
skip: Optional[Iterable[str]] = None,
) -> Any:
"""Performs simplifications on the SDFG in place.
Instead of calling `sdfg.simplify()` directly, you should use this function,
as it is specially tuned for GridTool based SDFGs.
By default this function will run the normal DaCe simplify pass, but skip
passes listed in `GT_SIMPLIFY_DEFAULT_SKIP_SET`. If `skip` is passed it
will be forwarded to DaCe, i.e. `GT_SIMPLIFY_DEFAULT_SKIP_SET` are not
added automatically.
This function runs the DaCe simplification pass, but the following passes are
replaced:
- `InlineSDFGs`: Instead `gt_inline_nested_sdfg()` will be called.
Furthermore, by default, or if `None` is passed fro `skip` the passes listed in
`GT_SIMPLIFY_DEFAULT_SKIP_SET` will be skipped.
Args:
sdfg: The SDFG to optimize.
Expand All @@ -55,11 +56,23 @@ def gt_simplify(
skip: List of simplify passes that should not be applied, defaults
to `GT_SIMPLIFY_DEFAULT_SKIP_SET`.
"""
return dace_passes_simplify.SimplifyPass(
# Ensure that `skip` is a `set`
skip = GT_SIMPLIFY_DEFAULT_SKIP_SET if skip is None else set(skip)

if "InlineSDFGs" not in skip:
gt_inline_nested_sdfg(
sdfg=sdfg,
multistate=True,
permissive=False,
validate=validate,
validate_all=validate_all,
)

return dace_passes.SimplifyPass(
validate=validate,
validate_all=validate_all,
verbose=False,
skip=set(skip) if skip is not None else skip,
skip=(skip | {"InlineSDFGs"}),
).apply_pass(sdfg, {})


Expand Down Expand Up @@ -91,6 +104,58 @@ def gt_set_iteration_order(
)


def gt_inline_nested_sdfg(
sdfg: dace.SDFG,
multistate: bool = True,
permissive: bool = False,
validate: bool = True,
validate_all: bool = False,
) -> dace.SDFG:
"""Perform inlining of nested SDFG into their parent SDFG.
The function uses DaCe's `InlineSDFG` transformation, the same used in simplify.
However, before the inline transformation is run the function will run some
cleaning passes that allows inlining nested SDFGs.
As a side effect, the function will split stages into more states.
Args:
sdfg: The SDFG that should be processed, will be modified in place and returned.
multistate: Allow inlining of multistate nested SDFG, defaults to `True`.
permissive: Be less strict on the accepted SDFGs.
validate: Perform validation after the transformation has finished.
validate_all: Performs extensive validation.
"""
first_iteration = True
i = 0
while True:
print(f"ITERATION: {i}")
nb_preproccess = sdfg.apply_transformations_repeated(
[dace_dataflow.PruneSymbols, dace_dataflow.PruneConnectors],
validate=False,
validate_all=validate_all,
)
if (nb_preproccess == 0) and (not first_iteration):
break

# Create and configure the inline pass
inline_sdfg = dace_passes.InlineSDFGs()
inline_sdfg.progress = False
inline_sdfg.permissive = permissive
inline_sdfg.multistate = multistate

# Apply the inline pass
nb_inlines = inline_sdfg.apply_pass(sdfg, {})

# Check result, if needed and test if we can stop
if validate_all or validate:
sdfg.validate()
if nb_inlines == 0:
break
first_iteration = False

return sdfg


def gt_auto_optimize(
sdfg: dace.SDFG,
gpu: bool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ def test_ffront_lap(cartesian_case):
in_field = square(in_field)
out_field = cases.allocate(cartesian_case, lap_program, "out_field")()

# cases.verify(
# cartesian_case,
# lap_program,
# in_field,
# out_field,
# inout=out_field[1:-1, 1:-1],
# ref=lap_ref(in_field.ndarray),
# )
cases.verify(
cartesian_case,
lap_program,
in_field,
out_field,
inout=out_field[1:-1, 1:-1],
ref=lap_ref(in_field.ndarray),
)


def test_ffront_skewedlap(cartesian_case):
Expand Down

0 comments on commit 3080a0e

Please sign in to comment.