Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement numpy.any/all, more descriptive Python frontend errors #1836

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4441,8 +4441,8 @@ def parse_target(t: Union[ast.Name, ast.Subscript]):
# Connect Python state
self._connect_pystate(tasklet, self.current_state, '__istate', '__ostate')

if return_type is None:
return []
if return_type is None: # Unknown but potentially used return value
return [dtypes.pyobject()]
else:
return return_names

Expand Down Expand Up @@ -4996,13 +4996,13 @@ def _visit_op(self, node: Union[ast.UnaryOp, ast.BinOp, ast.BoolOp], op1: ast.AS
operand2, op2type = None, None

# Type-check operands in order to provide a clear error message
if (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject)):
if (isinstance(operand1, dtypes.pyobject) or (isinstance(operand1, str) and operand1 in self.defined
and isinstance(self.defined[operand1].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op1, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand1}" to enable using it within the program.')
if (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject)):
if (isinstance(operand2, dtypes.pyobject) or (isinstance(operand2, str) and operand2 in self.defined
and isinstance(self.defined[operand2].dtype, dtypes.pyobject))):
raise DaceSyntaxError(
self, op2, 'Trying to operate on a callback return value with an undefined type. '
f'Please add a type hint to "{operand2}" to enable using it within the program.')
Expand Down Expand Up @@ -5289,6 +5289,12 @@ def visit_Subscript(self, node: ast.Subscript, inference: bool = False):
# Try to construct memlet from subscript
node.value = ast.Name(id=array)
defined = dace.sdfg.NestedDict({**self.sdfg.arrays, **self.defined})

if arrtype is data.Scalar and array in defined and isinstance(defined[array].dtype, dtypes.pyobject):
raise DaceSyntaxError(
self, node, f'Object "{array}" is defined as a callback return value and cannot be sliced. '
'Consider adding a type hint to the variable.')

expr: MemletExpr = ParseMemlet(self, defined, node, nslice)

if inference:
Expand Down
2 changes: 1 addition & 1 deletion dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any]
# Start with default arguments, then add other arguments
result = {**self.default_args}
# Reconstruct keyword arguments
result.update({aname: arg for aname, arg in zip(self.argnames, args)})
result.update({aname: arg for aname, arg in zip(self.argnames, args) if aname not in self.constant_args})
result.update(kwargs)

# Add closure arguments to the call
Expand Down
20 changes: 16 additions & 4 deletions dace/frontend/python/replacements.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,16 @@ def _sum_array(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, a: str):
return _reduce(pv, sdfg, state, "lambda x, y: x + y", a, axis=0, identity=0)


@oprepo.replaces('numpy.any')
def _any(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
return _reduce(pv, sdfg, state, "lambda x, y: x or y", a, axis=axis, identity=0)


@oprepo.replaces('numpy.all')
def _all(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
return _reduce(pv, sdfg, state, "lambda x, y: x and y", a, axis=axis, identity=0)


@oprepo.replaces('numpy.mean')
def _mean(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):

Expand All @@ -1042,26 +1052,28 @@ def _mean(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):

@oprepo.replaces('numpy.max')
@oprepo.replaces('numpy.amax')
def _max(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
def _max(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None, initial=None):
initial = initial if initial is not None else dtypes.min_value(sdfg.arrays[a].dtype)
return _reduce(pv,
sdfg,
state,
"lambda x, y: max(x, y)",
a,
axis=axis,
identity=dtypes.min_value(sdfg.arrays[a].dtype))
identity=initial)


@oprepo.replaces('numpy.min')
@oprepo.replaces('numpy.amin')
def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None):
def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None, initial=None):
initial = initial if initial is not None else dtypes.max_value(sdfg.arrays[a].dtype)
return _reduce(pv,
sdfg,
state,
"lambda x, y: min(x, y)",
a,
axis=axis,
identity=dtypes.max_value(sdfg.arrays[a].dtype))
identity=initial)


def _minmax2(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, b: str, ismin=True):
Expand Down
13 changes: 13 additions & 0 deletions tests/numpy/reductions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,16 @@ def test_degenerate_reduction_implicit(A: dace.float64[1, 20]):
return np.sum(A, axis=0)


@compare_numpy_output()
def test_any(A: dace.float64[20]):
return np.any(A > 0.8, axis=0)


@compare_numpy_output()
def test_all(A: dace.float64[20]):
return np.all(A > 0.8, axis=0)


if __name__ == '__main__':

# generated with cat tests/numpy/reductions_test.py | grep -oP '(?<=^def ).*(?=\()' | awk '{print $0 "()"}'
Expand Down Expand Up @@ -271,3 +281,6 @@ def test_degenerate_reduction_implicit(A: dace.float64[1, 20]):
test_scalar_reduction()
test_degenerate_reduction_explicit()
test_degenerate_reduction_implicit()

test_any()
test_all()
35 changes: 35 additions & 0 deletions tests/python_frontend/callback_autodetect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import time
from dace import config
from dace.frontend.python.common import DaceSyntaxError

N = dace.symbol('N')

Expand Down Expand Up @@ -906,6 +907,38 @@ def tester(a: dace.float64[20]):
assert np.allclose(aa, expected)


def test_disallowed_callback_in_condition():

@dace_inhibitor
def callbackfunc(arr):
return 42

@dace.program
def callback_in_condition(arr: dace.float64[20]):
if arr[0] < callbackfunc(arr):
return arr + 1
else:
return arr

with pytest.raises(DaceSyntaxError, match="Trying to operate on a callback"):
callback_in_condition.to_sdfg()


def test_disallowed_callback_slice():

@dace_inhibitor
def callbackfunc(arr):
return 42

@dace.program
def callback_in_condition(arr: dace.float64[20]):
a = callbackfunc(arr)
return arr + a[:20]

with pytest.raises(DaceSyntaxError, match="cannot be sliced"):
callback_in_condition.to_sdfg()


@pytest.mark.skip('Test requires GUI')
def test_matplotlib_with_compute():
"""
Expand Down Expand Up @@ -978,4 +1011,6 @@ def tester():
test_pyobject_return_tuple()
test_custom_generator()
test_custom_generator_with_break()
test_disallowed_callback_in_condition()
test_disallowed_callback_slice()
# test_matplotlib_with_compute()
Loading