Skip to content

Commit ba353d3

Browse files
authored
feat[next][dace]: Remove offsets in connectivity arrays (#1460)
Remove generation of offset symbols for connectivity arrays.
1 parent e631c7f commit ba353d3

File tree

1 file changed

+15
-6
lines changed
  • src/gt4py/next/program_processors/runners/dace_iterator

1 file changed

+15
-6
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -140,16 +140,23 @@ def get_shape_args(
140140
return shape_args
141141

142142

143-
def get_offset_args(
144-
sdfg: dace.SDFG,
145-
args: Sequence[Any],
146-
) -> Mapping[str, int]:
143+
def get_offset_args(sdfg: dace.SDFG, args: Sequence[Any]) -> Mapping[str, int]:
147144
sdfg_arrays: Mapping[str, dace.data.Array] = sdfg.arrays
148145
sdfg_params: Sequence[str] = sdfg.arg_names
146+
field_args = {param: arg for param, arg in zip(sdfg_params, args) if common.is_field(arg)}
147+
148+
# assume that arrays for connectivity tables do not use offset
149+
assert all(
150+
drange.start == 0
151+
for sdfg_param, arg in field_args.items()
152+
if sdfg_param.startswith("__connectivity")
153+
for drange in arg.domain.ranges
154+
)
155+
149156
return {
150157
str(sym): -drange.start
151-
for sdfg_param, arg in zip(sdfg_params, args)
152-
if common.is_field(arg)
158+
for sdfg_param, arg in field_args.items()
159+
if not sdfg_param.startswith("__connectivity")
153160
for sym, drange in zip(sdfg_arrays[sdfg_param].offset, get_sorted_dim_ranges(arg.domain))
154161
}
155162

@@ -331,6 +338,8 @@ def build_sdfg_from_itir(
331338
symbols: dict[str, int] = {}
332339
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
333340
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
341+
elif on_gpu:
342+
autoopt.apply_gpu_storage(sdfg)
334343

335344
if on_gpu:
336345
sdfg.apply_gpu_transformations()

0 commit comments

Comments
 (0)