From 35aebfb3134f32e9e846d6c0b12d8d8e907ac6f3 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 9 Jan 2025 10:40:29 +0100 Subject: [PATCH] Fix loop lifting for trailing increment assignments --- .../transformation/interstate/loop_lifting.py | 20 +++++++-- .../interstate/loop_lifting_test.py | 44 ++++++++++++++++++- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 746910964c..f67876016f 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -2,7 +2,7 @@ from dace import properties from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion from dace.transformation import transformation from dace.transformation.interstate.loop_detection import DetectLoop @@ -82,8 +82,22 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): added.add(e) if e is incr_edge: if left_over_incr_assignments != {}: - dst = loop.add_state(label + '_tail') if not inverted else e.dst - loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments)) + assignments = left_over_incr_assignments + dst = e.dst + if e.dst is first_state: + if not update_before_condition: + left_over_incr_cond_region = ConditionalBlock(label + '_post_incr_conditional') + incr_graph = ControlFlowRegion(label + '_post_incr') + left_over_incr_cond_region.add_branch(cond_edge.data.condition, incr_graph) + incr_graph.add_edge(incr_graph.add_state(label + '_post_incr_start', + is_start_block=True), + incr_graph.add_state(label + '_post_incr_end'), + InterstateEdge(assignments=left_over_incr_assignments)) + dst = left_over_incr_cond_region + assignments = {} + else: + dst = loop.add_state(label + '_tail') + loop.add_edge(e.src, dst, InterstateEdge(assignments=assignments)) elif e is cond_edge: if not inverted: e.data.condition = properties.CodeBlock('1') diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index 676512f5f6..cb25005457 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -167,7 +167,7 @@ def test_lift_loop_llvm_canonical_while(): def test_do_while(): - sdfg = SDFG('regular_for') + sdfg = SDFG('do_while') N = dace.symbol('N') sdfg.add_symbol('i', dace.int32) sdfg.add_symbol('j', dace.int32) @@ -209,9 +209,51 @@ def test_do_while(): assert np.allclose(A_valid, A) +def test_inverted_loop_with_additional_increment_assignment(): + sdfg = SDFG('inverted_loop_with_additional_increment_assignment') + N = dace.symbol('N') + sdfg.add_scalar('i', dace.int32, transient=True) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + a_state = sdfg.add_state('a_state', is_start_block=True) + b_state = sdfg.add_state('b_state') + c_state = sdfg.add_state('c_state') + d_state = sdfg.add_state('d_state') + sdfg.add_edge(a_state, b_state, InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(b_state, c_state, InterstateEdge()) + sdfg.add_edge(c_state, b_state, InterstateEdge(condition='i < N', assignments={'k': 'k + 1'})) + sdfg.add_edge(c_state, d_state, InterstateEdge(condition='i >= N')) + a_access = b_state.add_access('A') + w_tasklet = b_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + b_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + i_read = c_state.add_access('i') + i_write = c_state.add_access('i') + iw_tasklet = c_state.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2') + c_state.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]')) + c_state.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]')) + a_access_2 = d_state.add_access('A') + w_tasklet_2 = d_state.add_tasklet('t1', {}, {'out'}, 'out = k') + d_state.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_explicit_control_flow == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + if __name__ == '__main__': test_lift_regular_for_loop() test_lift_loop_llvm_canonical(True) test_lift_loop_llvm_canonical(False) test_lift_loop_llvm_canonical_while() test_do_while() + test_inverted_loop_with_additional_increment_assignment()