@@ -401,7 +401,7 @@ def builtin_neighbors(
401
401
origin_index_node ,
402
402
me ,
403
403
shift_tasklet ,
404
- memlet = dace .Memlet (data = origin_index_node .data , subset = "0" , debuginfo = di ),
404
+ memlet = dace .Memlet (data = origin_index_node .data , subset = "0" ),
405
405
dst_conn = "__idx" ,
406
406
)
407
407
state .add_edge (
@@ -469,7 +469,7 @@ def builtin_neighbors(
469
469
data_access_tasklet ,
470
470
mx ,
471
471
neighbor_value_node ,
472
- memlet = dace .Memlet (data = neighbor_value_var , subset = neighbor_map_index , debuginfo = di ),
472
+ memlet = dace .Memlet (data = neighbor_value_var , subset = neighbor_map_index ),
473
473
src_conn = "__data" ,
474
474
)
475
475
@@ -496,6 +496,7 @@ def builtin_neighbors(
496
496
{"__idx" },
497
497
{"__valid" },
498
498
f"__valid = True if __idx != { neighbor_skip_value } else False" ,
499
+ debuginfo = di ,
499
500
)
500
501
state .add_edge (
501
502
neighbor_index_node ,
@@ -545,7 +546,7 @@ def builtin_can_deref(
545
546
"_out" ,
546
547
result_node ,
547
548
None ,
548
- dace .Memlet (data = result_name , subset = "0" , debuginfo = di ),
549
+ dace .Memlet (data = result_name , subset = "0" ),
549
550
)
550
551
return [ValueExpr (result_node , dace .dtypes .bool )]
551
552
@@ -598,14 +599,14 @@ def build_if_state(arg, state):
598
599
stmt_state , tbr_state , dace .InterstateEdge (condition = f"{ stmt_node .value .data } == True" )
599
600
)
600
601
sdfg .add_edge (tbr_state , join_state , dace .InterstateEdge ())
601
- tbr_values = build_if_state (node_args [1 ], tbr_state )
602
+ tbr_values = flatten_list ( build_if_state (node_args [1 ], tbr_state ) )
602
603
#
603
604
fbr_state = sdfg .add_state ("false_branch" )
604
605
sdfg .add_edge (
605
606
stmt_state , fbr_state , dace .InterstateEdge (condition = f"{ stmt_node .value .data } == False" )
606
607
)
607
608
sdfg .add_edge (fbr_state , join_state , dace .InterstateEdge ())
608
- fbr_values = build_if_state (node_args [2 ], fbr_state )
609
+ fbr_values = flatten_list ( build_if_state (node_args [2 ], fbr_state ) )
609
610
610
611
assert isinstance (stmt_node , ValueExpr )
611
612
assert stmt_node .dtype == dace .dtypes .bool
@@ -804,7 +805,7 @@ def builtin_tuple_get(
804
805
class GatherLambdaSymbolsPass (eve .NodeVisitor ):
805
806
_sdfg : dace .SDFG
806
807
_state : dace .SDFGState
807
- _symbol_map : dict [str , TaskletExpr ]
808
+ _symbol_map : dict [str , TaskletExpr | tuple [ ValueExpr ] ]
808
809
_parent_symbol_map : dict [str , TaskletExpr ]
809
810
810
811
def __init__ (
@@ -827,7 +828,7 @@ def _add_symbol(self, param, arg):
827
828
if isinstance (arg , ValueExpr ):
828
829
# create storage in lambda sdfg
829
830
self ._sdfg .add_scalar (param , dtype = arg .dtype )
830
- # update table of lambda symbol
831
+ # update table of lambda symbols
831
832
self ._symbol_map [param ] = ValueExpr (
832
833
self ._state .add_access (param , debuginfo = self ._sdfg .debuginfo ), arg .dtype
833
834
)
@@ -839,7 +840,7 @@ def _add_symbol(self, param, arg):
839
840
index_names = {dim : f"__{ param } _i_{ dim } " for dim in arg .indices .keys ()}
840
841
for _ , index_name in index_names .items ():
841
842
self ._sdfg .add_scalar (index_name , dtype = _INDEX_DTYPE )
842
- # update table of lambda symbol
843
+ # update table of lambda symbols
843
844
field = self ._state .add_access (param , debuginfo = self ._sdfg .debuginfo )
844
845
indices = {
845
846
dim : self ._state .add_access (index_arg , debuginfo = self ._sdfg .debuginfo )
@@ -850,6 +851,17 @@ def _add_symbol(self, param, arg):
850
851
assert isinstance (arg , SymbolExpr )
851
852
self ._symbol_map [param ] = arg
852
853
854
+ def _add_tuple (self , param , args ):
855
+ nodes = []
856
+ # create storage in lambda sdfg for each tuple element
857
+ for arg in args :
858
+ var = unique_var_name ()
859
+ self ._sdfg .add_scalar (var , dtype = arg .dtype )
860
+ arg_node = self ._state .add_access (var , debuginfo = self ._sdfg .debuginfo )
861
+ nodes .append (ValueExpr (arg_node , arg .dtype ))
862
+ # update table of lambda symbols
863
+ self ._symbol_map [param ] = tuple (nodes )
864
+
853
865
def visit_SymRef (self , node : itir .SymRef ):
854
866
name = str (node .id )
855
867
if name in self ._parent_symbol_map and name not in self ._symbol_map :
@@ -858,9 +870,13 @@ def visit_SymRef(self, node: itir.SymRef):
858
870
859
871
def visit_Lambda (self , node : itir .Lambda , args : Optional [Sequence [TaskletExpr ]] = None ):
860
872
if args is not None :
861
- assert len (node .params ) == len (args )
862
- for param , arg in zip (node .params , args ):
863
- self ._add_symbol (str (param .id ), arg )
873
+ if len (node .params ) == len (args ):
874
+ for param , arg in zip (node .params , args ):
875
+ self ._add_symbol (str (param .id ), arg )
876
+ else :
877
+ # implicitly make tuple
878
+ assert len (node .params ) == 1
879
+ self ._add_tuple (str (node .params [0 ].id ), args )
864
880
self .visit (node .expr )
865
881
866
882
@@ -937,7 +953,7 @@ def visit_Lambda(
937
953
# Create the SDFG for the lambda's body
938
954
lambda_sdfg = dace .SDFG (func_name )
939
955
lambda_sdfg .debuginfo = dace_debuginfo (node , self .context .body .debuginfo )
940
- lambda_state = lambda_sdfg .add_state (f"{ func_name } _entry " , True )
956
+ lambda_state = lambda_sdfg .add_state (f"{ func_name } _body " , is_start_block = True )
941
957
942
958
lambda_symbols_pass = GatherLambdaSymbolsPass (
943
959
lambda_sdfg , lambda_state , self .context .symbol_map
@@ -947,9 +963,13 @@ def visit_Lambda(
947
963
# Add for input nodes for lambda symbols
948
964
inputs : list [tuple [str , ValueExpr ] | tuple [tuple [str , dict ], IteratorExpr ]] = []
949
965
for sym , input_node in lambda_symbols_pass .symbol_refs .items ():
950
- arg = next ((arg for param , arg in zip (node .params , args ) if param .id == sym ), None )
951
- if arg :
952
- outer_node = arg
966
+ params = [str (p .id ) for p in node .params ]
967
+ try :
968
+ param_index = params .index (sym )
969
+ except ValueError :
970
+ param_index = - 1
971
+ if param_index >= 0 :
972
+ outer_node = args [param_index ]
953
973
else :
954
974
# the symbol is not found among lambda arguments, then it is inherited from parent scope
955
975
outer_node = self .context .symbol_map [sym ]
@@ -962,6 +982,13 @@ def visit_Lambda(
962
982
elif isinstance (input_node , ValueExpr ):
963
983
assert isinstance (outer_node , ValueExpr )
964
984
inputs .append ((sym , outer_node ))
985
+ elif isinstance (input_node , tuple ):
986
+ assert param_index >= 0
987
+ for i , input_node_i in enumerate (input_node ):
988
+ arg_i = args [param_index + i ]
989
+ assert isinstance (arg_i , ValueExpr )
990
+ assert isinstance (input_node_i , ValueExpr )
991
+ inputs .append ((input_node_i .value .data , arg_i ))
965
992
966
993
# Add connectivities as arrays
967
994
for name in connectivity_names :
@@ -1530,7 +1557,7 @@ def add_expr_tasklet(
1530
1557
)
1531
1558
self .context .state .add_edge (arg .value , None , expr_tasklet , internal , memlet )
1532
1559
1533
- memlet = dace .Memlet (data = result_access .data , subset = "0" , debuginfo = di )
1560
+ memlet = dace .Memlet (data = result_access .data , subset = "0" )
1534
1561
self .context .state .add_edge (expr_tasklet , "__result" , result_access , None , memlet )
1535
1562
1536
1563
return [ValueExpr (result_access , result_type )]
0 commit comments