@@ -181,6 +181,126 @@ def __init__(
181
181
self .reduce_identity = reduce_identity
182
182
183
183
184
+ def _visit_lift_in_neighbors_reduction (
185
+ transformer : "PythonTaskletCodegen" ,
186
+ node : itir .FunCall ,
187
+ node_args : Sequence [IteratorExpr | list [ValueExpr ]],
188
+ offset_provider : NeighborTableOffsetProvider ,
189
+ map_entry : dace .nodes .MapEntry ,
190
+ map_exit : dace .nodes .MapExit ,
191
+ neighbor_index_node : dace .nodes .AccessNode ,
192
+ neighbor_value_node : dace .nodes .AccessNode ,
193
+ ) -> list [ValueExpr ]:
194
+ neighbor_dim = offset_provider .neighbor_axis .value
195
+ origin_dim = offset_provider .origin_axis .value
196
+
197
+ lifted_args : list [IteratorExpr | ValueExpr ] = []
198
+ for arg in node_args :
199
+ if isinstance (arg , IteratorExpr ):
200
+ if origin_dim in arg .indices :
201
+ lifted_indices = arg .indices .copy ()
202
+ lifted_indices .pop (origin_dim )
203
+ lifted_indices [neighbor_dim ] = neighbor_index_node
204
+ lifted_args .append (
205
+ IteratorExpr (
206
+ arg .field ,
207
+ lifted_indices ,
208
+ arg .dtype ,
209
+ arg .dimensions ,
210
+ )
211
+ )
212
+ else :
213
+ lifted_args .append (arg )
214
+ else :
215
+ lifted_args .append (arg [0 ])
216
+
217
+ lift_context , inner_inputs , inner_outputs = transformer .visit (node .args [0 ], args = lifted_args )
218
+ assert len (inner_outputs ) == 1
219
+ inner_out_connector = inner_outputs [0 ].value .data
220
+
221
+ input_nodes = {}
222
+ iterator_index_nodes = {}
223
+ lifted_index_connectors = set ()
224
+
225
+ for x , y in inner_inputs :
226
+ if isinstance (y , IteratorExpr ):
227
+ field_connector , inner_index_table = x
228
+ input_nodes [field_connector ] = y .field
229
+ for dim , connector in inner_index_table .items ():
230
+ if dim == neighbor_dim :
231
+ lifted_index_connectors .add (connector )
232
+ iterator_index_nodes [connector ] = y .indices [dim ]
233
+ else :
234
+ assert isinstance (y , ValueExpr )
235
+ input_nodes [x ] = y .value
236
+
237
+ neighbor_tables = filter_neighbor_tables (transformer .offset_provider )
238
+ connectivity_names = [connectivity_identifier (offset ) for offset in neighbor_tables .keys ()]
239
+
240
+ parent_sdfg = transformer .context .body
241
+ parent_state = transformer .context .state
242
+
243
+ input_mapping = {
244
+ connector : create_memlet_full (node .data , node .desc (parent_sdfg ))
245
+ for connector , node in input_nodes .items ()
246
+ }
247
+ connectivity_mapping = {
248
+ name : create_memlet_full (name , parent_sdfg .arrays [name ]) for name in connectivity_names
249
+ }
250
+ array_mapping = {** input_mapping , ** connectivity_mapping }
251
+ symbol_mapping = map_nested_sdfg_symbols (parent_sdfg , lift_context .body , array_mapping )
252
+
253
+ nested_sdfg_node = parent_state .add_nested_sdfg (
254
+ lift_context .body ,
255
+ parent_sdfg ,
256
+ inputs = {* array_mapping .keys (), * iterator_index_nodes .keys ()},
257
+ outputs = {inner_out_connector },
258
+ symbol_mapping = symbol_mapping ,
259
+ debuginfo = lift_context .body .debuginfo ,
260
+ )
261
+
262
+ for connectivity_connector , memlet in connectivity_mapping .items ():
263
+ parent_state .add_memlet_path (
264
+ parent_state .add_access (memlet .data , debuginfo = lift_context .body .debuginfo ),
265
+ map_entry ,
266
+ nested_sdfg_node ,
267
+ dst_conn = connectivity_connector ,
268
+ memlet = memlet ,
269
+ )
270
+
271
+ for inner_connector , access_node in input_nodes .items ():
272
+ parent_state .add_memlet_path (
273
+ access_node ,
274
+ map_entry ,
275
+ nested_sdfg_node ,
276
+ dst_conn = inner_connector ,
277
+ memlet = input_mapping [inner_connector ],
278
+ )
279
+
280
+ for inner_connector , access_node in iterator_index_nodes .items ():
281
+ memlet = dace .Memlet (data = access_node .data , subset = "0" )
282
+ if inner_connector in lifted_index_connectors :
283
+ parent_state .add_edge (access_node , None , nested_sdfg_node , inner_connector , memlet )
284
+ else :
285
+ parent_state .add_memlet_path (
286
+ access_node ,
287
+ map_entry ,
288
+ nested_sdfg_node ,
289
+ dst_conn = inner_connector ,
290
+ memlet = memlet ,
291
+ )
292
+
293
+ parent_state .add_memlet_path (
294
+ nested_sdfg_node ,
295
+ map_exit ,
296
+ neighbor_value_node ,
297
+ src_conn = inner_out_connector ,
298
+ memlet = dace .Memlet (data = neighbor_value_node .data , subset = "," .join (map_entry .params )),
299
+ )
300
+
301
+ return [ValueExpr (neighbor_value_node , inner_outputs [0 ].dtype )]
302
+
303
+
184
304
def builtin_neighbors (
185
305
transformer : "PythonTaskletCodegen" , node : itir .Expr , node_args : list [itir .Expr ]
186
306
) -> list [ValueExpr ]:
@@ -198,7 +318,16 @@ def builtin_neighbors(
198
318
"Neighbor reduction only implemented for connectivity based on neighbor tables."
199
319
)
200
320
201
- iterator = transformer .visit (data )
321
+ lift_node = None
322
+ if isinstance (data , FunCall ):
323
+ assert isinstance (data .fun , itir .FunCall )
324
+ fun_node = data .fun
325
+ if isinstance (fun_node .fun , itir .SymRef ) and fun_node .fun .id == "lift" :
326
+ lift_node = fun_node
327
+ lift_args = transformer .visit (data .args )
328
+ iterator = next (filter (lambda x : isinstance (x , IteratorExpr ), lift_args ), None )
329
+ if lift_node is None :
330
+ iterator = transformer .visit (data )
202
331
assert isinstance (iterator , IteratorExpr )
203
332
field_desc = iterator .field .desc (transformer .context .body )
204
333
origin_index_node = iterator .indices [offset_provider .origin_axis .value ]
@@ -259,44 +388,56 @@ def builtin_neighbors(
259
388
dace .Memlet (data = neighbor_index_var , subset = "0" ),
260
389
)
261
390
262
- data_access_tasklet = state .add_tasklet (
263
- "data_access" ,
264
- code = "__data = __field[__idx]"
265
- + (
266
- f" if __idx != { neighbor_skip_value } else { transformer .context .reduce_identity .value } "
267
- if offset_provider .has_skip_values
268
- else ""
269
- ),
270
- inputs = {"__field" , "__idx" },
271
- outputs = {"__data" },
272
- debuginfo = di ,
273
- )
274
- # select full shape only in the neighbor-axis dimension
275
- field_subset = tuple (
276
- f"0:{ shape } " if dim == offset_provider .neighbor_axis .value else f"i_{ dim } "
277
- for dim , shape in zip (sorted (iterator .dimensions ), field_desc .shape )
278
- )
279
- state .add_memlet_path (
280
- iterator .field ,
281
- me ,
282
- data_access_tasklet ,
283
- memlet = create_memlet_at (iterator .field .data , field_subset ),
284
- dst_conn = "__field" ,
285
- )
286
- state .add_edge (
287
- neighbor_index_node ,
288
- None ,
289
- data_access_tasklet ,
290
- "__idx" ,
291
- dace .Memlet (data = neighbor_index_var , subset = "0" ),
292
- )
293
- state .add_memlet_path (
294
- data_access_tasklet ,
295
- mx ,
296
- neighbor_value_node ,
297
- memlet = dace .Memlet (data = neighbor_value_var , subset = neighbor_map_index , debuginfo = di ),
298
- src_conn = "__data" ,
299
- )
391
+ if lift_node is not None :
392
+ _visit_lift_in_neighbors_reduction (
393
+ transformer ,
394
+ lift_node ,
395
+ lift_args ,
396
+ offset_provider ,
397
+ me ,
398
+ mx ,
399
+ neighbor_index_node ,
400
+ neighbor_value_node ,
401
+ )
402
+ else :
403
+ data_access_tasklet = state .add_tasklet (
404
+ "data_access" ,
405
+ code = "__data = __field[__idx]"
406
+ + (
407
+ f" if __idx != { neighbor_skip_value } else { transformer .context .reduce_identity .value } "
408
+ if offset_provider .has_skip_values
409
+ else ""
410
+ ),
411
+ inputs = {"__field" , "__idx" },
412
+ outputs = {"__data" },
413
+ debuginfo = di ,
414
+ )
415
+ # select full shape only in the neighbor-axis dimension
416
+ field_subset = tuple (
417
+ f"0:{ shape } " if dim == offset_provider .neighbor_axis .value else f"i_{ dim } "
418
+ for dim , shape in zip (sorted (iterator .dimensions ), field_desc .shape )
419
+ )
420
+ state .add_memlet_path (
421
+ iterator .field ,
422
+ me ,
423
+ data_access_tasklet ,
424
+ memlet = create_memlet_at (iterator .field .data , field_subset ),
425
+ dst_conn = "__field" ,
426
+ )
427
+ state .add_edge (
428
+ neighbor_index_node ,
429
+ None ,
430
+ data_access_tasklet ,
431
+ "__idx" ,
432
+ dace .Memlet (data = neighbor_index_var , subset = "0" ),
433
+ )
434
+ state .add_memlet_path (
435
+ data_access_tasklet ,
436
+ mx ,
437
+ neighbor_value_node ,
438
+ memlet = dace .Memlet (data = neighbor_value_var , subset = neighbor_map_index , debuginfo = di ),
439
+ src_conn = "__data" ,
440
+ )
300
441
301
442
if not offset_provider .has_skip_values :
302
443
return [ValueExpr (neighbor_value_node , iterator .dtype )]
@@ -377,9 +518,8 @@ def builtin_can_deref(
377
518
# create tasklet to check that field indices are non-negative (-1 is invalid)
378
519
args = [ValueExpr (access_node , _INDEX_DTYPE ) for access_node in iterator .indices .values ()]
379
520
internals = [f"{ arg .value .data } _v" for arg in args ]
380
- expr_code = " and " .join ([ f"{ v } >= 0 " for v in internals ] )
521
+ expr_code = " and " .join (f"{ v } != { neighbor_skip_value } " for v in internals )
381
522
382
- # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution
383
523
return transformer .add_expr_tasklet (
384
524
list (zip (args , internals )),
385
525
expr_code ,
@@ -946,7 +1086,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]:
946
1086
iterator = self .visit (node .args [0 ])
947
1087
if not isinstance (iterator , IteratorExpr ):
948
1088
# shift cannot be applied because the argument is not iterable
949
- # TODO: remove this special case when ITIR reduce-unroll pass is able to catch it
1089
+ # TODO: remove this special case when ITIR pass is able to catch it
950
1090
assert isinstance (iterator , list ) and len (iterator ) == 1
951
1091
assert isinstance (iterator [0 ], ValueExpr )
952
1092
return iterator
0 commit comments