Skip to content

Commit ae9c203

Browse files
authored
feat[next][dace]: make canonical representation of field domain optional (#1476)
Baseline implementation of DaCe backend was reordering the dimensions in field domain based on alphabetical order. This is the canonical representation of field domain, and provides the advantage of not requiring regenerating the SDFG for different memory layouts of field arguments. Besides, the code for accessing a field is simple, because all field domains are assumed to follow the same layout. However, the canonical representation poses an obstacle to the realization of module-level SDFGs, because it requires an additional conversion step of all array arguments before calling the SDFG. Therefore, we make the canonical representation optional. Note that this change should not have any performance impact, because the real memory layout of field arrays is not modified.
1 parent 77a205b commit ae9c203

File tree

4 files changed

+166
-98
lines changed

4 files changed

+166
-98
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+58-38
Original file line numberDiff line numberDiff line change
@@ -42,38 +42,37 @@
4242
cp = None
4343

4444

45-
def get_sorted_dim_ranges(domain: common.Domain) -> Sequence[common.FiniteUnitRange]:
46-
assert common.Domain.is_finite(domain)
47-
sorted_dims = get_sorted_dims(domain.dims)
48-
return [domain.ranges[dim_index] for dim_index, _ in sorted_dims]
49-
50-
5145
""" Default build configuration in DaCe backend """
5246
_build_type = "Release"
53-
54-
55-
def convert_arg(arg: Any, sdfg_param: str):
56-
if common.is_field(arg):
57-
# field domain offsets are not supported
58-
non_zero_offsets = [
59-
(dim, dim_range)
60-
for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges)
61-
if dim_range.start != 0
62-
]
63-
if non_zero_offsets:
64-
dim, dim_range = non_zero_offsets[0]
65-
raise RuntimeError(
66-
f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}."
67-
)
68-
sorted_dims = get_sorted_dims(arg.domain.dims)
69-
ndim = len(sorted_dims)
70-
dim_indices = [dim_index for dim_index, _ in sorted_dims]
71-
if isinstance(arg.ndarray, np.ndarray):
72-
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
73-
else:
74-
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
75-
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)
76-
return arg
47+
_default_on_gpu = False
48+
_default_use_field_canonical_representation = False
49+
50+
51+
def convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation: bool):
52+
if not common.is_field(arg):
53+
return arg
54+
# field domain offsets are not supported
55+
non_zero_offsets = [
56+
(dim, dim_range)
57+
for dim, dim_range in zip(arg.domain.dims, arg.domain.ranges)
58+
if dim_range.start != 0
59+
]
60+
if non_zero_offsets:
61+
dim, dim_range = non_zero_offsets[0]
62+
raise RuntimeError(
63+
f"Field '{sdfg_param}' passed as array slice with offset {dim_range.start} on dimension {dim.value}."
64+
)
65+
if not use_field_canonical_representation:
66+
return arg.ndarray
67+
# the canonical representation requires alphabetical ordering of the dimensions in field domain definition
68+
sorted_dims = get_sorted_dims(arg.domain.dims)
69+
ndim = len(sorted_dims)
70+
dim_indices = [dim_index for dim_index, _ in sorted_dims]
71+
if isinstance(arg.ndarray, np.ndarray):
72+
return np.moveaxis(arg.ndarray, range(ndim), dim_indices)
73+
else:
74+
assert cp is not None and isinstance(arg.ndarray, cp.ndarray)
75+
return cp.moveaxis(arg.ndarray, range(ndim), dim_indices)
7776

7877

