14
14
from typing import Any , Mapping , Optional , Sequence , cast
15
15
16
16
import dace
17
+ from dace .sdfg .state import LoopRegion
17
18
18
19
import gt4py .eve as eve
19
20
from gt4py .next import Dimension , DimensionKind , type_inference as next_typing
@@ -477,15 +478,38 @@ def _visit_scan_stencil_closure(
477
478
scan_sdfg = dace .SDFG (name = "scan" )
478
479
scan_sdfg .debuginfo = dace_debuginfo (node )
479
480
480
- # create a state machine for lambda call over the scan dimension
481
- start_state = scan_sdfg .add_state ("start" , True )
482
- lambda_state = scan_sdfg .add_state ("lambda_compute" )
483
- end_state = scan_sdfg .add_state ("end" )
484
-
485
481
# the carry value of the scan operator exists only in the scope of the scan sdfg
486
482
scan_carry_name = unique_var_name ()
487
483
scan_sdfg .add_scalar (scan_carry_name , dtype = as_dace_type (scan_dtype ), transient = True )
488
484
485
+ # create a loop region for lambda call over the scan dimension
486
+ scan_loop_var = f"i_{ scan_dim } "
487
+ if is_forward :
488
+ scan_loop = LoopRegion (
489
+ label = "scan" ,
490
+ condition_expr = f"{ scan_loop_var } < { scan_ub_str } " ,
491
+ loop_var = scan_loop_var ,
492
+ initialize_expr = f"{ scan_loop_var } = { scan_lb_str } " ,
493
+ update_expr = f"{ scan_loop_var } = { scan_loop_var } + 1" ,
494
+ inverted = False ,
495
+ )
496
+ else :
497
+ scan_loop = LoopRegion (
498
+ label = "scan" ,
499
+ condition_expr = f"{ scan_loop_var } >= { scan_lb_str } " ,
500
+ loop_var = scan_loop_var ,
501
+ initialize_expr = f"{ scan_loop_var } = { scan_ub_str } - 1" ,
502
+ update_expr = f"{ scan_loop_var } = { scan_loop_var } - 1" ,
503
+ inverted = False ,
504
+ )
505
+ scan_sdfg .add_node (scan_loop )
506
+ compute_state = scan_loop .add_state ("lambda_compute" , is_start_block = True )
507
+ update_state = scan_loop .add_state ("lambda_update" )
508
+ scan_loop .add_edge (compute_state , update_state , dace .InterstateEdge ())
509
+
510
+ start_state = scan_sdfg .add_state ("start" , is_start_block = True )
511
+ scan_sdfg .add_edge (start_state , scan_loop , dace .InterstateEdge ())
512
+
489
513
# tasklet for initialization of carry
490
514
carry_init_tasklet = start_state .add_tasklet (
491
515
"get_carry_init_value" ,
@@ -502,19 +526,6 @@ def _visit_scan_stencil_closure(
502
526
dace .Memlet .simple (scan_carry_name , "0" ),
503
527
)
504
528
505
- # TODO(edopao): replace state machine with dace loop construct
506
- scan_sdfg .add_loop (
507
- start_state ,
508
- lambda_state ,
509
- end_state ,
510
- loop_var = f"i_{ scan_dim } " ,
511
- initialize_expr = f"{ scan_lb_str } " if is_forward else f"{ scan_ub_str } - 1" ,
512
- condition_expr = f"i_{ scan_dim } < { scan_ub_str } "
513
- if is_forward
514
- else f"i_{ scan_dim } >= { scan_lb_str } " ,
515
- increment_expr = f"i_{ scan_dim } + 1" if is_forward else f"i_{ scan_dim } - 1" ,
516
- )
517
-
518
529
# add storage to scan SDFG for inputs
519
530
for name in [* input_names , * connectivity_names ]:
520
531
assert name not in scan_sdfg .arrays
@@ -569,7 +580,7 @@ def _visit_scan_stencil_closure(
569
580
array_mapping = {** input_mapping , ** connectivity_mapping }
570
581
symbol_mapping = map_nested_sdfg_symbols (scan_sdfg , lambda_context .body , array_mapping )
571
582
572
- scan_inner_node = lambda_state .add_nested_sdfg (
583
+ scan_inner_node = compute_state .add_nested_sdfg (
573
584
lambda_context .body ,
574
585
parent = scan_sdfg ,
575
586
inputs = set (lambda_input_names ) | set (connectivity_names ),
@@ -580,29 +591,25 @@ def _visit_scan_stencil_closure(
580
591
581
592
# connect scan SDFG to lambda inputs
582
593
for name , memlet in array_mapping .items ():
583
- access_node = lambda_state .add_access (name , debuginfo = lambda_context .body .debuginfo )
584
- lambda_state .add_edge (access_node , None , scan_inner_node , name , memlet )
594
+ access_node = compute_state .add_access (name , debuginfo = lambda_context .body .debuginfo )
595
+ compute_state .add_edge (access_node , None , scan_inner_node , name , memlet )
585
596
586
597
output_names = [output_name ]
587
598
assert len (lambda_output_names ) == 1
588
599
# connect lambda output to scan SDFG
589
600
for name , connector in zip (output_names , lambda_output_names ):
590
- lambda_state .add_edge (
601
+ compute_state .add_edge (
591
602
scan_inner_node ,
592
603
connector ,
593
- lambda_state .add_access (name , debuginfo = lambda_context .body .debuginfo ),
604
+ compute_state .add_access (name , debuginfo = lambda_context .body .debuginfo ),
594
605
None ,
595
- dace .Memlet .simple (name , f"i_ { scan_dim } " ),
606
+ dace .Memlet .simple (name , scan_loop_var ),
596
607
)
597
608
598
- # add state to scan SDFG to update the carry value at each loop iteration
599
- lambda_update_state = scan_sdfg .add_state_after (lambda_state , "lambda_update" )
600
- lambda_update_state .add_memlet_path (
601
- lambda_update_state .add_access (output_name , debuginfo = lambda_context .body .debuginfo ),
602
- lambda_update_state .add_access (
603
- scan_carry_name , debuginfo = lambda_context .body .debuginfo
604
- ),
605
- memlet = dace .Memlet .simple (output_names [0 ], f"i_{ scan_dim } " , other_subset_str = "0" ),
609
+ update_state .add_nedge (
610
+ update_state .add_access (output_name , debuginfo = lambda_context .body .debuginfo ),
611
+ update_state .add_access (scan_carry_name , debuginfo = lambda_context .body .debuginfo ),
612
+ dace .Memlet .simple (output_names [0 ], scan_loop_var , other_subset_str = "0" ),
606
613
)
607
614
608
615
return scan_sdfg , map_ranges , scan_dim_index
0 commit comments