42
42
cp = None
43
43
44
44
45
- def get_sorted_dim_ranges (domain : common .Domain ) -> Sequence [common .FiniteUnitRange ]:
46
- assert common .Domain .is_finite (domain )
47
- sorted_dims = get_sorted_dims (domain .dims )
48
- return [domain .ranges [dim_index ] for dim_index , _ in sorted_dims ]
49
-
50
-
51
45
""" Default build configuration in DaCe backend """
52
46
_build_type = "Release"
53
-
54
-
55
- def convert_arg (arg : Any , sdfg_param : str ):
56
- if common .is_field (arg ):
57
- # field domain offsets are not supported
58
- non_zero_offsets = [
59
- (dim , dim_range )
60
- for dim , dim_range in zip (arg .domain .dims , arg .domain .ranges )
61
- if dim_range .start != 0
62
- ]
63
- if non_zero_offsets :
64
- dim , dim_range = non_zero_offsets [0 ]
65
- raise RuntimeError (
66
- f"Field '{ sdfg_param } ' passed as array slice with offset { dim_range .start } on dimension { dim .value } ."
67
- )
68
- sorted_dims = get_sorted_dims (arg .domain .dims )
69
- ndim = len (sorted_dims )
70
- dim_indices = [dim_index for dim_index , _ in sorted_dims ]
71
- if isinstance (arg .ndarray , np .ndarray ):
72
- return np .moveaxis (arg .ndarray , range (ndim ), dim_indices )
73
- else :
74
- assert cp is not None and isinstance (arg .ndarray , cp .ndarray )
75
- return cp .moveaxis (arg .ndarray , range (ndim ), dim_indices )
76
- return arg
47
+ _default_on_gpu = False
48
+ _default_use_field_canonical_representation = False
49
+
50
+
51
+ def convert_arg (arg : Any , sdfg_param : str , use_field_canonical_representation : bool ):
52
+ if not common .is_field (arg ):
53
+ return arg
54
+ # field domain offsets are not supported
55
+ non_zero_offsets = [
56
+ (dim , dim_range )
57
+ for dim , dim_range in zip (arg .domain .dims , arg .domain .ranges )
58
+ if dim_range .start != 0
59
+ ]
60
+ if non_zero_offsets :
61
+ dim , dim_range = non_zero_offsets [0 ]
62
+ raise RuntimeError (
63
+ f"Field '{ sdfg_param } ' passed as array slice with offset { dim_range .start } on dimension { dim .value } ."
64
+ )
65
+ if not use_field_canonical_representation :
66
+ return arg .ndarray
67
+ # the canonical representation requires alphabetical ordering of the dimensions in field domain definition
68
+ sorted_dims = get_sorted_dims (arg .domain .dims )
69
+ ndim = len (sorted_dims )
70
+ dim_indices = [dim_index for dim_index , _ in sorted_dims ]
71
+ if isinstance (arg .ndarray , np .ndarray ):
72
+ return np .moveaxis (arg .ndarray , range (ndim ), dim_indices )
73
+ else :
74
+ assert cp is not None and isinstance (arg .ndarray , cp .ndarray )
75
+ return cp .moveaxis (arg .ndarray , range (ndim ), dim_indices )
77
76
78
77
79
78
def preprocess_program (
@@ -107,9 +106,14 @@ def preprocess_program(
107
106
return fencil_definition , tmps
108
107
109
108
110
- def get_args (sdfg : dace .SDFG , args : Sequence [Any ]) -> dict [str , Any ]:
109
+ def get_args (
110
+ sdfg : dace .SDFG , args : Sequence [Any ], use_field_canonical_representation : bool
111
+ ) -> dict [str , Any ]:
111
112
sdfg_params : Sequence [str ] = sdfg .arg_names
112
- return {sdfg_param : convert_arg (arg , sdfg_param ) for sdfg_param , arg in zip (sdfg_params , args )}
113
+ return {
114
+ sdfg_param : convert_arg (arg , sdfg_param , use_field_canonical_representation )
115
+ for sdfg_param , arg in zip (sdfg_params , args )
116
+ }
113
117
114
118
115
119
def _ensure_is_on_device (
@@ -162,8 +166,13 @@ def get_stride_args(
162
166
raise ValueError (
163
167
f"Stride ({ stride_size } bytes) for argument '{ sym } ' must be a multiple of item size ({ value .itemsize } bytes)."
164
168
)
165
- stride_args [str (sym )] = stride
166
-
169
+ if isinstance (sym , dace .symbol ):
170
+ assert sym .name not in stride_args
171
+ stride_args [str (sym )] = stride
172
+ elif sym != stride :
173
+ raise RuntimeError (
174
+ f"Expected stride { arrays [name ].strides } for arg { name } , got { value .strides } ."
175
+ )
167
176
return stride_args
168
177
169
178
@@ -221,12 +230,15 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) ->
221
230
sdfg: The SDFG for which we want to get the arguments.
222
231
"""
223
232
offset_provider = kwargs ["offset_provider" ]
224
- on_gpu = kwargs .get ("on_gpu" , False )
233
+ on_gpu = kwargs .get ("on_gpu" , _default_on_gpu )
234
+ use_field_canonical_representation = kwargs .get (
235
+ "use_field_canonical_representation" , _default_use_field_canonical_representation
236
+ )
225
237
226
238
neighbor_tables = filter_neighbor_tables (offset_provider )
227
239
device = dace .DeviceType .GPU if on_gpu else dace .DeviceType .CPU
228
240
229
- dace_args = get_args (sdfg , args )
241
+ dace_args = get_args (sdfg , args , use_field_canonical_representation )
230
242
dace_field_args = {n : v for n , v in dace_args .items () if not np .isscalar (v )}
231
243
dace_conn_args = get_connectivity_args (neighbor_tables , device )
232
244
dace_shapes = get_shape_args (sdfg .arrays , dace_field_args )
@@ -261,6 +273,7 @@ def build_sdfg_from_itir(
261
273
load_sdfg_from_file : bool = False ,
262
274
cache_id : Optional [str ] = None ,
263
275
save_sdfg : bool = True ,
276
+ use_field_canonical_representation : bool = True ,
264
277
) -> dace .SDFG :
265
278
"""Translate a Fencil into an SDFG.
266
279
@@ -275,6 +288,7 @@ def build_sdfg_from_itir(
275
288
load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only.
276
289
cache_id: The id of the cache entry, used to disambiguate stored sdfgs.
277
290
save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`.
291
+ use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically.
278
292
279
293
Notes:
280
294
Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored.
@@ -292,7 +306,9 @@ def build_sdfg_from_itir(
292
306
293
307
# visit ITIR and generate SDFG
294
308
program , tmps = preprocess_program (program , offset_provider , lift_mode )
295
- sdfg_genenerator = ItirToSDFG (arg_types , offset_provider , tmps , column_axis )
309
+ sdfg_genenerator = ItirToSDFG (
310
+ arg_types , offset_provider , tmps , use_field_canonical_representation , column_axis
311
+ )
296
312
sdfg = sdfg_genenerator .visit (program )
297
313
if sdfg is None :
298
314
raise RuntimeError (f"Visit failed for program { program .id } ." )
@@ -343,9 +359,12 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
343
359
build_cache = kwargs .get ("build_cache" , None )
344
360
compiler_args = kwargs .get ("compiler_args" , None ) # `None` will take default.
345
361
build_type = kwargs .get ("build_type" , "RelWithDebInfo" )
346
- on_gpu = kwargs .get ("on_gpu" , False )
362
+ on_gpu = kwargs .get ("on_gpu" , _default_on_gpu )
347
363
auto_optimize = kwargs .get ("auto_optimize" , True )
348
364
lift_mode = kwargs .get ("lift_mode" , itir_transforms .LiftMode .FORCE_INLINE )
365
+ use_field_canonical_representation = kwargs .get (
366
+ "use_field_canonical_representation" , _default_use_field_canonical_representation
367
+ )
349
368
# ITIR parameters
350
369
column_axis = kwargs .get ("column_axis" , None )
351
370
offset_provider = kwargs ["offset_provider" ]
@@ -374,6 +393,7 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
374
393
load_sdfg_from_file = load_sdfg_from_file ,
375
394
cache_id = cache_id ,
376
395
save_sdfg = save_sdfg ,
396
+ use_field_canonical_representation = use_field_canonical_representation ,
377
397
)
378
398
379
399
sdfg .build_folder = compilation_cache ._session_cache_dir_path / ".dacecache"
0 commit comments