7978
def preprocess_program(
@@ -107,9 +106,14 @@ def preprocess_program(
107106
return fencil_definition, tmps
108107

109108

110-
def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:
109+
def get_args(
110+
sdfg: dace.SDFG, args: Sequence[Any], use_field_canonical_representation: bool
111+
) -> dict[str, Any]:
111112
sdfg_params: Sequence[str] = sdfg.arg_names
112-
return {sdfg_param: convert_arg(arg, sdfg_param) for sdfg_param, arg in zip(sdfg_params, args)}
113+
return {
114+
sdfg_param: convert_arg(arg, sdfg_param, use_field_canonical_representation)
115+
for sdfg_param, arg in zip(sdfg_params, args)
116+
}
113117

114118

115119
def _ensure_is_on_device(
@@ -162,8 +166,13 @@ def get_stride_args(
162166
raise ValueError(
163167
f"Stride ({stride_size} bytes) for argument '{sym}' must be a multiple of item size ({value.itemsize} bytes)."
164168
)
165-
stride_args[str(sym)] = stride
166-
169+
if isinstance(sym, dace.symbol):
170+
assert sym.name not in stride_args
171+
stride_args[str(sym)] = stride
172+
elif sym != stride:
173+
raise RuntimeError(
174+
f"Expected stride {arrays[name].strides} for arg {name}, got {value.strides}."
175+
)
167176
return stride_args
168177

169178

@@ -221,12 +230,15 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) ->
221230
sdfg: The SDFG for which we want to get the arguments.
222231
"""
223232
offset_provider = kwargs["offset_provider"]
224-
on_gpu = kwargs.get("on_gpu", False)
233+
on_gpu = kwargs.get("on_gpu", _default_on_gpu)
234+
use_field_canonical_representation = kwargs.get(
235+
"use_field_canonical_representation", _default_use_field_canonical_representation
236+
)
225237

226238
neighbor_tables = filter_neighbor_tables(offset_provider)
227239
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
228240

229-
dace_args = get_args(sdfg, args)
241+
dace_args = get_args(sdfg, args, use_field_canonical_representation)
230242
dace_field_args = {n: v for n, v in dace_args.items() if not np.isscalar(v)}
231243
dace_conn_args = get_connectivity_args(neighbor_tables, device)
232244
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
@@ -261,6 +273,7 @@ def build_sdfg_from_itir(
261273
load_sdfg_from_file: bool = False,
262274
cache_id: Optional[str] = None,
263275
save_sdfg: bool = True,
276+
use_field_canonical_representation: bool = True,
264277
) -> dace.SDFG:
265278
"""Translate a Fencil into an SDFG.
266279
@@ -275,6 +288,7 @@ def build_sdfg_from_itir(
275288
load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only.
276289
cache_id: The id of the cache entry, used to disambiguate stored sdfgs.
277290
save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`.
291+
use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically.
278292
279293
Notes:
280294
Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored.
@@ -292,7 +306,9 @@ def build_sdfg_from_itir(
292306

293307
# visit ITIR and generate SDFG
294308
program, tmps = preprocess_program(program, offset_provider, lift_mode)
295-
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, tmps, column_axis)
309+
sdfg_genenerator = ItirToSDFG(
310+
arg_types, offset_provider, tmps, use_field_canonical_representation, column_axis
311+
)
296312
sdfg = sdfg_genenerator.visit(program)
297313
if sdfg is None:
298314
raise RuntimeError(f"Visit failed for program {program.id}.")
@@ -343,9 +359,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
343359
build_cache = kwargs.get("build_cache", None)
344360
compiler_args = kwargs.get("compiler_args", None) # `None` will take default.
345361
build_type = kwargs.get("build_type", "RelWithDebInfo")
346-
on_gpu = kwargs.get("on_gpu", False)
362+
on_gpu = kwargs.get("on_gpu", _default_on_gpu)
347363
auto_optimize = kwargs.get("auto_optimize", True)
348364
lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE)
365+
use_field_canonical_representation = kwargs.get(
366+
"use_field_canonical_representation", _default_use_field_canonical_representation
367+
)
349368
# ITIR parameters
350369
column_axis = kwargs.get("column_axis", None)
351370
offset_provider = kwargs["offset_provider"]
@@ -374,6 +393,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
374393
load_sdfg_from_file=load_sdfg_from_file,
375394
cache_id=cache_id,
376395
save_sdfg=save_sdfg,
396+
use_field_canonical_representation=use_field_canonical_representation,
377397
)
378398

379399
sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache"

0 commit comments

Comments
 (0)