@@ -140,16 +140,23 @@ def get_shape_args(
140
140
return shape_args
141
141
142
142
143
- def get_offset_args (
144
- sdfg : dace .SDFG ,
145
- args : Sequence [Any ],
146
- ) -> Mapping [str , int ]:
143
+ def get_offset_args (sdfg : dace .SDFG , args : Sequence [Any ]) -> Mapping [str , int ]:
147
144
sdfg_arrays : Mapping [str , dace .data .Array ] = sdfg .arrays
148
145
sdfg_params : Sequence [str ] = sdfg .arg_names
146
+ fied_args = {param : arg for param , arg in zip (sdfg_params , args ) if common .is_field (arg )}
147
+
148
+ # assume that arrays for connectivity tables do not use offset
149
+ assert all (
150
+ drange .start == 0
151
+ for sdfg_param , arg in fied_args .items ()
152
+ if sdfg_param .startswith ("__connectivity" )
153
+ for drange in arg .domain .ranges
154
+ )
155
+
149
156
return {
150
157
str (sym ): - drange .start
151
- for sdfg_param , arg in zip ( sdfg_params , args )
152
- if common . is_field ( arg )
158
+ for sdfg_param , arg in fied_args . items ( )
159
+ if not sdfg_param . startswith ( "__connectivity" )
153
160
for sym , drange in zip (sdfg_arrays [sdfg_param ].offset , get_sorted_dim_ranges (arg .domain ))
154
161
}
155
162
@@ -235,15 +242,13 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) ->
235
242
dace_shapes = get_shape_args (sdfg .arrays , dace_field_args )
236
243
dace_conn_shapes = get_shape_args (sdfg .arrays , dace_conn_args )
237
244
dace_strides = get_stride_args (sdfg .arrays , dace_field_args )
238
- dace_conn_strides = get_stride_args (sdfg .arrays , dace_conn_args )
239
245
dace_offsets = get_offset_args (sdfg , args )
240
246
all_args = {
241
247
** dace_args ,
242
248
** dace_conn_args ,
243
249
** dace_shapes ,
244
250
** dace_conn_shapes ,
245
251
** dace_strides ,
246
- ** dace_conn_strides ,
247
252
** dace_offsets ,
248
253
}
249
254
0 commit comments