Skip to content

Commit df19dcf

Browse files
committed
[dace] Remove stride and offset symbols for connectity arrays
1 parent e0a8734 commit df19dcf

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,23 @@ def get_shape_args(
140140
return shape_args
141141

142142

143-
def get_offset_args(
144-
sdfg: dace.SDFG,
145-
args: Sequence[Any],
146-
) -> Mapping[str, int]:
143+
def get_offset_args(sdfg: dace.SDFG, args: Sequence[Any]) -> Mapping[str, int]:
147144
sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays
148145
sdfg_params: Sequence[str] = sdfg.arg_names
146+
fied_args = {param: arg for param, arg in zip(sdfg_params, args) if common.is_field(arg)}
147+
148+
# assume that arrays for connectivity tables do not use offset
149+
assert all(
150+
drange.start == 0
151+
for sdfg_param, arg in fied_args.items()
152+
if sdfg_param.startswith("__connectivity")
153+
for drange in arg.domain.ranges
154+
)
155+
149156
return {
150157
str(sym): -drange.start
151-
for sdfg_param, arg in zip(sdfg_params, args)
152-
if common.is_field(arg)
158+
for sdfg_param, arg in fied_args.items()
159+
if not sdfg_param.startswith("__connectivity")
153160
for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain))
154161
}
155162

@@ -235,15 +242,13 @@ def get_sdfg_args(sdfg: dace.SDFG, *args, check_args: bool = False, **kwargs) ->
235242
dace_shapes = get_shape_args(sdfg.arrays, dace_field_args)
236243
dace_conn_shapes = get_shape_args(sdfg.arrays, dace_conn_args)
237244
dace_strides = get_stride_args(sdfg.arrays, dace_field_args)
238-
dace_conn_strides = get_stride_args(sdfg.arrays, dace_conn_args)
239245
dace_offsets = get_offset_args(sdfg, args)
240246
all_args = {
241247
**dace_args,
242248
**dace_conn_args,
243249
**dace_shapes,
244250
**dace_conn_shapes,
245251
**dace_strides,
246-
**dace_conn_strides,
247252
**dace_offsets,
248253
}
249254

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def add_storage(
189189
name,
190190
shape=shape,
191191
strides=strides,
192-
offset=(offset if has_offset else None),
192+
offset=(offset if has_offset and name not in neighbor_tables else None),
193193
dtype=dtype,
194194
)
195195

0 commit comments

Comments
 (0)