From 1cf1016bc5166036410204cb1d2f142463250261 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 Mar 2024 12:48:03 +0100 Subject: [PATCH 1/4] Increase recursion limit in TraceShift pass --- src/gt4py/next/iterator/transforms/trace_shifts.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 20692d7338..9992a788bf 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses import enum +import sys from collections.abc import Callable from typing import Any, Final, Iterable, Literal @@ -263,6 +264,8 @@ def _tuple_get(index, tuple_val): } +# TODO(tehrengruber): This pass is unnecessarily very inefficient and easily exceeds the default +# recursion limit. @dataclasses.dataclass(frozen=True) class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) @@ -329,16 +332,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure): result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) + return node @classmethod def apply( - cls, node: ir.StencilClosure, *, inputs_only=True, save_to_annex=False + cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False ) -> ( dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] ): + old_recursionlimit = sys.getrecursionlimit() + sys.setrecursionlimit(100000000) + instance = cls() instance.visit(node) + sys.setrecursionlimit(old_recursionlimit) + recorded_shifts = instance.shift_recorder.recorded_shifts if save_to_annex: From 4d407b4f8a5026c0ab3eb8e6d3d18c1afa134165 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 Mar 2024 12:48:43 +0100 Subject: [PATCH 2/4] Fix pre-commit --- src/gt4py/next/iterator/transforms/trace_shifts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 9992a788bf..1b62a8a02e 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -357,6 +357,7 @@ def apply( ValidateRecordedShiftsAnnex().visit(node) if inputs_only: + assert isinstance(node, ir.StencilClosure) inputs_shifts = {} for inp in node.inputs: inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)] From 00ef8bcba42bb622d63d1ad0e0e644163ab6cb5f Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 Mar 2024 13:19:21 +0100 Subject: [PATCH 3/4] Retrigger CI From 63200afea11b572f6d38702e3e2c3fb0f43731b5 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Wed, 6 Mar 2024 15:35:37 +0100 Subject: [PATCH 4/4] Retrigger CI