Skip to content

Commit 70f0f88

Browse files
authored
feat[next][dace]: use new LoopRegion construct for scan operator (#1424)
The lowering of scan operator to SDFG uses a state machine to represent a loop. This PR replaces the state machine with a LoopRegion construct introduced in dace v0.15. The LoopRegion construct is not yet supported by dace transformation, but it will in the future and it could open new optimization opportunities (e.g. K-caching).
1 parent ac0478a commit 70f0f88

File tree

2 files changed

+43
-32
lines changed

2 files changed

+43
-32
lines changed

src/gt4py/next/program_processors/runners/dace_iterator/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import dace
2121
import numpy as np
2222
from dace.codegen.compiled_sdfg import CompiledSDFG
23+
from dace.sdfg import utils as sdutils
2324
from dace.transformation.auto import auto_optimize as autoopt
2425

2526
import gt4py.next.allocators as next_allocators
@@ -293,6 +294,9 @@ def build_sdfg_from_itir(
293294
filename=frameinfo.filename,
294295
)
295296

297+
# TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct
298+
sdutils.inline_loop_blocks(sdfg)
299+
296300
# run DaCe transformations to simplify the SDFG
297301
sdfg.simplify()
298302

src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py

+39-32
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing import Any, Mapping, Optional, Sequence, cast
1515

1616
import dace
17+
from dace.sdfg.state import LoopRegion
1718

1819
import gt4py.eve as eve
1920
from gt4py.next import Dimension, DimensionKind, type_inference as next_typing
@@ -477,15 +478,38 @@ def _visit_scan_stencil_closure(
477478
scan_sdfg = dace.SDFG(name="scan")
478479
scan_sdfg.debuginfo = dace_debuginfo(node)
479480

480-
# create a state machine for lambda call over the scan dimension
481-
start_state = scan_sdfg.add_state("start", True)
482-
lambda_state = scan_sdfg.add_state("lambda_compute")
483-
end_state = scan_sdfg.add_state("end")
484-
485481
# the carry value of the scan operator exists only in the scope of the scan sdfg
486482
scan_carry_name = unique_var_name()
487483
scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True)
488484

485+
# create a loop region for lambda call over the scan dimension
486+
scan_loop_var = f"i_{scan_dim}"
487+
if is_forward:
488+
scan_loop = LoopRegion(
489+
label="scan",
490+
condition_expr=f"{scan_loop_var} < {scan_ub_str}",
491+
loop_var=scan_loop_var,
492+
initialize_expr=f"{scan_loop_var} = {scan_lb_str}",
493+
update_expr=f"{scan_loop_var} = {scan_loop_var} + 1",
494+
inverted=False,
495+
)
496+
else:
497+
scan_loop = LoopRegion(
498+
label="scan",
499+
condition_expr=f"{scan_loop_var} >= {scan_lb_str}",
500+
loop_var=scan_loop_var,
501+
initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1",
502+
update_expr=f"{scan_loop_var} = {scan_loop_var} - 1",
503+
inverted=False,
504+
)
505+
scan_sdfg.add_node(scan_loop)
506+
compute_state = scan_loop.add_state("lambda_compute", is_start_block=True)
507+
update_state = scan_loop.add_state("lambda_update")
508+
scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge())
509+
510+
start_state = scan_sdfg.add_state("start", is_start_block=True)
511+
scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge())
512+
489513
# tasklet for initialization of carry
490514
carry_init_tasklet = start_state.add_tasklet(
491515
"get_carry_init_value",
@@ -502,19 +526,6 @@ def _visit_scan_stencil_closure(
502526
dace.Memlet.simple(scan_carry_name, "0"),
503527
)
504528

505-
# TODO(edopao): replace state machine with dace loop construct
506-
scan_sdfg.add_loop(
507-
start_state,
508-
lambda_state,
509-
end_state,
510-
loop_var=f"i_{scan_dim}",
511-
initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1",
512-
condition_expr=f"i_{scan_dim} < {scan_ub_str}"
513-
if is_forward
514-
else f"i_{scan_dim} >= {scan_lb_str}",
515-
increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1",
516-
)
517-
518529
# add storage to scan SDFG for inputs
519530
for name in [*input_names, *connectivity_names]:
520531
assert name not in scan_sdfg.arrays
@@ -569,7 +580,7 @@ def _visit_scan_stencil_closure(
569580
array_mapping = {**input_mapping, **connectivity_mapping}
570581
symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping)
571582

572-
scan_inner_node = lambda_state.add_nested_sdfg(
583+
scan_inner_node = compute_state.add_nested_sdfg(
573584
lambda_context.body,
574585
parent=scan_sdfg,
575586
inputs=set(lambda_input_names) | set(connectivity_names),
@@ -580,29 +591,25 @@ def _visit_scan_stencil_closure(
580591

581592
# connect scan SDFG to lambda inputs
582593
for name, memlet in array_mapping.items():
583-
access_node = lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
584-
lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet)
594+
access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo)
595+
compute_state.add_edge(access_node, None, scan_inner_node, name, memlet)
585596

586597
output_names = [output_name]
587598
assert len(lambda_output_names) == 1
588599
# connect lambda output to scan SDFG
589600
for name, connector in zip(output_names, lambda_output_names):
590-
lambda_state.add_edge(
601+
compute_state.add_edge(
591602
scan_inner_node,
592603
connector,
593-
lambda_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
604+
compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo),
594605
None,
595-
dace.Memlet.simple(name, f"i_{scan_dim}"),
606+
dace.Memlet.simple(name, scan_loop_var),
596607
)
597608

598-
# add state to scan SDFG to update the carry value at each loop iteration
599-
lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update")
600-
lambda_update_state.add_memlet_path(
601-
lambda_update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
602-
lambda_update_state.add_access(
603-
scan_carry_name, debuginfo=lambda_context.body.debuginfo
604-
),
605-
memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"),
609+
update_state.add_nedge(
610+
update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo),
611+
update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo),
612+
dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"),
606613
)
607614

608615
return scan_sdfg, map_ranges, scan_dim_index

0 commit comments

Comments
 (0)