Skip to content

Commit

Permalink
Merge branch 'users/phschaad/python_frontend_unpacking_assignment' in…
Browse files Browse the repository at this point in the history
…to users/phschaad/cf_block_data_deps
  • Loading branch information
phschaad committed Jan 8, 2025
2 parents 9a5cac9 + 61f8973 commit 7682bde
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
13 changes: 13 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3292,6 +3292,19 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if isinstance(node.value, (ast.Tuple, ast.List)):
for n in node.value.elts:
results.extend(self._gettype(n))
elif isinstance(node.value, ast.Name) and node.value.id in self.sdfg.arrays and isinstance(
self.sdfg.arrays[node.value.id],
data.Array) and self.sdfg.arrays[node.value.id].total_size == len(elts):
# In the case where the rhs is an array (not being accessed with a slice) of exactly the same length as the
# number of elements in the lhs, the array can be expanded with a series of slice/subscript accesses to
# constant indexes (according to the number of elements in the lhs). These expansions can then be used to
# perform an unpacking assignment, similar to what Python does natively.
for i in range(len(elts)):
const_node = NumConstant(i)
ast.copy_location(const_node, node)
slice_node = ast.Subscript(node.value, const_node, node.value.ctx)
ast.copy_location(slice_node, node)
results.extend(self._gettype(slice_node))
else:
results.extend(self._gettype(node.value))

Expand Down
16 changes: 16 additions & 0 deletions tests/python_frontend/assignment_statements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,21 @@ def test_multiple_targets_parentheses():
assert (c[0] == np.float32(2) * np.float32(np.pi))


@dace.program
def multiple_targets_unpacking(a: dace.float32[2]):
b, c = a
return b, c


def test_multiple_targets_unpacking():
a = np.zeros((2, ), dtype=np.float32)
a[0] = np.pi
a[1] = 2 * np.pi
b, c = multiple_targets_unpacking(a=a)
assert (b[0] == a[0])
assert (c[0] == a[1])


@dace.program
def starred_target(a: dace.float32[1]):
b, *c, d, e = a, 2 * a, 3 * a, 4 * a, 5 * a, 6 * a
Expand Down Expand Up @@ -209,6 +224,7 @@ def method(self):
test_single_target_tuple_with_definition()
test_multiple_targets()
test_multiple_targets_parentheses()
test_multiple_targets_unpacking()

# test_starred_target()
# test_attribute_reference()
Expand Down

0 comments on commit 7682bde

Please sign in to comment.