Skip to content

Commit e462a2e

Browse files
authored
feat[next][dace]: Add support for lift expressions in neighbor reductions (no unrolling) (#1431)
Baseline dace backend forced unroll of neighbor reductions, in the ITIR pass, in order to eliminate all lift expressions. This PR adds support for lowering of lift expressions in neighbor reductions, thus avoiding the need to unroll reduce expressions. The result is a more compact SDFG, which leaves to the optimization backend the option of unrolling neighbor reductions.
1 parent 0d158ad commit e462a2e

File tree

3 files changed

+205
-68
lines changed

3 files changed

+205
-68
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,16 @@ def preprocess_program(
6969
program: itir.FencilDefinition,
7070
offset_provider: Mapping[str, Any],
7171
lift_mode: itir_transforms.LiftMode,
72+
unroll_reduce: bool = False,
7273
):
73-
node = itir_transforms.apply_common_transforms(
74+
return itir_transforms.apply_common_transforms(
7475
program,
7576
common_subexpression_elimination=False,
77+
force_inline_lambda_args=True,
7678
lift_mode=lift_mode,
7779
offset_provider=offset_provider,
78-
unroll_reduce=False,
80+
unroll_reduce=unroll_reduce,
7981
)
80-
# If we don't unroll, there may be lifts left in the itir which can't be lowered to SDFG.
81-
# In this case, just retry with unrolled reductions.
82-
if all([ItirToSDFG._check_no_lifts(closure) for closure in node.closures]):
83-
fencil_definition = node
84-
else:
85-
fencil_definition = itir_transforms.apply_common_transforms(
86-
program,
87-
common_subexpression_elimination=False,
88-
force_inline_lambda_args=True,
89-
lift_mode=lift_mode,
90-
offset_provider=offset_provider,
91-
unroll_reduce=True,
92-
)
93-
return fencil_definition
9482

9583

9684
def get_args(sdfg: dace.SDFG, args: Sequence[Any]) -> dict[str, Any]:

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,24 @@ def _make_array_shape_and_strides(
124124
return shape, strides
125125

126126

127+
def _check_no_lifts(node: itir.StencilClosure):
128+
"""
129+
Parse stencil closure ITIR to check that lift expressions only appear as child nodes in neighbor reductions.
130+
131+
Returns
132+
-------
133+
True if lifts do not appear in the ITIR exception lift expressions in neighbor reductions. False otherwise.
134+
"""
135+
neighbors_call_count = 0
136+
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun"):
137+
if getattr(fun, "id", "") == "neighbors":
138+
neighbors_call_count = 3
139+
elif getattr(fun, "id", "") == "lift" and neighbors_call_count != 1:
140+
return False
141+
neighbors_call_count = max(0, neighbors_call_count - 1)
142+
return True
143+
144+
127145
class ItirToSDFG(eve.NodeVisitor):
128146
param_types: list[ts.TypeSpec]
129147
storage_types: dict[str, ts.TypeSpec]
@@ -262,7 +280,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition):
262280
def visit_StencilClosure(
263281
self, node: itir.StencilClosure, array_table: dict[str, dace.data.Array]
264282
) -> tuple[dace.SDFG, list[str], list[str]]:
265-
assert ItirToSDFG._check_no_lifts(node)
283+
assert _check_no_lifts(node)
266284

267285
# Create the closure's nested SDFG and single state.
268286
closure_sdfg = dace.SDFG(name="closure")
@@ -681,15 +699,6 @@ def _visit_domain(
681699

682700
return tuple(sorted(bounds, key=lambda item: item[0]))
683701

684-
@staticmethod
685-
def _check_no_lifts(node: itir.StencilClosure):
686-
if any(
687-
getattr(fun, "id", "") == "lift"
688-
for fun in eve.walk_values(node).if_isinstance(itir.FunCall).getattr("fun")
689-
):
690-
return False
691-
return True
692-
693702
@staticmethod
694703
def _check_shift_offsets_are_literals(node: itir.StencilClosure):
695704
fun_calls = eve.walk_values(node).if_isinstance(itir.FunCall)

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py

+182-42
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,126 @@ def __init__(
181181
self.reduce_identity = reduce_identity
182182

183183

184+
def _visit_lift_in_neighbors_reduction(
185+
transformer: "PythonTaskletCodegen",
186+
node: itir.FunCall,
187+
node_args: Sequence[IteratorExpr | list[ValueExpr]],
188+
offset_provider: NeighborTableOffsetProvider,
189+
map_entry: dace.nodes.MapEntry,
190+
map_exit: dace.nodes.MapExit,
191+
neighbor_index_node: dace.nodes.AccessNode,
192+
neighbor_value_node: dace.nodes.AccessNode,
193+
) -> list[ValueExpr]:
194+
neighbor_dim = offset_provider.neighbor_axis.value
195+
origin_dim = offset_provider.origin_axis.value
196+
197+
lifted_args: list[IteratorExpr | ValueExpr] = []
198+
for arg in node_args:
199+
if isinstance(arg, IteratorExpr):
200+
if origin_dim in arg.indices:
201+
lifted_indices = arg.indices.copy()
202+
lifted_indices.pop(origin_dim)
203+
lifted_indices[neighbor_dim] = neighbor_index_node
204+
lifted_args.append(
205+
IteratorExpr(
206+
arg.field,
207+
lifted_indices,
208+
arg.dtype,
209+
arg.dimensions,
210+
)
211+
)
212+
else:
213+
lifted_args.append(arg)
214+
else:
215+
lifted_args.append(arg[0])
216+
217+
lift_context, inner_inputs, inner_outputs = transformer.visit(node.args[0], args=lifted_args)
218+
assert len(inner_outputs) == 1
219+
inner_out_connector = inner_outputs[0].value.data
220+
221+
input_nodes = {}
222+
iterator_index_nodes = {}
223+
lifted_index_connectors = set()
224+
225+
for x, y in inner_inputs:
226+
if isinstance(y, IteratorExpr):
227+
field_connector, inner_index_table = x
228+
input_nodes[field_connector] = y.field
229+
for dim, connector in inner_index_table.items():
230+
if dim == neighbor_dim:
231+
lifted_index_connectors.add(connector)
232+
iterator_index_nodes[connector] = y.indices[dim]
233+
else:
234+
assert isinstance(y, ValueExpr)
235+
input_nodes[x] = y.value
236+
237+
neighbor_tables = filter_neighbor_tables(transformer.offset_provider)
238+
connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()]
239+
240+
parent_sdfg = transformer.context.body
241+
parent_state = transformer.context.state
242+
243+
input_mapping = {
244+
connector: create_memlet_full(node.data, node.desc(parent_sdfg))
245+
for connector, node in input_nodes.items()
246+
}
247+
connectivity_mapping = {
248+
name: create_memlet_full(name, parent_sdfg.arrays[name]) for name in connectivity_names
249+
}
250+
array_mapping = {**input_mapping, **connectivity_mapping}
251+
symbol_mapping = map_nested_sdfg_symbols(parent_sdfg, lift_context.body, array_mapping)
252+
253+
nested_sdfg_node = parent_state.add_nested_sdfg(
254+
lift_context.body,
255+
parent_sdfg,
256+
inputs={*array_mapping.keys(), *iterator_index_nodes.keys()},
257+
outputs={inner_out_connector},
258+
symbol_mapping=symbol_mapping,
259+
debuginfo=lift_context.body.debuginfo,
260+
)
261+
262+
for connectivity_connector, memlet in connectivity_mapping.items():
263+
parent_state.add_memlet_path(
264+
parent_state.add_access(memlet.data, debuginfo=lift_context.body.debuginfo),
265+
map_entry,
266+
nested_sdfg_node,
267+
dst_conn=connectivity_connector,
268+
memlet=memlet,
269+
)
270+
271+
for inner_connector, access_node in input_nodes.items():
272+
parent_state.add_memlet_path(
273+
access_node,
274+
map_entry,
275+
nested_sdfg_node,
276+
dst_conn=inner_connector,
277+
memlet=input_mapping[inner_connector],
278+
)
279+
280+
for inner_connector, access_node in iterator_index_nodes.items():
281+
memlet = dace.Memlet(data=access_node.data, subset="0")
282+
if inner_connector in lifted_index_connectors:
283+
parent_state.add_edge(access_node, None, nested_sdfg_node, inner_connector, memlet)
284+
else:
285+
parent_state.add_memlet_path(
286+
access_node,
287+
map_entry,
288+
nested_sdfg_node,
289+
dst_conn=inner_connector,
290+
memlet=memlet,
291+
)
292+
293+
parent_state.add_memlet_path(
294+
nested_sdfg_node,
295+
map_exit,
296+
neighbor_value_node,
297+
src_conn=inner_out_connector,
298+
memlet=dace.Memlet(data=neighbor_value_node.data, subset=",".join(map_entry.params)),
299+
)
300+
301+
return [ValueExpr(neighbor_value_node, inner_outputs[0].dtype)]
302+
303+
184304
def builtin_neighbors(
185305
transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr]
186306
) -> list[ValueExpr]:
@@ -198,7 +318,16 @@ def builtin_neighbors(
198318
"Neighbor reduction only implemented for connectivity based on neighbor tables."
199319
)
200320

201-
iterator = transformer.visit(data)
321+
lift_node = None
322+
if isinstance(data, FunCall):
323+
assert isinstance(data.fun, itir.FunCall)
324+
fun_node = data.fun
325+
if isinstance(fun_node.fun, itir.SymRef) and fun_node.fun.id == "lift":
326+
lift_node = fun_node
327+
lift_args = transformer.visit(data.args)
328+
iterator = next(filter(lambda x: isinstance(x, IteratorExpr), lift_args), None)
329+
if lift_node is None:
330+
iterator = transformer.visit(data)
202331
assert isinstance(iterator, IteratorExpr)
203332
field_desc = iterator.field.desc(transformer.context.body)
204333
origin_index_node = iterator.indices[offset_provider.origin_axis.value]
@@ -259,44 +388,56 @@ def builtin_neighbors(
259388
dace.Memlet(data=neighbor_index_var, subset="0"),
260389
)
261390

262-
data_access_tasklet = state.add_tasklet(
263-
"data_access",
264-
code="__data = __field[__idx]"
265-
+ (
266-
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
267-
if offset_provider.has_skip_values
268-
else ""
269-
),
270-
inputs={"__field", "__idx"},
271-
outputs={"__data"},
272-
debuginfo=di,
273-
)
274-
# select full shape only in the neighbor-axis dimension
275-
field_subset = tuple(
276-
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
277-
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
278-
)
279-
state.add_memlet_path(
280-
iterator.field,
281-
me,
282-
data_access_tasklet,
283-
memlet=create_memlet_at(iterator.field.data, field_subset),
284-
dst_conn="__field",
285-
)
286-
state.add_edge(
287-
neighbor_index_node,
288-
None,
289-
data_access_tasklet,
290-
"__idx",
291-
dace.Memlet(data=neighbor_index_var, subset="0"),
292-
)
293-
state.add_memlet_path(
294-
data_access_tasklet,
295-
mx,
296-
neighbor_value_node,
297-
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
298-
src_conn="__data",
299-
)
391+
if lift_node is not None:
392+
_visit_lift_in_neighbors_reduction(
393+
transformer,
394+
lift_node,
395+
lift_args,
396+
offset_provider,
397+
me,
398+
mx,
399+
neighbor_index_node,
400+
neighbor_value_node,
401+
)
402+
else:
403+
data_access_tasklet = state.add_tasklet(
404+
"data_access",
405+
code="__data = __field[__idx]"
406+
+ (
407+
f" if __idx != {neighbor_skip_value} else {transformer.context.reduce_identity.value}"
408+
if offset_provider.has_skip_values
409+
else ""
410+
),
411+
inputs={"__field", "__idx"},
412+
outputs={"__data"},
413+
debuginfo=di,
414+
)
415+
# select full shape only in the neighbor-axis dimension
416+
field_subset = tuple(
417+
f"0:{shape}" if dim == offset_provider.neighbor_axis.value else f"i_{dim}"
418+
for dim, shape in zip(sorted(iterator.dimensions), field_desc.shape)
419+
)
420+
state.add_memlet_path(
421+
iterator.field,
422+
me,
423+
data_access_tasklet,
424+
memlet=create_memlet_at(iterator.field.data, field_subset),
425+
dst_conn="__field",
426+
)
427+
state.add_edge(
428+
neighbor_index_node,
429+
None,
430+
data_access_tasklet,
431+
"__idx",
432+
dace.Memlet(data=neighbor_index_var, subset="0"),
433+
)
434+
state.add_memlet_path(
435+
data_access_tasklet,
436+
mx,
437+
neighbor_value_node,
438+
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
439+
src_conn="__data",
440+
)
300441

301442
if not offset_provider.has_skip_values:
302443
return [ValueExpr(neighbor_value_node, iterator.dtype)]
@@ -377,9 +518,8 @@ def builtin_can_deref(
377518
# create tasklet to check that field indices are non-negative (-1 is invalid)
378519
args = [ValueExpr(access_node, _INDEX_DTYPE) for access_node in iterator.indices.values()]
379520
internals = [f"{arg.value.data}_v" for arg in args]
380-
expr_code = " and ".join([f"{v} >= 0" for v in internals])
521+
expr_code = " and ".join(f"{v} != {neighbor_skip_value}" for v in internals)
381522

382-
# TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
383523
return transformer.add_expr_tasklet(
384524
list(zip(args, internals)),
385525
expr_code,
@@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
9461086
iterator = self.visit(node.args[0])
9471087
if not isinstance(iterator, IteratorExpr):
9481088
# shift cannot be applied because the argument is not iterable
949-
# TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
1089+
# TODO: remove this special case when ITIR pass is able to catch it
9501090
assert isinstance(iterator, list) and len(iterator) == 1
9511091
assert isinstance(iterator[0], ValueExpr)
9521092
return iterator

0 commit comments

Comments
 (0)