Skip to content

Commit ac0478a

Browse files
authored
fix[next][dace]: Use constant shape for neighbor tables in local dimension (#1422)
Main purpose of this PR is to avoid the definition of shape symbols for array dimensions known at compile time. The local size of neighbor connectivity tables falls into this category. For each element in the origin dimension, the number of elements in the target dimension is defined by the attribute max_neighbors in the offset provider.
1 parent 11f9c1c commit ac0478a

File tree

4 files changed

+160
-83
lines changed

4 files changed

+160
-83
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+58-38
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import hashlib
1515
import warnings
1616
from inspect import currentframe, getframeinfo
17+
from pathlib import Path
1718
from typing import Any, Mapping, Optional, Sequence
1819

1920
import dace
@@ -26,7 +27,7 @@
2627
import gt4py.next.program_processors.otf_compile_executor as otf_exec
2728
import gt4py.next.program_processors.processor_interface as ppi
2829
from gt4py.next import common
29-
from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms
30+
from gt4py.next.iterator import transforms as itir_transforms
3031
from gt4py.next.otf.compilation import cache as compilation_cache
3132
from gt4py.next.type_system import type_specifications as ts, type_translation
3233

@@ -109,23 +110,29 @@ def _ensure_is_on_device(
109110

110111

111112
def get_connectivity_args(
112-
neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]],
113+
neighbor_tables: Mapping[str, common.NeighborTable],
113114
device: dace.dtypes.DeviceType,
114115
) -> dict[str, Any]:
115116
return {
116-
connectivity_identifier(offset): _ensure_is_on_device(table.table, device)
117-
for offset, table in neighbor_tables
117+
connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device)
118+
for offset, offset_provider in neighbor_tables.items()
118119
}
119120

120121

121122
def get_shape_args(
122123
arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any]
123124
) -> Mapping[str, int]:
124-
return {
125-
str(sym): size
126-
for name, value in args.items()
127-
for sym, size in zip(arrays[name].shape, value.shape)
128-
}
125+
shape_args: dict[str, int] = {}
126+
for name, value in args.items():
127+
for sym, size in zip(arrays[name].shape, value.shape):
128+
if isinstance(sym, dace.symbol):
129+
assert sym.name not in shape_args
130+
shape_args[sym.name] = size
131+
elif sym != size:
132+
raise RuntimeError(
133+
f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}."
134+
)
135+
return shape_args
129136

130137

131138
def get_offset_args(
@@ -158,34 +165,41 @@ def get_stride_args(
158165
return stride_args
159166

160167

161-
_build_cache_cpu: dict[str, CompiledSDFG] = {}
162-
_build_cache_gpu: dict[str, CompiledSDFG] = {}
168+
_build_cache: dict[str, CompiledSDFG] = {}
163169

164170

165171
def get_cache_id(
172+
build_type: str,
173+
build_for_gpu: bool,
166174
program: itir.FencilDefinition,
167175
arg_types: Sequence[ts.TypeSpec],
168176
column_axis: Optional[common.Dimension],
169177
offset_provider: Mapping[str, Any],
170178
) -> str:
171-
max_neighbors = [
172-
(k, v.max_neighbors)
173-
for k, v in offset_provider.items()
174-
if isinstance(
175-
v,
176-
(
177-
itir_embedded.NeighborTableOffsetProvider,
178-
itir_embedded.StridedNeighborOffsetProvider,
179-
),
180-
)
179+
def offset_invariants(offset):
180+
if isinstance(offset, common.Connectivity):
181+
return (
182+
offset.origin_axis,
183+
offset.neighbor_axis,
184+
offset.has_skip_values,
185+
offset.max_neighbors,
186+
)
187+
if isinstance(offset, common.Dimension):
188+
return (offset,)
189+
return tuple()
190+
191+
offset_cache_keys = [
192+
(name, *offset_invariants(offset)) for name, offset in offset_provider.items()
181193
]
182194
cache_id_args = [
183195
str(arg)
184196
for arg in (
197+
build_type,
198+
build_for_gpu,
185199
program,
186200
*arg_types,
187201
column_axis,
188-
*max_neighbors,
202+
*offset_cache_keys,
189203
)
190204
]
191205
m = hashlib.sha256()
@@ -262,7 +276,7 @@ def build_sdfg_from_itir(
262276
# visit ITIR and generate SDFG
263277
program = preprocess_program(program, offset_provider, lift_mode)
264278
sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis)
265-
sdfg = sdfg_genenerator.visit(program)
279+
sdfg: dace.SDFG = sdfg_genenerator.visit(program)
266280
if sdfg is None:
267281
raise RuntimeError(f"Visit failed for program {program.id}.")
268282

@@ -284,8 +298,8 @@ def build_sdfg_from_itir(
284298

285299
# run DaCe auto-optimization heuristics
286300
if auto_optimize:
287-
# TODO: Investigate how symbol definitions improve autoopt transformations,
288-
# in which case the cache table should take the symbols map into account.
301+
# TODO: Investigate performance improvement from SDFG specialization with constant symbols,
302+
# for array shape and strides, although this would imply JIT compilation.
289303
symbols: dict[str, int] = {}
290304
device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU
291305
sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu)
@@ -307,25 +321,31 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs):
307321
# ITIR parameters
308322
column_axis = kwargs.get("column_axis", None)
309323
offset_provider = kwargs["offset_provider"]
324+
# debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run
325+
skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False)
310326

311327
arg_types = [type_translation.from_value(arg) for arg in args]
312328

313-
cache_id = get_cache_id(program, arg_types, column_axis, offset_provider)
329+
cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider)
314330
if build_cache is not None and cache_id in build_cache:
315331
# retrieve SDFG program from build cache
316332
sdfg_program = build_cache[cache_id]
317333
sdfg = sdfg_program.sdfg
318-
319334
else:
320-
sdfg = build_sdfg_from_itir(
321-
program,
322-
*args,
323-
offset_provider=offset_provider,
324-
auto_optimize=auto_optimize,
325-
on_gpu=on_gpu,
326-
column_axis=column_axis,
327-
lift_mode=lift_mode,
328-
)
335+
sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg"
336+
if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()):
337+
sdfg = build_sdfg_from_itir(
338+
program,
339+
*args,
340+
offset_provider=offset_provider,
341+
auto_optimize=auto_optimize,
342+
on_gpu=on_gpu,
343+
column_axis=column_axis,
344+
lift_mode=lift_mode,
345+
)
346+
sdfg.save(sdfg_filename)
347+
else:
348+
sdfg = dace.SDFG.from_file(sdfg_filename)
329349

