Skip to content

Commit

Permalink
Fix loop lifting for trailing increment assignments
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Jan 9, 2025
1 parent 97634cf commit 35aebfb
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
20 changes: 17 additions & 3 deletions dace/transformation/interstate/loop_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dace import properties
from dace.sdfg.sdfg import SDFG, InterstateEdge
from dace.sdfg.state import ControlFlowRegion, LoopRegion
from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion
from dace.transformation import transformation
from dace.transformation.interstate.loop_detection import DetectLoop

Expand Down Expand Up @@ -82,8 +82,22 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG):
added.add(e)
if e is incr_edge:
if left_over_incr_assignments != {}:
dst = loop.add_state(label + '_tail') if not inverted else e.dst
loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments))
assignments = left_over_incr_assignments
dst = e.dst
if e.dst is first_state:
if not update_before_condition:
left_over_incr_cond_region = ConditionalBlock(label + '_post_incr_conditional')
incr_graph = ControlFlowRegion(label + '_post_incr')
left_over_incr_cond_region.add_branch(cond_edge.data.condition, incr_graph)
incr_graph.add_edge(incr_graph.add_state(label + '_post_incr_start',
is_start_block=True),
incr_graph.add_state(label + '_post_incr_end'),
InterstateEdge(assignments=left_over_incr_assignments))
dst = left_over_incr_cond_region
assignments = {}
else:
dst = loop.add_state(label + '_tail')
loop.add_edge(e.src, dst, InterstateEdge(assignments=assignments))
elif e is cond_edge:
if not inverted:
e.data.condition = properties.CodeBlock('1')
Expand Down
44 changes: 43 additions & 1 deletion tests/transformations/interstate/loop_lifting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_lift_loop_llvm_canonical_while():


def test_do_while():
sdfg = SDFG('regular_for')
sdfg = SDFG('do_while')
N = dace.symbol('N')
sdfg.add_symbol('i', dace.int32)
sdfg.add_symbol('j', dace.int32)
Expand Down Expand Up @@ -209,9 +209,51 @@ def test_do_while():
assert np.allclose(A_valid, A)


def test_inverted_loop_with_additional_increment_assignment():
sdfg = SDFG('inverted_loop_with_additional_increment_assignment')
N = dace.symbol('N')
sdfg.add_scalar('i', dace.int32, transient=True)
sdfg.add_symbol('k', dace.int32)
sdfg.add_array('A', (N,), dace.int32)
a_state = sdfg.add_state('a_state', is_start_block=True)
b_state = sdfg.add_state('b_state')
c_state = sdfg.add_state('c_state')
d_state = sdfg.add_state('d_state')
sdfg.add_edge(a_state, b_state, InterstateEdge(assignments={'k': 0}))
sdfg.add_edge(b_state, c_state, InterstateEdge())
sdfg.add_edge(c_state, b_state, InterstateEdge(condition='i < N', assignments={'k': 'k + 1'}))
sdfg.add_edge(c_state, d_state, InterstateEdge(condition='i >= N'))
a_access = b_state.add_access('A')
w_tasklet = b_state.add_tasklet('t1', {}, {'out'}, 'out = 1')
b_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]'))
i_read = c_state.add_access('i')
i_write = c_state.add_access('i')
iw_tasklet = c_state.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2')
c_state.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]'))
c_state.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]'))
a_access_2 = d_state.add_access('A')
w_tasklet_2 = d_state.add_tasklet('t1', {}, {'out'}, 'out = k')
d_state.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]'))

N = 30
A = np.zeros((N,)).astype(np.int32)
A_valid = np.zeros((N,)).astype(np.int32)
sdfg(A=A_valid, N=N)

sdfg.apply_transformations_repeated([LoopLifting])

assert sdfg.using_explicit_control_flow == True
assert any(isinstance(x, LoopRegion) for x in sdfg.nodes())

sdfg(A=A, N=N)

assert np.allclose(A_valid, A)


if __name__ == '__main__':
test_lift_regular_for_loop()
test_lift_loop_llvm_canonical(True)
test_lift_loop_llvm_canonical(False)
test_lift_loop_llvm_canonical_while()
test_do_while()
test_inverted_loop_with_additional_increment_assignment()

0 comments on commit 35aebfb

Please sign in to comment.