Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refine condition for prohibiting scalar return values #1838

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,11 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context

# Because of how the code generator works Scalars can not be return values.
# TODO: Remove this limitation as the CompiledSDFG contains logic for that.
if isinstance(desc, dt.Scalar) and name.startswith("__return") and not desc.transient:
raise InvalidSDFGError(f'Can not use scalar "{name}" as return value.', sdfg, None)
if (sdfg.parent is None and isinstance(desc, dt.Scalar) and name.startswith("__return")
and not desc.transient):
raise InvalidSDFGError(
f'Cannot use scalar data descriptor ("{name}") as return value of a top-level function.', sdfg,
None)

# Validate array names
if name is not None and not dtypes.validate_name(name):
Expand Down Expand Up @@ -332,14 +335,14 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
symbols[str(sym)] = sym.dtype
validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)


except InvalidSDFGError as ex:
# If the SDFG is invalid, save it
fpath = os.path.join('_dacegraphs', 'invalid.sdfgz')
sdfg.save(fpath, exception=ex, compress=True)
ex.path = fpath
raise


def _accessible(sdfg: 'dace.sdfg.SDFG', container: str, context: Dict[str, bool]):
"""
Helper function that returns False if a data container cannot be accessed in the current SDFG context.
Expand Down
14 changes: 14 additions & 0 deletions tests/python_frontend/return_value_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,19 @@ def return_scalar():
assert return_scalar() == 5


def test_return_scalar_in_nested_function():

@dace.program
def nested_function() -> dace.int32:
return 5

@dace.program
def return_scalar():
return nested_function()

assert return_scalar() == 5


def test_return_array():

@dace.program
Expand Down Expand Up @@ -91,6 +104,7 @@ def return_void(a: dace.float64[20]):

if __name__ == '__main__':
test_return_scalar()
test_return_scalar_in_nested_function()
test_return_array()
test_return_tuple()
test_return_array_tuple()
Expand Down
Loading