330350
sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache"
331351
with dace.config.temporary_config():
@@ -361,7 +381,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
361381
program,
362382
*args,
363383
**kwargs,
364-
build_cache=_build_cache_cpu,
384+
build_cache=_build_cache,
365385
build_type=_build_type,
366386
compiler_args=compiler_args,
367387
on_gpu=False,
@@ -380,7 +400,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None:
380400
program,
381401
*args,
382402
**kwargs,
383-
build_cache=_build_cache_gpu,
403+
build_cache=_build_cache,
384404
build_type=_build_type,
385405
on_gpu=True,
386406
)

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+66-19
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
1212
#
1313
# SPDX-License-Identifier: GPL-3.0-or-later
14-
from typing import Any, Optional, cast
14+
from typing import Any, Mapping, Optional, Sequence, cast
1515

1616
import dace
1717

1818
import gt4py.eve as eve
1919
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
20+
from gt4py.next.common import NeighborTable
2021
from gt4py.next.iterator import ir as itir, type_inference as itir_typing
21-
from gt4py.next.iterator.embedded import NeighborTableOffsetProvider
2222
from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef
2323
from gt4py.next.type_system import type_specifications as ts, type_translation
2424

@@ -43,13 +43,12 @@
4343
flatten_list,
4444
get_sorted_dims,
4545
map_nested_sdfg_symbols,
46-
new_array_symbols,
4746
unique_name,
4847
unique_var_name,
4948
)
5049

5150

