Skip to content

Commit 02a4839

Browse files
authored
fix[next][dace]: accept runtime lift_mode as argument to dace backend (#1481)
Fix previous PR #1477. Like for gtfn backend, we still need to accept the runtime `lift_mode` passed as keyword argument to the backend.
1 parent 628a33b commit 02a4839

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/workflow.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import dataclasses
18+
import warnings
1819
from typing import Callable, Optional, cast
1920

2021
import dace
@@ -68,7 +69,22 @@ def __call__(
6869

6970
# ITIR parameters
7071
column_axis: Optional[Dimension] = inp.kwargs.get("column_axis", None)
71-
offset_provider = inp.kwargs["offset_provider"]
72+
offset_provider: dict[str, common.Dimension | common.Connectivity] = inp.kwargs[
73+
"offset_provider"
74+
]
75+
runtime_lift_mode: Optional[LiftMode] = inp.kwargs.get("lift_mode", None)
76+
77+
# TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added
78+
# to the interface of all (or at least all of concern) backends, but instead should be
79+
# configured in the backend itself (like it is here), until then we respect the argument
80+
# here and warn the user if it differs from the one configured.
81+
lift_mode = runtime_lift_mode or self.lift_mode
82+
if runtime_lift_mode and runtime_lift_mode != self.lift_mode:
83+
warnings.warn(
84+
f"DaCe Backend was configured for LiftMode `{self.lift_mode!s}`, but "
85+
f"overriden to be {runtime_lift_mode!s} at runtime.",
86+
stacklevel=2,
87+
)
7288

7389
sdfg = build_sdfg_from_itir(
7490
program,
@@ -77,7 +93,7 @@ def __call__(
7793
auto_optimize=self.auto_optimize,
7894
on_gpu=on_gpu,
7995
column_axis=column_axis,
80-
lift_mode=self.lift_mode,
96+
lift_mode=lift_mode,
8197
load_sdfg_from_file=False,
8298
save_sdfg=False,
8399
use_field_canonical_representation=self.use_field_canonical_representation,

src/gt4py/next/program_processors/runners/gtfn.py

+2
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int:
111111
content_hash(tuple(from_value(arg) for arg in otf_closure.args)),
112112
id(offset_provider) if offset_provider else None,
113113
otf_closure.kwargs.get("column_axis", None),
114+
# TODO(tehrengruber): Remove `lift_mode` from call interface.
115+
otf_closure.kwargs.get("lift_mode", None),
114116
))
115117

116118

0 commit comments

Comments
 (0)