Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for packing assignments in the python frontend #1854

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 61 additions & 4 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3296,7 +3296,65 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
results.extend(self._gettype(node.value))

if len(results) != len(elts):
raise DaceSyntaxError(self, node, 'Function returns %d values but %d provided' % (len(results), len(elts)))
if len(elts) == 1 and len(results) > 1 and isinstance(elts[0], ast.Name):
Copy link
Collaborator

@tbennun tbennun Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have self.visit-ed to get the name from other kinds of expressions (need to take care to visit the rhs once though)

# If multiple results are being assigned to one element, attempt to perform a packing assignment,
# i.e., similar to Python. This constructs a tuple / array of the correct size for the lhs according to
# the number of elements on the rhs, and then assigns to individual array / tuple positions using the
# correct slice accesses. If the datacontainer on the lhs is not defined yet, it is created here.
# If it already exists, only succeed if the size matches the number of elements on the rhs, similar to
# Python. All elements on the rhs must have a common datatype for this to work.
elt = elts[0]
desc = None
if elt.id in self.sdfg.arrays:
desc = self.sdfg.arrays[elt.id]
if desc is not None and not isinstance(desc, data.Array):
raise DaceSyntaxError(
self, node, 'Cannot assign %d function return values to %s due to incompatible type' %
(len(results), elt.id))
elif desc is not None and desc.total_size != len(results):
raise DaceSyntaxError(
self, node, 'Cannot assign %d function return values to a data container of size %s' %
(len(results), str(desc.total_size)))

# Determine the result data type and make sure there is only one.
res_dtype = None
for res, _ in results:
if not (isinstance(res, str) and res in self.sdfg.arrays):
res_dtype = None
break
res_data = self.sdfg.arrays[res]
if res_dtype is None:
res_dtype = res_data.dtype
elif res_dtype != res_data.dtype:
res_dtype = None
break
if res_dtype is None:
raise DaceSyntaxError(
self, node,
'Cannot determine common result datatype for %d function return values' % (len(results)))

res_name = elt.id
if desc is None:
# If no data container exists yet, create it.
res_name, desc = self.sdfg.add_transient(res_name, (len(results), ), res_dtype)
self.variables[res_name] = res_name

# Create the correct slice accesses.
new_elts = []
for i in range(len(results)):
name_node = ast.Name(res_name, elt.ctx)
ast.copy_location(name_node, elt)
const_node = NumConstant(i)
ast.copy_location(const_node, elt)
slice_node = ast.Subscript(name_node, const_node, elt.ctx)
ast.copy_location(slice_node, elt)
new_elts.append(slice_node)

elts = new_elts
else:
raise DaceSyntaxError(
self, node,
'Function returns %d values but assigning to %d expected values' % (len(results), len(elts)))

defined_vars = {**self.variables, **self.scope_vars}
defined_arrays = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.scope_arrays})
Expand Down Expand Up @@ -3515,7 +3573,6 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False):
if boolarr is not None and indirect_indices:
raise IndexError('Boolean array indexing cannot be combined with indirect access')


if self.nested and not new_data and not visited_target:
new_name, new_rng = self._add_write_access(name, rng, target)
# Local symbol or local data dependent
Expand Down Expand Up @@ -5028,12 +5085,12 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS

# Type-check operands in order to provide a clear error message
if (isinstance(operand1, dtypes.pyobject) or (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op1, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand1}" to enable using it within the program.')
if (isinstance(operand2, dtypes.pyobject) or (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op2, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand2}" to enable using it within the program.')
Expand Down
34 changes: 34 additions & 0 deletions tests/python_frontend/assignment_statements_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,38 @@ def test_single_target_parentheses():
assert (b[0] == np.float32(np.pi))


@dace.program
def single_target_tuple(a: dace.float32[1], b: dace.float32[1], c: dace.float32[2]):
c = (a, b)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should not change the contents of c though



def test_single_target_tuple():
a = np.zeros((1, ), dtype=np.float32)
b = np.zeros((1, ), dtype=np.float32)
c = np.zeros((2, ), dtype=np.float32)
a[0] = np.pi
b[0] = 2 * np.pi
single_target_tuple(a=a, b=b, c=c)
assert (c[0] == a[0])
assert (c[1] == b[0])


@dace.program
def single_target_tuple_with_definition(a: dace.float32[1], b: dace.float32[1]):
c = (a, b)
return c


def test_single_target_tuple_with_definition():
a = np.zeros((1, ), dtype=np.float32)
b = np.zeros((1, ), dtype=np.float32)
a[0] = np.pi
b[0] = 2 * np.pi
c = single_target_tuple_with_definition(a=a, b=b)
assert (c[0] == a[0])
assert (c[1] == b[0])


@dace.program
def multiple_targets(a: dace.float32[1]):
b, c = a, 2 * a
Expand Down Expand Up @@ -173,6 +205,8 @@ def method(self):
if __name__ == "__main__":
test_single_target()
test_single_target_parentheses()
test_single_target_tuple()
test_single_target_tuple_with_definition()
test_multiple_targets()
test_multiple_targets_parentheses()

Expand Down
Loading