diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 20692d7338..1b62a8a02e 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import dataclasses import enum +import sys from collections.abc import Callable from typing import Any, Final, Iterable, Literal @@ -263,6 +264,8 @@ def _tuple_get(index, tuple_val): } +# TODO(tehrengruber): This pass is unnecessarily very inefficient and easily exceeds the default +# recursion limit. @dataclasses.dataclass(frozen=True) class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) @@ -329,16 +332,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure): result = self.visit(node.stencil, ctx=_START_CTX)(*tracers) assert all(el is Sentinel.VALUE for el in _primitive_constituents(result)) + return node @classmethod def apply( - cls, node: ir.StencilClosure, *, inputs_only=True, save_to_annex=False + cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False ) -> ( dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]] ): + old_recursionlimit = sys.getrecursionlimit() + sys.setrecursionlimit(100000000) + instance = cls() instance.visit(node) + sys.setrecursionlimit(old_recursionlimit) + recorded_shifts = instance.shift_recorder.recorded_shifts if save_to_annex: @@ -348,6 +357,7 @@ def apply( ValidateRecordedShiftsAnnex().visit(node) if inputs_only: + assert isinstance(node, ir.StencilClosure) inputs_shifts = {} for inp in node.inputs: inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]