Skip to content

Commit e51ab30

Browse files
authored
bug[next]: Increase recursion limit in TraceShift pass (#1482)
The TraceShift pass has a very high recursion depth. This PR temporarily increases the allowed recursion depth during the pass. Fixes compilation failures in Icon4Py.
1 parent 66f8447 commit e51ab30

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/gt4py/next/iterator/transforms/trace_shifts.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# SPDX-License-Identifier: GPL-3.0-or-later
1414
import dataclasses
1515
import enum
16+
import sys
1617
from collections.abc import Callable
1718
from typing import Any, Final, Iterable, Literal
1819

@@ -263,6 +264,8 @@ def _tuple_get(index, tuple_val):
263264
}
264265

265266

267+
# TODO(tehrengruber): This pass is unnecessarily very inefficient and easily exceeds the default
268+
# recursion limit.
266269
@dataclasses.dataclass(frozen=True)
267270
class TraceShifts(PreserveLocationVisitor, NodeTranslator):
268271
shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder)
@@ -329,16 +332,22 @@ def visit_StencilClosure(self, node: ir.StencilClosure):
329332

330333
result = self.visit(node.stencil, ctx=_START_CTX)(*tracers)
331334
assert all(el is Sentinel.VALUE for el in _primitive_constituents(result))
335+
return node
332336

333337
@classmethod
334338
def apply(
335-
cls, node: ir.StencilClosure, *, inputs_only=True, save_to_annex=False
339+
cls, node: ir.StencilClosure | ir.FencilDefinition, *, inputs_only=True, save_to_annex=False
336340
) -> (
337341
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]]
338342
):
343+
old_recursionlimit = sys.getrecursionlimit()
344+
sys.setrecursionlimit(100000000)
345+
339346
instance = cls()
340347
instance.visit(node)
341348

349+
sys.setrecursionlimit(old_recursionlimit)
350+
342351
recorded_shifts = instance.shift_recorder.recorded_shifts
343352

344353
if save_to_annex:
@@ -348,6 +357,7 @@ def apply(
348357
ValidateRecordedShiftsAnnex().visit(node)
349358

350359
if inputs_only:
360+
assert isinstance(node, ir.StencilClosure)
351361
inputs_shifts = {}
352362
for inp in node.inputs:
353363
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]

0 commit comments

Comments
 (0)