From 8785f3910f966001f053351a8b6aee1955c229e2 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 28 Jul 2024 19:04:34 +0200 Subject: [PATCH 1/7] Support for itir.Program in TraceShift --- .../next/iterator/transforms/global_tmps.py | 16 +- .../inline_center_deref_lift_vars.py | 6 +- .../next/iterator/transforms/trace_shifts.py | 126 ++++--- .../iterator_tests/test_type_inference.py | 1 + .../transforms_tests/test_trace_shifts.py | 309 ++++++------------ 5 files changed, 204 insertions(+), 254 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index 135379b72c..ffcb9caf7a 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -177,14 +177,13 @@ class SimpleTemporaryExtractionHeuristics: closure: ir.StencilClosure - @functools.cached_property - def closure_shifts( - self, - ) -> dict[int, set[tuple[ir.OffsetLiteral, ...]]]: - return trace_shifts.TraceShifts.apply(self.closure, inputs_only=False) # type: ignore[return-value] # TODO fix weird `apply` overloads + def __post_init__(self) -> None: + trace_shifts.trace_stencil( + self.closure.stencil, num_args=len(self.closure.inputs), save_to_annex=True + ) def __call__(self, expr: ir.Expr) -> bool: - shifts = self.closure_shifts[id(expr)] + shifts = expr.annex.recorded_shifts if len(shifts) > 1: return True return False @@ -523,8 +522,9 @@ def update_domains( closures.append(closure) - local_shifts = trace_shifts.TraceShifts.apply(closure) - for param, shift_chains in local_shifts.items(): + local_shifts = trace_shifts.trace_stencil(closure.stencil, num_args=len(closure.inputs)) + for param_sym, shift_chains in zip(closure.inputs, local_shifts): + param = param_sym.id assert isinstance(param, str) consumed_domains: list[SymbolicDomain] = ( [SymbolicDomain.from_expr(domains[param])] if param in domains else [] diff --git a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py index 2b6bcf3c9d..ce21f38403 100644 --- a/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py +++ b/src/gt4py/next/iterator/transforms/inline_center_deref_lift_vars.py @@ -20,9 +20,9 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms import trace_shifts from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda from gt4py.next.iterator.transforms.inline_lifts import InlineLifts -from gt4py.next.iterator.transforms.trace_shifts import TraceShifts, copy_recorded_shifts def is_center_derefed_only(node: itir.Node) -> bool: @@ -64,7 +64,7 @@ def apply(cls, node: itir.FencilDefinition, uids: Optional[eve_utils.UIDGenerato def visit_StencilClosure(self, node: itir.StencilClosure, **kwargs): # TODO(tehrengruber): move the analysis out of this pass and just make it a requirement # such that we don't need to run in multiple times if multiple passes use it. - TraceShifts.apply(node, save_to_annex=True) + trace_shifts.trace_stencil(node.stencil, num_args=len(node.inputs), save_to_annex=True) return self.generic_visit(node, **kwargs) def visit_FunCall(self, node: itir.FunCall, **kwargs): @@ -80,7 +80,7 @@ def visit_FunCall(self, node: itir.FunCall, **kwargs): eligible_params[i] = True bound_arg_name = self.uids.sequential_id(prefix="_icdlv") capture_lift = im.promote_to_const_iterator(bound_arg_name) - copy_recorded_shifts(from_=param, to=capture_lift) + trace_shifts.copy_recorded_shifts(from_=param, to=capture_lift) new_args.append(capture_lift) # since we deref an applied lift here we can (but don't need to) immediately # inline diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 17a55be4a2..99ed6f72ab 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -15,11 +15,12 @@ import enum import sys from collections.abc import Callable -from typing import Any, Final, Iterable, Literal +from typing import Any, Final, Iterable, Literal, Optional from gt4py import eve from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -82,7 +83,7 @@ def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]) # for performance reasons (`isinstance` is slow otherwise) we don't use abc here -class IteratorTracer: +class Tracer: def deref(self): raise NotImplementedError() @@ -91,13 +92,13 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): @dataclasses.dataclass(frozen=True) -class IteratorArgTracer(IteratorTracer): +class ArgTracer(Tracer): arg: ir.Expr | ir.Sym shift_recorder: ShiftRecorder | ForwardingShiftRecorder offsets: tuple[ir.OffsetLiteral, ...] = () def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): - return IteratorArgTracer( + return ArgTracer( arg=self.arg, shift_recorder=self.shift_recorder, offsets=self.offsets + tuple(offsets) ) @@ -109,8 +110,8 @@ def deref(self): # This class is only needed because we currently allow conditionals on iterators. Since this is # not supported in the C++ backend it can likely be removed again in the future. @dataclasses.dataclass(frozen=True) -class CombinedTracer(IteratorTracer): - its: tuple[IteratorTracer, ...] +class CombinedTracer(Tracer): + its: tuple[Tracer, ...] def shift(self, offsets: tuple[ir.OffsetLiteral, ...]): return CombinedTracer(tuple(_shift(*offsets)(it) for it in self.its)) @@ -143,16 +144,16 @@ def _can_deref(x): def _shift(*offsets): def apply(arg): - assert isinstance(arg, IteratorTracer) + assert isinstance(arg, Tracer) return arg.shift(offsets) return apply @dataclasses.dataclass(frozen=True) -class AppliedLift(IteratorTracer): +class AppliedLift(Tracer): stencil: Callable - its: tuple[IteratorTracer, ...] + its: tuple[Tracer, ...] def shift(self, offsets): return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its)) @@ -163,7 +164,7 @@ def deref(self): def _lift(f): def apply(*its): - if not all(isinstance(it, IteratorTracer) for it in its): + if not all(isinstance(it, Tracer) for it in its): raise AssertionError("All arguments must be iterators.") return AppliedLift(f, its) @@ -190,20 +191,18 @@ def apply(*args): def _primitive_constituents( - val: Literal[Sentinel.VALUE] | IteratorTracer | tuple, -) -> Iterable[Literal[Sentinel.VALUE] | IteratorTracer]: - if val is Sentinel.VALUE or isinstance(val, IteratorTracer): + val: Literal[Sentinel.VALUE] | Tracer | tuple, +) -> Iterable[Literal[Sentinel.VALUE] | Tracer]: + if val is Sentinel.VALUE or isinstance(val, Tracer): yield val elif isinstance(val, tuple): for el in val: if isinstance(el, tuple): yield from _primitive_constituents(el) - elif el is Sentinel.VALUE or isinstance(el, IteratorTracer): + elif el is Sentinel.VALUE or isinstance(el, Tracer): yield el else: - raise AssertionError( - "Expected a `Sentinel.VALUE`, `IteratorTracer` or tuple thereof." - ) + raise AssertionError("Expected a `Sentinel.VALUE`, `Tracer` or tuple thereof.") else: raise ValueError() @@ -226,9 +225,7 @@ def _if(cond: Literal[Sentinel.VALUE], true_branch, false_branch): result.append(_if(Sentinel.VALUE, el_true_branch, el_false_branch)) return tuple(result) - is_iterator_arg = tuple( - isinstance(arg, IteratorTracer) for arg in (cond, true_branch, false_branch) - ) + is_iterator_arg = tuple(isinstance(arg, Tracer) for arg in (cond, true_branch, false_branch)) if is_iterator_arg == (False, True, True): return CombinedTracer((true_branch, false_branch)) assert is_iterator_arg == (False, False, False) and all( @@ -248,7 +245,15 @@ def _tuple_get(index, tuple_val): return Sentinel.VALUE +def _as_fieldop(stencil, domain=None): + def applied_as_fieldop(*args): + return stencil(*args) + + return applied_as_fieldop + + _START_CTX: Final = { + "as_fieldop": _as_fieldop, "deref": _deref, "can_deref": _can_deref, "shift": _shift, @@ -292,11 +297,11 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any: def visit(self, node, **kwargs): result = super().visit(node, **kwargs) - if isinstance(result, IteratorTracer): + if isinstance(result, Tracer): assert isinstance(node, (ir.Sym, ir.Expr)) self.shift_recorder.register_node(node) - result = IteratorArgTracer( + result = ArgTracer( arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder) ) return result @@ -305,10 +310,10 @@ def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable: def fun(*args): new_args = [] for param, arg in zip(node.params, args, strict=True): - if isinstance(arg, IteratorTracer): + if isinstance(arg, Tracer): self.shift_recorder.register_node(param) new_args.append( - IteratorArgTracer( + ArgTracer( arg=param, shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder), ) @@ -322,46 +327,84 @@ def fun(*args): return fun + # FIXME[#1582](tehrengruber): remove after refactoring to GTIR def visit_StencilClosure(self, node: ir.StencilClosure): tracers = [] for inp in node.inputs: self.shift_recorder.register_node(inp) - tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder)) + tracers.append(ArgTracer(arg=inp, shift_recorder=self.shift_recorder)) result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) return node + def visit_SetAt(self, node: ir.SetAt, *, ctx: dict[str, Any]) -> None: + self.visit(node.expr, ctx=ctx) + + def initialize_context(self, inputs: Iterable[ir.Sym | ir.SymRef]) -> dict[str, Any]: + ctx: dict[str, Any] = {**_START_CTX} + for inp in inputs: + self.shift_recorder.register_node(inp) + ctx[inp.id] = ArgTracer(arg=inp, shift_recorder=self.shift_recorder) + return ctx + @classmethod - def apply( - cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False - ) -> ( - dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] + def trace_stencil( + cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False ): + if isinstance(stencil, ir.Lambda): + assert num_args is None or num_args == len(stencil.params) + num_args = len(stencil.params) + if not isinstance(num_args, int): + raise ValueError("Stencil must be an 'itir.Lambda' or `num_args` is given.") + assert isinstance(num_args, int) + + args = [im.ref(f"__arg{i}") for i in range(num_args)] + + instance = cls() + ctx = instance.initialize_context(args) + instance.visit(im.call(stencil)(*args), ctx=ctx) + + recorded_shifts = instance.shift_recorder.recorded_shifts + + param_shifts = [] + for arg in args: + param_shifts.append(recorded_shifts[id(arg)]) + + if save_to_annex: + _save_to_annex(stencil, recorded_shifts) + + return param_shifts + + @classmethod + def trace_program(cls, program: ir.Program, save_to_annex=False): old_recursionlimit = sys.getrecursionlimit() sys.setrecursionlimit(100000000) instance = cls() - instance.visit(node) + ctx = instance.initialize_context(program.params) + + for stmt in program.body: + assert isinstance(stmt, ir.SetAt) + instance.visit(stmt, ctx=ctx) sys.setrecursionlimit(old_recursionlimit) recorded_shifts = instance.shift_recorder.recorded_shifts if save_to_annex: - _save_to_annex(node, recorded_shifts) + _save_to_annex(program, recorded_shifts) + + param_shifts: dict[str, set[tuple[ir.OffsetLiteral, ...]]] = {} + for param in program.params: + param_shifts[str(param.id)] = recorded_shifts[id(param)] - if __debug__: - ValidateRecordedShiftsAnnex().visit(node) + return param_shifts - if inputs_only: - assert isinstance(node, ir.StencilClosure) - inputs_shifts = {} - for inp in node.inputs: - inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)] - return inputs_shifts - return recorded_shifts +trace_program = TraceShifts.trace_program + +trace_stencil = TraceShifts.trace_stencil def _save_to_annex( @@ -370,3 +413,6 @@ def _save_to_annex( for child_node in node.pre_walk_values(): if id(child_node) in recorded_shifts: child_node.annex.recorded_shifts = recorded_shifts[id(child_node)] + + if __debug__: + ValidateRecordedShiftsAnnex().visit(node) diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 7b5d3e6a2f..1263ef65c5 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -233,6 +233,7 @@ def test_aliased_function(): assert result.type == int_type + def test_late_offset_axis(): mesh = simple_mesh() diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index 1d1a2dc89d..fad4097f47 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -11,169 +11,101 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later - from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.transforms.trace_shifts import Sentinel, TraceShifts -def test_trivial(): - testee = ir.StencilClosure( - stencil=ir.SymRef(id="deref"), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), +def test_trivial_stencil(): + expected = [{()}] + + actual = TraceShifts.trace_stencil(im.ref("deref"), num_args=1) + assert actual == expected + + +def test_trivial_program(): + set_at = ir.SetAt( + expr=im.as_fieldop("deref")("inp"), + domain=im.call("cartesian_domain")(), + target=im.ref("out"), + ) + testee = ir.Program( + id="testee", + function_definitions=[], + params=[im.sym("inp"), im.sym("out")], + declarations=[], + body=[set_at], ) - expected = {"inp": {()}} + expected = {"inp": {()}, "out": set()} - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_program(testee) assert actual == expected def test_shift(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], - ), - args=[ir.SymRef(id="x")], - ) - ], - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + testee = im.lambda_("inp")(im.deref(im.shift("I", 1)("inp"))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_lift(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="deref"), - args=[ - ir.FunCall( - fun=ir.FunCall(fun=ir.SymRef(id="lift"), args=[ir.SymRef(id="deref")]), - args=[ - ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="shift"), - args=[ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1)], - ), - args=[ir.SymRef(id="x")], - ) - ], - ) - ], - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + testee = im.lambda_("inp")(im.deref(im.lift("deref")(im.shift("I", 1)("inp")))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_neighbors(): - testee = ir.StencilClosure( - stencil=ir.Lambda( - expr=ir.FunCall( - fun=ir.SymRef(id="neighbors"), args=[ir.OffsetLiteral(value="O"), ir.SymRef(id="x")] - ), - params=[ir.Sym(id="x")], - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="O"), Sentinel.ALL_NEIGHBORS)}} + testee = im.lambda_("inp")(im.neighbors("O", "inp")) + expected = [{(ir.OffsetLiteral(value="O"), Sentinel.ALL_NEIGHBORS)}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_reduce(): - testee = ir.StencilClosure( - # λ(inp) → reduce(plus, 0.)(·inp) - stencil=ir.Lambda( - params=[ir.Sym(id="inp")], - expr=ir.FunCall( - fun=ir.FunCall( - fun=ir.SymRef(id="reduce"), - args=[ir.SymRef(id="plus"), im.literal_from_value(0.0)], - ), - args=[ir.FunCall(fun=ir.SymRef(id="deref"), args=[ir.SymRef(id="inp")])], - ), - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {()}} + # λ(inp) → reduce(plus, 0.)(·inp) + testee = im.lambda_("inp")(im.call(im.call("reduce")("plus", 0.0))(im.deref("inp"))) + expected = [{()}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_shifted_literal(): "Test shifting an applied lift of a stencil returning a constant / literal works." - testee = ir.StencilClosure( - # λ(x) → ·⟪Iₒ, 1ₒ⟫((↑(λ() → 1))()) - stencil=im.lambda_("x")(im.deref(im.shift("I", 1)(im.lift(im.lambda_()(1))()))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": set()} + testee = im.lambda_("inp")(im.deref(im.shift("I", 1)(im.lift(im.lambda_()(1))()))) + expected = [set()] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_tuple_get(): - testee = ir.StencilClosure( - # λ(x, y) → ·{x, y}[1] - stencil=im.lambda_("x", "y")(im.deref(im.tuple_get(1, im.make_tuple("x", "y")))), - inputs=[ir.SymRef(id="inp1"), ir.SymRef(id="inp2")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp1": set(), "inp2": {()}} # never derefed # once derefed - - actual = TraceShifts.apply(testee) + # λ(x, y) → ·{x, y}[1] + testee = im.lambda_("x", "y")(im.deref(im.tuple_get(1, im.make_tuple("x", "y")))) + expected = [ + set(), # never derefed + {()}, # once derefed + ] + + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_trace_non_closure_input_arg(): x, y = im.sym("x"), im.sym("y") - testee = ir.StencilClosure( - # λ(x) → (λ(y) → ·⟪Iₒ, 1ₒ⟫(y))(⟪Iₒ, 2ₒ⟫(x)) - stencil=im.lambda_(x)( - im.call(im.lambda_(y)(im.deref(im.shift("I", 1)("y"))))(im.shift("I", 2)("x")) - ), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(x) → (λ(y) → ·⟪Iₒ, 1ₒ⟫(y))(⟪Iₒ, 2ₒ⟫(x)) + testee = im.lambda_(x)( + im.call(im.lambda_(y)(im.deref(im.shift("I", 1)("y"))))(im.shift("I", 2)("x")) ) - actual = TraceShifts.apply(testee, inputs_only=False) + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) - assert actual[id(x)] == { + assert x.annex.recorded_shifts == { ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -181,78 +113,59 @@ def test_trace_non_closure_input_arg(): ir.OffsetLiteral(value=1), ) } - assert actual[id(y)] == {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} + assert y.annex.recorded_shifts == {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} def test_inner_iterator(): inner_shift = im.shift("I", 1)("x") - testee = ir.StencilClosure( - # λ(x) → ·⟪Iₒ, 1ₒ⟫(⟪Iₒ, 1ₒ⟫(x)) - stencil=im.lambda_("x")(im.deref(im.shift("I", 1)(inner_shift))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) + # λ(x) → ·⟪Iₒ, 1ₒ⟫(⟪Iₒ, 1ₒ⟫(x)) + testee = im.lambda_("x")(im.deref(im.shift("I", 1)(inner_shift))) expected = {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))} - actual = TraceShifts.apply(testee, inputs_only=False) - assert actual[id(inner_shift)] == expected + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) + assert inner_shift.annex.recorded_shifts == expected def test_tuple_get_on_closure_input(): - testee = ir.StencilClosure( - # λ(x) → (·⟪Iₒ, 1ₒ⟫(x))[0] - stencil=im.lambda_("x")(im.tuple_get(0, im.deref(im.shift("I", 1)("x")))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) - expected = {"inp": {(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}} + # λ(x) → (·⟪Iₒ, 1ₒ⟫(x))[0] + testee = im.lambda_("x")(im.tuple_get(0, im.deref(im.shift("I", 1)("x")))) + expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_tuple_branch_broadcasting(): - testee = ir.StencilClosure( - # λ(cond, inp) → (if ·cond then ·inp else {1, 2})[1] - stencil=im.lambda_("cond", "inp")( - im.tuple_get( - 1, - im.if_( - im.deref("cond"), - im.deref("inp"), - im.make_tuple(im.literal_from_value(1), im.literal_from_value(2)), - ), - ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(cond, inp) → (if ·cond then ·inp else {1, 2})[1] + testee = im.lambda_("cond", "inp")( + im.tuple_get( + 1, + im.if_( + im.deref("cond"), + im.deref("inp"), + im.make_tuple(im.literal_from_value(1), im.literal_from_value(2)), + ), + ) ) - expected = {"cond": {()}, "inp": {()}} + expected = [ + {()}, # cond + {()}, # inp + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_of_iterators(): - testee = ir.StencilClosure( - # λ(cond, x) → ·⟪Iₒ, 1ₒ⟫(if ·cond then ⟪Iₒ, 2ₒ⟫(x) else ⟪Iₒ, 3ₒ⟫(x)) - stencil=im.lambda_("cond", "x")( - im.deref( - im.shift("I", 1)( - im.if_(im.deref("cond"), im.shift("I", 2)("x"), im.shift("I", 3)("x")) - ) - ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + # λ(cond, x) → ·⟪Iₒ, 1ₒ⟫(if ·cond then ⟪Iₒ, 2ₒ⟫(x) else ⟪Iₒ, 3ₒ⟫(x)) + testee = im.lambda_("cond", "x")( + im.deref( + im.shift("I", 1)(im.if_(im.deref("cond"), im.shift("I", 2)("x"), im.shift("I", 3)("x"))) + ) ) - expected = { - "cond": {()}, - "inp": { + expected = [ + {()}, # cond + { # inp ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -266,37 +179,32 @@ def test_if_of_iterators(): ir.OffsetLiteral(value=1), ), }, - } + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected def test_if_of_tuples_of_iterators(): - testee = ir.StencilClosure( - # λ(cond, x) → - # ·⟪Iₒ, 1ₒ⟫((if ·cond then {⟪Iₒ, 2ₒ⟫(x), ⟪Iₒ, 3ₒ⟫(x)} else {⟪Iₒ, 4ₒ⟫(x), ⟪Iₒ, 5ₒ⟫(x)})[0]) - stencil=im.lambda_("cond", "x")( - im.deref( - im.shift("I", 1)( - im.tuple_get( - 0, - im.if_( - im.deref("cond"), - im.make_tuple(im.shift("I", 2)("x"), im.shift("I", 3)("x")), - im.make_tuple(im.shift("I", 4)("x"), im.shift("I", 5)("x")), - ), - ) + # λ(cond, x) → + # ·⟪Iₒ, 1ₒ⟫((if ·cond then {⟪Iₒ, 2ₒ⟫(x), ⟪Iₒ, 3ₒ⟫(x)} else {⟪Iₒ, 4ₒ⟫(x), ⟪Iₒ, 5ₒ⟫(x)})[0]) + testee = im.lambda_("cond", "x")( + im.deref( + im.shift("I", 1)( + im.tuple_get( + 0, + im.if_( + im.deref("cond"), + im.make_tuple(im.shift("I", 2)("x"), im.shift("I", 3)("x")), + im.make_tuple(im.shift("I", 4)("x"), im.shift("I", 5)("x")), + ), ) ) - ), - inputs=[ir.SymRef(id="cond"), ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), + ) ) - expected = { - "cond": {()}, - "inp": { + expected = [ + {()}, # cond + { # inp ( ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=2), @@ -310,9 +218,9 @@ def test_if_of_tuples_of_iterators(): ir.OffsetLiteral(value=1), ), }, - } + ] - actual = TraceShifts.apply(testee) + actual = TraceShifts.trace_stencil(testee) assert actual == expected @@ -321,13 +229,8 @@ def test_non_derefed_iterator(): Test that even if an iterator is not derefed the resulting dict has an (empty) entry for it. """ non_derefed_it = im.shift("I", 1)("x") - testee = ir.StencilClosure( - # λ(x) → (λ(non_derefed_it) → ·x)(⟪Iₒ, 1ₒ⟫(x)) - stencil=im.lambda_("x")(im.let("non_derefed_it", non_derefed_it)(im.deref("x"))), - inputs=[ir.SymRef(id="inp")], - output=ir.SymRef(id="out"), - domain=ir.FunCall(fun=ir.SymRef(id="cartesian_domain"), args=[]), - ) + # λ(x) → (λ(non_derefed_it) → ·x)(⟪Iₒ, 1ₒ⟫(x)) + testee = im.lambda_("x")(im.let("non_derefed_it", non_derefed_it)(im.deref("x"))) - actual = TraceShifts.apply(testee, inputs_only=False) - assert actual[id(non_derefed_it)] == set() + actual = TraceShifts.trace_stencil(testee, save_to_annex=True) + assert non_derefed_it.annex.recorded_shifts == set() From d0808d8f4786697f1d8a85ecf4b4e22741c335b4 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 28 Jul 2024 19:49:18 +0200 Subject: [PATCH 2/7] Cleanup --- .../next/iterator/transforms/trace_shifts.py | 35 +++---------------- .../runners/dace_iterator/__init__.py | 12 ++++--- .../iterator_tests/test_type_inference.py | 1 - .../transforms_tests/test_trace_shifts.py | 19 ---------- 4 files changed, 12 insertions(+), 55 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 99ed6f72ab..baec4a4579 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -338,9 +338,6 @@ def visit_StencilClosure(self, node: ir.StencilClosure): assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) return node - def visit_SetAt(self, node: ir.SetAt, *, ctx: dict[str, Any]) -> None: - self.visit(node.expr, ctx=ctx) - def initialize_context(self, inputs: Iterable[ir.Sym | ir.SymRef]) -> dict[str, Any]: ctx: dict[str, Any] = {**_START_CTX} for inp in inputs: @@ -361,10 +358,15 @@ def trace_stencil( args = [im.ref(f"__arg{i}") for i in range(num_args)] + old_recursionlimit = sys.getrecursionlimit() + sys.setrecursionlimit(100000000) + instance = cls() ctx = instance.initialize_context(args) instance.visit(im.call(stencil)(*args), ctx=ctx) + sys.setrecursionlimit(old_recursionlimit) + recorded_shifts = instance.shift_recorder.recorded_shifts param_shifts = [] @@ -376,33 +378,6 @@ def trace_stencil( return param_shifts - @classmethod - def trace_program(cls, program: ir.Program, save_to_annex=False): - old_recursionlimit = sys.getrecursionlimit() - sys.setrecursionlimit(100000000) - - instance = cls() - ctx = instance.initialize_context(program.params) - - for stmt in program.body: - assert isinstance(stmt, ir.SetAt) - instance.visit(stmt, ctx=ctx) - - sys.setrecursionlimit(old_recursionlimit) - - recorded_shifts = instance.shift_recorder.recorded_shifts - - if save_to_annex: - _save_to_annex(program, recorded_shifts) - - param_shifts: dict[str, set[tuple[ir.OffsetLiteral, ...]]] = {} - for param in program.params: - param_shifts[str(param.id)] = recorded_shifts[id(param)] - - return param_shifts - - -trace_program = TraceShifts.trace_program trace_stencil = TraceShifts.trace_stencil diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index eeca53ea0d..230224e2da 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -398,13 +398,15 @@ def __sdfg__(self, *args, **kwargs) -> dace.sdfg.sdfg.SDFG: self.itir, offset_provider=offset_provider ) for closure in itir_tmp.closures: # type: ignore[union-attr] - shifts = itir_transforms.trace_shifts.TraceShifts.apply(closure) - for k, v in shifts.items(): - if not isinstance(k, str): + params_shifts = itir_transforms.trace_shifts.trace_stencil( + closure.stencil, num_args=len(closure.inputs) + ) + for param, shifts in zip(closure.inputs, params_shifts): + if not isinstance(param.id, str): continue - if k not in sdfg.gt4py_program_input_fields: + if param.id not in sdfg.gt4py_program_input_fields: continue - sdfg.offset_providers_per_input_field.setdefault(k, []).extend(list(v)) + sdfg.offset_providers_per_input_field.setdefault(param.id, []).extend(list(shifts)) return sdfg diff --git a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py index 1263ef65c5..7b5d3e6a2f 100644 --- a/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/test_type_inference.py @@ -233,7 +233,6 @@ def test_aliased_function(): assert result.type == int_type - def test_late_offset_axis(): mesh = simple_mesh() diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py index fad4097f47..369b1d7c15 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_trace_shifts.py @@ -23,25 +23,6 @@ def test_trivial_stencil(): assert actual == expected -def test_trivial_program(): - set_at = ir.SetAt( - expr=im.as_fieldop("deref")("inp"), - domain=im.call("cartesian_domain")(), - target=im.ref("out"), - ) - testee = ir.Program( - id="testee", - function_definitions=[], - params=[im.sym("inp"), im.sym("out")], - declarations=[], - body=[set_at], - ) - expected = {"inp": {()}, "out": set()} - - actual = TraceShifts.trace_program(testee) - assert actual == expected - - def test_shift(): testee = im.lambda_("inp")(im.deref(im.shift("I", 1)("inp"))) expected = [{(ir.OffsetLiteral(value="I"), ir.OffsetLiteral(value=1))}] From 5a973e038ede5bc9d4c9c6a9c2cbfacda0973f50 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Sun, 28 Jul 2024 20:19:18 +0200 Subject: [PATCH 3/7] Cleanup --- src/gt4py/next/iterator/transforms/trace_shifts.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index baec4a4579..930f244025 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -327,17 +327,6 @@ def fun(*args): return fun - # FIXME[#1582](tehrengruber): remove after refactoring to GTIR - def visit_StencilClosure(self, node: ir.StencilClosure): - tracers = [] - for inp in node.inputs: - self.shift_recorder.register_node(inp) - tracers.append(ArgTracer(arg=inp, shift_recorder=self.shift_recorder)) - - result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) - assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) - return node - def initialize_context(self, inputs: Iterable[ir.Sym | ir.SymRef]) -> dict[str, Any]: ctx: dict[str, Any] = {**_START_CTX} for inp in inputs: From 6ff3338f5fb53da96cbf34882eccc404a9f75e17 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 16 Sep 2024 15:55:24 +0200 Subject: [PATCH 4/7] Address review comments --- .../next/iterator/transforms/trace_shifts.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 930f244025..d2dee97b63 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -327,17 +327,11 @@ def fun(*args): return fun - def initialize_context(self, inputs: Iterable[ir.Sym | ir.SymRef]) -> dict[str, Any]: - ctx: dict[str, Any] = {**_START_CTX} - for inp in inputs: - self.shift_recorder.register_node(inp) - ctx[inp.id] = ArgTracer(arg=inp, shift_recorder=self.shift_recorder) - return ctx - @classmethod def trace_stencil( cls, stencil: ir.Expr, *, num_args: Optional[int] = None, save_to_annex: bool = False ): + # If we get a lambda we can deduce the number of arguments. if isinstance(stencil, ir.Lambda): assert num_args is None or num_args == len(stencil.params) num_args = len(stencil.params) @@ -351,7 +345,14 @@ def trace_stencil( sys.setrecursionlimit(100000000) instance = cls() - ctx = instance.initialize_context(args) + + # initialize context with all built-ins and the iterator argument tracers + ctx: dict[str, Any] = {**_START_CTX} + for arg in args: + instance.shift_recorder.register_node(arg) + ctx[arg.id] = ArgTracer(arg=arg, shift_recorder=instance.shift_recorder) + + # actually trace stencil instance.visit(im.call(stencil)(*args), ctx=ctx) sys.setrecursionlimit(old_recursionlimit) From b180be98acc81b9437ba8888b818d6ad8fc62808 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Mon, 16 Sep 2024 15:57:57 +0200 Subject: [PATCH 5/7] Address review comments --- src/gt4py/next/iterator/transforms/trace_shifts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index d2dee97b63..af078987bc 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -346,7 +346,7 @@ def trace_stencil( instance = cls() - # initialize context with all built-ins and the iterator argument tracers + # initialize shift recorder & context with all built-ins and the iterator argument tracers ctx: dict[str, Any] = {**_START_CTX} for arg in args: instance.shift_recorder.register_node(arg) From 8f5f7dcc63ed21d5d653f607fc4d286acaa5b659 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 17 Sep 2024 18:06:13 +0200 Subject: [PATCH 6/7] Fix infer_domain --- .../next/iterator/transforms/infer_domain.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index 2d465eb3b0..bfb38c8162 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -12,7 +12,7 @@ from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union -from gt4py.next.iterator.transforms.trace_shifts import TraceShifts +from gt4py.next.iterator.transforms import trace_shifts def _merge_domains( @@ -27,20 +27,6 @@ def _merge_domains( return new_domains - -# FIXME[#1582](tehrengruber): Use new TraceShift API when #1592 is merged. -def trace_shifts( - stencil: itir.Expr, input_ids: list[str], domain: itir.Expr -) -> dict[str, set[tuple[itir.OffsetLiteral, ...]]]: - node = itir.StencilClosure( - stencil=stencil, - inputs=[im.ref(id_) for id_ in input_ids], - output=im.ref("__dummy"), - domain=domain, - ) - return TraceShifts.apply(node, inputs_only=True) # type: ignore[return-value] # ensured by inputs_only=True - - def extract_shifts_and_translate_domains( stencil: itir.Expr, input_ids: list[str], @@ -48,11 +34,12 @@ def extract_shifts_and_translate_domains( offset_provider: Dict[str, Dimension], accessed_domains: Dict[str, SymbolicDomain], ): - shifts_results = trace_shifts(stencil, input_ids, SymbolicDomain.as_expr(target_domain)) - - for in_field_id in input_ids: - shifts_list = shifts_results[in_field_id] + shifts_results = trace_shifts.trace_stencil( + stencil, + num_args=len(input_ids) + ) + for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [ SymbolicDomain.translate(target_domain, shift, offset_provider) for shift in shifts_list ] From 77c900bf602b43235d3850a71465f3d980beaa96 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 17 Sep 2024 18:06:52 +0200 Subject: [PATCH 7/7] Fix format --- src/gt4py/next/iterator/transforms/infer_domain.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index bfb38c8162..e05c58e157 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -11,8 +11,8 @@ from gt4py.next.common import Dimension from gt4py.next.iterator import ir as itir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im -from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union from gt4py.next.iterator.transforms import trace_shifts +from gt4py.next.iterator.transforms.global_tmps import AUTO_DOMAIN, SymbolicDomain, domain_union def _merge_domains( @@ -27,6 +27,7 @@ def _merge_domains( return new_domains + def extract_shifts_and_translate_domains( stencil: itir.Expr, input_ids: list[str], @@ -34,10 +35,7 @@ def extract_shifts_and_translate_domains( offset_provider: Dict[str, Dimension], accessed_domains: Dict[str, SymbolicDomain], ): - shifts_results = trace_shifts.trace_stencil( - stencil, - num_args=len(input_ids) - ) + shifts_results = trace_shifts.trace_stencil(stencil, num_args=len(input_ids)) for in_field_id, shifts_list in zip(input_ids, shifts_results, strict=True): new_domains = [