Skip to content

Commit 6509dd9

Browse files
authored
feat[next][dace]: DaCe support for temporaries (#1351)
Temporaries are implemented in DaCe backend as transient arrays. This PR adds extraction of temporaries and generation of corresponding transient arrays in the SDFG representation.
1 parent e462a2e commit 6509dd9

File tree

3 files changed

+124
-10
lines changed

3 files changed

+124
-10
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+27-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from dace.codegen.compiled_sdfg import CompiledSDFG
2323
from dace.sdfg import utils as sdutils
2424
from dace.transformation.auto import auto_optimize as autoopt
25+
from dace.transformation.interstate import RefineNestedAccess
2526

2627
import gt4py.next.allocators as next_allocators
2728
import gt4py.next.iterator.ir as itir
@@ -71,7 +72,7 @@ def preprocess_program(
7172
lift_mode: itir_transforms.LiftMode,
7273
unroll_reduce: bool = False,
7374
):
74-
return itir_transforms.apply_common_transforms(
75+
node = itir_transforms.apply_common_transforms(
7576
program,
7677
common_subexpression_elimination=False,
7778
force_inline_lambda_args=True,
@@ -80,6 +81,21 @@ def preprocess_program(
8081
unroll_reduce=unroll_reduce,
8182
)
8283

84+
if isinstance(node, itir_transforms.global_tmps.FencilWithTemporaries):
85+
fencil_definition = node.fencil
86+
tmps = node.tmps
87+
88+
elif isinstance(node, itir.FencilDefinition):
89+
fencil_definition = node
90+
tmps = []
91+
92+
else:
93+
raise TypeError(
94+
f"Expected 'FencilDefinition' or 'FencilWithTemporaries', got '{type(program).__name__}'."
95+
)
96+
97+
return fencil_definition, tmps
98+
8399

84100
def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
85101
sdfg_params: Sequence[str] = sdfg.arg_names
@@ -160,6 +176,7 @@ def get_stride_args(
160176
def get_cache_id(
161177
build_type: str,
162178
build_for_gpu: bool,
179+
lift_mode: itir_transforms.LiftMode,
163180
program: itir.FencilDefinition,
164181
arg_types: Sequence[ts.TypeSpec],
165182
column_axis: Optional[common.Dimension],
@@ -185,6 +202,7 @@ def offset_invariants(offset):
185202
for arg in (
186203
build_type,
187204
build_for_gpu,
205+
lift_mode,
188206
program,
189207
*arg_types,
190208
column_axis,
@@ -272,17 +290,17 @@ def build_sdfg_from_itir(
272290
sdfg.validate()
273291
return sdfg
274292

275-
# TODO(edopao): As temporary fix until temporaries are supported in the DaCe Backend force
276-
# `lift_more` to `FORCE_INLINE` mode.
277-
lift_mode = itir_transforms.LiftMode.FORCE_INLINE
278293
arg_types = [type_translation.from_value(arg) for arg in args]
279294

280295
# visit ITIR and generate SDFG
281-
program = preprocess_program(program, offset_provider, lift_mode)
282-
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
296+
program, tmps = preprocess_program(program, offset_provider, lift_mode)
297+
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
283298
sdfg = sdfg_genenerator.visit(program)
284299
if sdfg is None:
285300
raise RuntimeError(f"Visit failed for program {program.id}.")
301+
elif tmps:
302+
# This pass is needed to avoid transformation errors in SDFG inlining, because temporaries are using offsets
303+
sdfg.apply_transformations_repeated(RefineNestedAccess)
286304

287305
for nested_sdfg in sdfg.all_sdfgs_recursive():
288306
if not nested_sdfg.debuginfo:
@@ -338,7 +356,9 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
338356

339357
arg_types = [type_translation.from_value(arg) for arg in args]
340358

341-
cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider)
359+
cache_id = get_cache_id(
360+
build_type, on_gpu, lift_mode, program, arg_types, column_axis, offset_provider
361+
)
342362
if build_cache is not None and cache_id in build_cache:
343363
# retrieve SDFG program from build cache
344364
sdfg_program = build_cache[cache_id]

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+89-3
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import gt4py.eve as eve
2020
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
2121
from gt4py.next.common import NeighborTable
22-
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
23-
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
22+
from gt4py.next.iterator import (
23+
ir as itir,
24+
transforms as itir_transforms,
25+
type_inference as itir_typing,
26+
)
27+
from gt4py.next.iterator.ir import Expr, FunCall, Literal, Sym, SymRef
2428
from gt4py.next.type_system import type_specifications as ts, type_translation
2529

2630
from .itir_to_tasklet import (
@@ -36,6 +40,7 @@
3640
from .utility import (
3741
add_mapped_nested_sdfg,
3842
as_dace_type,
43+
as_scalar_type,
3944
connectivity_identifier,
4045
create_memlet_at,
4146
create_memlet_full,
@@ -44,6 +49,7 @@
4449
flatten_list,
4550
get_sorted_dims,
4651
map_nested_sdfg_symbols,
52+
new_array_symbols,
4753
unique_name,
4854
unique_var_name,
4955
)
@@ -154,12 +160,14 @@ def __init__(
154160
self,
155161
param_types: list[ts.TypeSpec],
156162
offset_provider: dict[str, NeighborTable],
163+
tmps: list[itir_transforms.global_tmps.Temporary],
157164
column_axis: Optional[Dimension] = None,
158165
):
159166
self.param_types = param_types
160167
self.column_axis = column_axis
161168
self.offset_provider = offset_provider
162169
self.storage_types = {}
170+
self.tmps = tmps
163171

164172
def add_storage(
165173
self,
@@ -189,6 +197,70 @@ def add_storage(
189197
raise NotImplementedError()
190198
self.storage_types[name] = type_
191199

200+
def add_storage_for_temporaries(
201+
self, node_params: list[Sym], defs_state: dace.SDFGState, program_sdfg: dace.SDFG
202+
) -> dict[str, str]:
203+
symbol_map: dict[str, TaskletExpr] = {}
204+
# The shape of temporary arrays might be defined based on scalar values passed as program arguments.
205+
# Here we collect these values in a symbol map.
206+
tmp_ids = set(tmp.id for tmp in self.tmps)
207+
for sym in node_params:
208+
if sym.id not in tmp_ids and sym.kind != "Iterator":
209+
name_ = str(sym.id)
210+
type_ = self.storage_types[name_]
211+
assert isinstance(type_, ts.ScalarType)
212+
symbol_map[name_] = SymbolExpr(name_, as_dace_type(type_))
213+
214+
tmp_symbols: dict[str, str] = {}
215+
for tmp in self.tmps:
216+
tmp_name = str(tmp.id)
217+
218+
# We visit the domain of the temporary field, passing the set of available symbols.
219+
assert isinstance(tmp.domain, itir.FunCall)
220+
self.node_types.update(itir_typing.infer_all(tmp.domain))
221+
domain_ctx = Context(program_sdfg, defs_state, symbol_map)
222+
tmp_domain = self._visit_domain(tmp.domain, domain_ctx)
223+
224+
# We build the FieldType for this temporary array.
225+
dims: list[Dimension] = []
226+
for dim, _ in tmp_domain:
227+
dims.append(
228+
Dimension(
229+
value=dim,
230+
kind=(
231+
DimensionKind.VERTICAL
232+
if self.column_axis is not None and self.column_axis.value == dim
233+
else DimensionKind.HORIZONTAL
234+
),
235+
)
236+
)
237+
assert isinstance(tmp.dtype, str)
238+
type_ = ts.FieldType(dims=dims, dtype=as_scalar_type(tmp.dtype))
239+
self.storage_types[tmp_name] = type_
240+
241+
# N.B.: skip generation of symbolic strides and just let dace assign default strides, for now.
242+
# Another option, in the future, is to use symbolic strides and apply auto-tuning or some heuristics
243+
# to assign optimal stride values.
244+
tmp_shape, _ = new_array_symbols(tmp_name, len(dims))
245+
tmp_offset = [
246+
dace.symbol(unique_name(f"{tmp_name}_offset{i}")) for i in range(len(dims))
247+
]
248+
_, tmp_array = program_sdfg.add_array(
249+
tmp_name, tmp_shape, as_dace_type(type_.dtype), offset=tmp_offset, transient=True
250+
)
251+
252+
# Loop through all dimensions to visit the symbolic expressions for array shape and offset.
253+
# These expressions are later mapped to interstate symbols.
254+
for (_, (begin, end)), offset_sym, shape_sym in zip(
255+
tmp_domain,
256+
tmp_array.offset,
257+
tmp_array.shape,
258+
):
259+
tmp_symbols[str(offset_sym)] = f"0 - {begin.value}"
260+
tmp_symbols[str(shape_sym)] = f"{end.value} - {begin.value}"
261+
262+
return tmp_symbols
263+
192264
def get_output_nodes(
193265
self, closure: itir.StencilClosure, sdfg: dace.SDFG, state: dace.SDFGState
194266
) -> dict[str, dace.nodes.AccessNode]:
@@ -204,7 +276,7 @@ def get_output_nodes(
204276
def visit_FencilDefinition(self, node: itir.FencilDefinition):
205277
program_sdfg = dace.SDFG(name=node.id)
206278
program_sdfg.debuginfo = dace_debuginfo(node)
207-
last_state = program_sdfg.add_state("program_entry", True)
279+
entry_state = program_sdfg.add_state("program_entry", is_start_block=True)
208280
self.node_types = itir_typing.infer_all(node)
209281

210282
# Filter neighbor tables from offset providers.
@@ -214,6 +286,20 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
214286
for param, type_ in zip(node.params, self.param_types):
215287
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)
216288

289+
if self.tmps:
290+
tmp_symbols = self.add_storage_for_temporaries(node.params, entry_state, program_sdfg)
291+
# on the first interstate edge define symbols for shape and offsets of temporary arrays
292+
last_state = program_sdfg.add_state("init_symbols_for_temporaries")
293+
program_sdfg.add_edge(
294+
entry_state,
295+
last_state,
296+
dace.InterstateEdge(
297+
assignments=tmp_symbols,
298+
),
299+
)
300+
else:
301+
last_state = entry_state
302+
217303
# Add connectivities as SDFG storages.
218304
for offset, offset_provider in neighbor_tables.items():
219305
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)

src/gt4py/next/program_processors/runners/dace_iterator/utility.py

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ def as_dace_type(type_: ts.ScalarType):
5151
raise ValueError(f"Scalar type '{type_}' not supported.")
5252

5353

54+
def as_scalar_type(typestr: str) -> ts.ScalarType:
55+
try:
56+
kind = getattr(ts.ScalarKind, typestr.upper())
57+
except AttributeError:
58+
raise ValueError(f"Data type {typestr} not supported.")
59+
return ts.ScalarType(kind)
60+
61+
5462
def filter_neighbor_tables(offset_provider: dict[str, Any]):
5563
return {
5664
offset: table

0 commit comments

Comments
 (0)