Skip to content

Commit 0a9be51

Browse files
authored
fix[next][dace]: Fixes to DaCe backend to support latest ITIR (#1499)
Fixes in DaCe backend to support latest ITIR: - Add support for tuple argument to lambda functions. - Flatten list of expressions in if-statememts Minor code cleanup: skip debug information for memlets.
1 parent b26e6a3 commit 0a9be51

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py

+43-16
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def builtin_neighbors(
401401
origin_index_node,
402402
me,
403403
shift_tasklet,
404-
memlet=dace.Memlet(data=origin_index_node.data, subset="0", debuginfo=di),
404+
memlet=dace.Memlet(data=origin_index_node.data, subset="0"),
405405
dst_conn="__idx",
406406
)
407407
state.add_edge(
@@ -469,7 +469,7 @@ def builtin_neighbors(
469469
data_access_tasklet,
470470
mx,
471471
neighbor_value_node,
472-
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index, debuginfo=di),
472+
memlet=dace.Memlet(data=neighbor_value_var, subset=neighbor_map_index),
473473
src_conn="__data",
474474
)
475475

@@ -496,6 +496,7 @@ def builtin_neighbors(
496496
{"__idx"},
497497
{"__valid"},
498498
f"__valid = True if __idx != {neighbor_skip_value} else False",
499+
debuginfo=di,
499500
)
500501
state.add_edge(
501502
neighbor_index_node,
@@ -545,7 +546,7 @@ def builtin_can_deref(
545546
"_out",
546547
result_node,
547548
None,
548-
dace.Memlet(data=result_name, subset="0", debuginfo=di),
549+
dace.Memlet(data=result_name, subset="0"),
549550
)
550551
return [ValueExpr(result_node, dace.dtypes.bool)]
551552

@@ -598,14 +599,14 @@ def build_if_state(arg, state):
598599
stmt_state, tbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == True")
599600
)
600601
sdfg.add_edge(tbr_state, join_state, dace.InterstateEdge())
601-
tbr_values = build_if_state(node_args[1], tbr_state)
602+
tbr_values = flatten_list(build_if_state(node_args[1], tbr_state))
602603
#
603604
fbr_state = sdfg.add_state("false_branch")
604605
sdfg.add_edge(
605606
stmt_state, fbr_state, dace.InterstateEdge(condition=f"{stmt_node.value.data} == False")
606607
)
607608
sdfg.add_edge(fbr_state, join_state, dace.InterstateEdge())
608-
fbr_values = build_if_state(node_args[2], fbr_state)
609+
fbr_values = flatten_list(build_if_state(node_args[2], fbr_state))
609610

610611
assert isinstance(stmt_node, ValueExpr)
611612
assert stmt_node.dtype == dace.dtypes.bool
@@ -804,7 +805,7 @@ def builtin_tuple_get(
804805
class GatherLambdaSymbolsPass(eve.NodeVisitor):
805806
_sdfg: dace.SDFG
806807
_state: dace.SDFGState
807-
_symbol_map: dict[str, TaskletExpr]
808+
_symbol_map: dict[str, TaskletExpr | tuple[ValueExpr]]
808809
_parent_symbol_map: dict[str, TaskletExpr]
809810

810811
def __init__(
@@ -827,7 +828,7 @@ def _add_symbol(self, param, arg):
827828
if isinstance(arg, ValueExpr):
828829
# create storage in lambda sdfg
829830
self._sdfg.add_scalar(param, dtype=arg.dtype)
830-
# update table of lambda symbol
831+
# update table of lambda symbols
831832
self._symbol_map[param] = ValueExpr(
832833
self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype
833834
)
@@ -839,7 +840,7 @@ def _add_symbol(self, param, arg):
839840
index_names = {dim: f"__{param}_i_{dim}" for dim in arg.indices.keys()}
840841
for _, index_name in index_names.items():
841842
self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE)
842-
# update table of lambda symbol
843+
# update table of lambda symbols
843844
field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo)
844845
indices = {
845846
dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo)
@@ -850,6 +851,17 @@ def _add_symbol(self, param, arg):
850851
assert isinstance(arg, SymbolExpr)
851852
self._symbol_map[param] = arg
852853

854+
def _add_tuple(self, param, args):
855+
nodes = []
856+
# create storage in lambda sdfg for each tuple element
857+
for arg in args:
858+
var = unique_var_name()
859+
self._sdfg.add_scalar(var, dtype=arg.dtype)
860+
arg_node = self._state.add_access(var, debuginfo=self._sdfg.debuginfo)
861+
nodes.append(ValueExpr(arg_node, arg.dtype))
862+
# update table of lambda symbols
863+
self._symbol_map[param] = tuple(nodes)
864+
853865
def visit_SymRef(self, node: itir.SymRef):
854866
name = str(node.id)
855867
if name in self._parent_symbol_map and name not in self._symbol_map:
@@ -858,9 +870,13 @@ def visit_SymRef(self, node: itir.SymRef):
858870

859871
def visit_Lambda(self, node: itir.Lambda, args: Optional[Sequence[TaskletExpr]] = None):
860872
if args is not None:
861-
assert len(node.params) == len(args)
862-
for param, arg in zip(node.params, args):
863-
self._add_symbol(str(param.id), arg)
873+
if len(node.params) == len(args):
874+
for param, arg in zip(node.params, args):
875+
self._add_symbol(str(param.id), arg)
876+
else:
877+
# implicitly make tuple
878+
assert len(node.params) == 1
879+
self._add_tuple(str(node.params[0].id), args)
864880
self.visit(node.expr)
865881

866882

@@ -937,7 +953,7 @@ def visit_Lambda(
937953
# Create the SDFG for the lambda's body
938954
lambda_sdfg = dace.SDFG(func_name)
939955
lambda_sdfg.debuginfo = dace_debuginfo(node, self.context.body.debuginfo)
940-
lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True)
956+
lambda_state = lambda_sdfg.add_state(f"{func_name}_body", is_start_block=True)
941957

942958
lambda_symbols_pass = GatherLambdaSymbolsPass(
943959
lambda_sdfg, lambda_state, self.context.symbol_map
@@ -947,9 +963,13 @@ def visit_Lambda(
947963
# Add for input nodes for lambda symbols
948964
inputs: list[tuple[str, ValueExpr] | tuple[tuple[str, dict], IteratorExpr]] = []
949965
for sym, input_node in lambda_symbols_pass.symbol_refs.items():
950-
arg = next((arg for param, arg in zip(node.params, args) if param.id == sym), None)
951-
if arg:
952-
outer_node = arg
966+
params = [str(p.id) for p in node.params]
967+
try:
968+
param_index = params.index(sym)
969+
except ValueError:
970+
param_index = -1
971+
if param_index >= 0:
972+
outer_node = args[param_index]
953973
else:
954974
# the symbol is not found among lambda arguments, then it is inherited from parent scope
955975
outer_node = self.context.symbol_map[sym]
@@ -962,6 +982,13 @@ def visit_Lambda(
962982
elif isinstance(input_node, ValueExpr):
963983
assert isinstance(outer_node, ValueExpr)
964984
inputs.append((sym, outer_node))
985+
elif isinstance(input_node, tuple):
986+
assert param_index >= 0
987+
for i, input_node_i in enumerate(input_node):
988+
arg_i = args[param_index + i]
989+
assert isinstance(arg_i, ValueExpr)
990+
assert isinstance(input_node_i, ValueExpr)
991+
inputs.append((input_node_i.value.data, arg_i))
965992

966993
# Add connectivities as arrays
967994
for name in connectivity_names:
@@ -1530,7 +1557,7 @@ def add_expr_tasklet(
15301557
)
15311558
self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet)
15321559

1533-
memlet = dace.Memlet(data=result_access.data, subset="0", debuginfo=di)
1560+
memlet = dace.Memlet(data=result_access.data, subset="0")
15341561
self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet)
15351562

15361563
return [ValueExpr(result_access, result_type)]

0 commit comments

Comments
 (0)