Skip to content

Commit

Permalink
scal2sym: Fix incorrect dimensionality in indirection removal (#1871)
Browse files Browse the repository at this point in the history
Fixes a reported failure mode of scalar to symbol promotion.
  • Loading branch information
tbennun committed Jan 14, 2025
1 parent 41e64d4 commit bb5fd17
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
17 changes: 14 additions & 3 deletions dace/transformation/passes/scalar_to_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,21 @@ def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subse
arrname, tasklet_slice = astutils.subscript_to_ast_slice(node)
arrname = arrname if arrname in self.arrays else None
if len(tasklet_slice) < len(memlet_subset):
new_tasklet_slice = [(None, None, None)] * len(memlet_subset)
# Unsqueeze all index dimensions from orig_subset into tasklet_subset
for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))):
if start == end:
tasklet_slice.insert(i, (None, None, None))
j = 0
for i, (start, end, _) in enumerate(memlet_subset.ndrange()):
if start != end:
new_tasklet_slice[i] = tasklet_slice[j]
j += 1

# Sanity check
if j != len(tasklet_slice):
raise IndexError(f'Only {j} out of {len(tasklet_slice)} indices were provided in subset expression '
f'"{astutils.unparse(node)}", found during composing with memlet of subset '
f'"{memlet_subset}".')
tasklet_slice = new_tasklet_slice

tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname))
return memlet_subset.compose(tasklet_subset)

Expand Down
39 changes: 39 additions & 0 deletions tests/passes/scalar_to_symbol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,43 @@ def test_reversed_order():
sdfg.compile()


@pytest.mark.parametrize('memlet_volume_n', (False, True))
def test_scalar_index_regression(memlet_volume_n):
"""
Tests a reported failure with an invalid promotion of a scalar index.
"""
N = dace.symbol('N')
volume = 1 if not memlet_volume_n else N
sdfg = dace.SDFG('tester')
sdfg.add_array('A', [10, 10, N], dace.float64)
sdfg.add_scalar('scal', dace.int64)
sdfg.add_scalar('tmp', dace.int64, transient=True)

init_state = sdfg.add_state()
t = init_state.add_tasklet('set', {}, {'t'}, 't = 1')
w = init_state.add_write('tmp')
init_state.add_edge(t, 't', w, None, dace.Memlet('tmp'))

state = sdfg.add_state_after(init_state)
r = state.add_read('scal')
rt = state.add_read('tmp')
t = state.add_tasklet('setone', {'s', 't'}, {'a'}, 'a[s + t] = -1')
w = state.add_write('A')
state.add_edge(rt, None, t, 't', dace.Memlet('tmp'))
state.add_edge(r, None, t, 's', dace.Memlet('scal'))
state.add_edge(t, 'a', w, None, dace.Memlet(data='A', subset='0, 0, 0:N', volume=volume))

sdfg.validate()
scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {})

a = np.random.rand(10, 10, 20)
scal = np.int64(5)
ref = np.copy(a)
ref[0, 0, scal + 1] = -1
sdfg(A=a, scal=scal, N=20)
assert np.allclose(a, ref)


if __name__ == '__main__':
test_find_promotable()
test_promote_simple()
Expand All @@ -783,3 +820,5 @@ def test_reversed_order():
test_ternary_expression(True)
test_double_index_bug()
test_reversed_order()
test_scalar_index_regression(False)
test_scalar_index_regression(True)

0 comments on commit bb5fd17

Please sign in to comment.