@@ -140,16 +140,23 @@ def get_shape_args(
140
140
return shape_args
141
141
142
142
143
- def get_offset_args (
144
- sdfg : dace .SDFG ,
145
- args : Sequence [Any ],
146
- ) -> Mapping [str , int ]:
143
+ def get_offset_args (sdfg : dace .SDFG , args : Sequence [Any ]) -> Mapping [str , int ]:
147
144
sdfg_arrays : Mapping [str , dace .data .Array ] = sdfg .arrays
148
145
sdfg_params : Sequence [str ] = sdfg .arg_names
146
+ field_args = {param : arg for param , arg in zip (sdfg_params , args ) if common .is_field (arg )}
147
+
148
+ # assume that arrays for connectivity tables do not use offset
149
+ assert all (
150
+ drange .start == 0
151
+ for sdfg_param , arg in field_args .items ()
152
+ if sdfg_param .startswith ("__connectivity" )
153
+ for drange in arg .domain .ranges
154
+ )
155
+
149
156
return {
150
157
str (sym ): - drange .start
151
- for sdfg_param , arg in zip ( sdfg_params , args )
152
- if common . is_field ( arg )
158
+ for sdfg_param , arg in field_args . items ( )
159
+ if not sdfg_param . startswith ( "__connectivity" )
153
160
for sym , drange in zip (sdfg_arrays [sdfg_param ].offset , get_sorted_dim_ranges (arg .domain ))
154
161
}
155
162
@@ -331,6 +338,8 @@ def build_sdfg_from_itir(
331
338
symbols : dict [str , int ] = {}
332
339
device = dace .DeviceType .GPU if on_gpu else dace .DeviceType .CPU
333
340
sdfg = autoopt .auto_optimize (sdfg , device , symbols = symbols , use_gpu_storage = on_gpu )
341
+ elif on_gpu :
342
+ autoopt .apply_gpu_storage (sdfg )
334
343
335
344
if on_gpu :
336
345
sdfg .apply_gpu_transformations ()
0 commit comments