52-
def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
51+
def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
5352
"""
5453
Parse stencil expression to extract the scan arguments.
5554
@@ -68,7 +67,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]:
6867
return is_forward.value == "True", init_carry
6968

7069

71-
def get_scan_dim(
70+
def _get_scan_dim(
7271
column_axis: Dimension,
7372
storage_types: dict[str, ts.TypeSpec],
7473
output: SymRef,
@@ -93,6 +92,35 @@ def get_scan_dim(
9392
)
9493

9594

95+
def _make_array_shape_and_strides(
96+
name: str,
97+
dims: Sequence[Dimension],
98+
neighbor_tables: Mapping[str, NeighborTable],
99+
sort_dims: bool,
100+
) -> tuple[list[dace.symbol], list[dace.symbol]]:
101+
"""
102+
Parse field dimensions and allocate symbols for array shape and strides.
103+
104+
For local dimensions, the size is known at compile-time and therefore
105+
the corresponding array shape dimension is set to an integer literal value.
106+
107+
Returns
108+
-------
109+
tuple(shape, strides)
110+
The output tuple fields are arrays of dace symbolic expressions.
111+
"""
112+
dtype = dace.int64
113+
sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims
114+
shape = [
115+
neighbor_tables[dim.value].max_neighbors
116+
if dim.kind == DimensionKind.LOCAL
117+
else dace.symbol(unique_name(f"{name}_shape{i}"), dtype)
118+
for i, dim in enumerate(sorted_dims)
119+
]
120+
strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)]
121+
return shape, strides
122+
123+
96124
class ItirToSDFG(eve.NodeVisitor):
97125
param_types: list[ts.TypeSpec]
98126
storage_types: dict[str, ts.TypeSpec]
@@ -104,17 +132,27 @@ class ItirToSDFG(eve.NodeVisitor):
104132
def __init__(
105133
self,
106134
param_types: list[ts.TypeSpec],
107-
offset_provider: dict[str, NeighborTableOffsetProvider],
135+
offset_provider: dict[str, NeighborTable],
108136
column_axis: Optional[Dimension] = None,
109137
):
110138
self.param_types = param_types
111139
self.column_axis = column_axis
112140
self.offset_provider = offset_provider
113141
self.storage_types = {}
114142

115-
def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True):
143+
def add_storage(
144+
self,
145+
sdfg: dace.SDFG,
146+
name: str,
147+
type_: ts.TypeSpec,
148+
neighbor_tables: Mapping[str, NeighborTable],
149+
has_offset: bool = True,
150+
sort_dimensions: bool = True,
151+
):
116152
if isinstance(type_, ts.FieldType):
117-
shape, strides = new_array_symbols(name, len(type_.dims))
153+
shape, strides = _make_array_shape_and_strides(
154+
name, type_.dims, neighbor_tables, sort_dimensions
155+
)
118156
offset = (
119157
[dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))]
120158
if has_offset
@@ -153,14 +191,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
153191

154192
# Add program parameters as SDFG storages.
155193
for param, type_ in zip(node.params, self.param_types):
156-
self.add_storage(program_sdfg, str(param.id), type_)
194+
self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables)
157195

158196
# Add connectivities as SDFG storages.
159-
for offset, table in neighbor_tables:
160-
scalar_kind = type_translation.get_scalar_kind(table.table.dtype)
161-
local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL)
162-
type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind))
163-
self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False)
197+
for offset, offset_provider in neighbor_tables.items():
198+
scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype)
199+
local_dim = Dimension(offset, kind=DimensionKind.LOCAL)
200+
type_ = ts.FieldType(
201+
[offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind)
202+
)
203+
self.add_storage(
204+
program_sdfg,
205+
connectivity_identifier(offset),
206+
type_,
207+
neighbor_tables,
208+
has_offset=False,
209+
sort_dimensions=False,
210+
)
164211

165212
# Create a nested SDFG for all stencil closures.
166213
for closure in node.closures:
@@ -222,7 +269,7 @@ def visit_StencilClosure(
222269

223270
input_names = [str(inp.id) for inp in node.inputs]
224271
neighbor_tables = filter_neighbor_tables(self.offset_provider)
225-
connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
272+
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]
226273

227274
output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state)
228275
output_names = [k for k, _ in output_nodes.items()]
@@ -400,11 +447,11 @@ def _visit_scan_stencil_closure(
400447
output_name: str,
401448
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]:
402449
# extract scan arguments
403-
is_forward, init_carry_value = get_scan_args(node.stencil)
450+
is_forward, init_carry_value = _get_scan_args(node.stencil)
404451
# select the scan dimension based on program argument for column axis
405452
assert self.column_axis
406453
assert isinstance(node.output, SymRef)
407-
scan_dim, scan_dim_index, scan_dtype = get_scan_dim(
454+
scan_dim, scan_dim_index, scan_dtype = _get_scan_dim(
408455
self.column_axis,
409456
self.storage_types,
410457
node.output,
@@ -570,7 +617,7 @@ def _visit_parallel_stencil_closure(
570617
) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]:
571618
neighbor_tables = filter_neighbor_tables(self.offset_provider)
572619
input_names = [str(inp.id) for inp in node.inputs]
573-
conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables]
620+
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]
574621

575622
# find the scan dimension, same as output dimension, and exclude it from the map domain
576623
map_ranges = {}
@@ -583,7 +630,7 @@ def _visit_parallel_stencil_closure(
583630
index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain}
584631

585632
input_arrays = [(name, self.storage_types[name]) for name in input_names]
586-
connectivity_arrays = [(array_table[name], name) for name in conn_names]
633+
connectivity_arrays = [(array_table[name], name) for name in connectivity_names]
587634

588635
context, results = closure_to_tasklet_sdfg(
589636
node,

0 commit comments

Comments
 (0)