From b36142b08f96f104b5c7d797bdab99bd5000e3d5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 30 Dec 2024 13:27:29 -0800 Subject: [PATCH] Implement `numpy.clip` (#1839) --- dace/frontend/python/replacements.py | 16 ++++++++++++++-- tests/numpy/ufunc_test.py | 12 ++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index c5b3e3b2a2..b1ed0bbf56 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -315,7 +315,7 @@ def _numpy_full(pv: ProgramVisitor, """ if isinstance(shape, Number) or symbolic.issymbolic(shape): shape = [shape] - + is_data = False if isinstance(fill_value, (Number, np.bool_)): vtype = dtypes.dtype_to_typeclass(type(fill_value)) @@ -587,7 +587,7 @@ def _arange(pv: ProgramVisitor, if any(not isinstance(s, Number) for s in [start, stop, step]): if step == 1: # Common case where ceiling is not necessary - shape = (stop - start,) + shape = (stop - start, ) else: shape = (symbolic.int_ceil(stop - start, step), ) else: @@ -1064,6 +1064,17 @@ def _min(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, axis=None): identity=dtypes.max_value(sdfg.arrays[a].dtype)) +@oprepo.replaces('numpy.clip') +def _clip(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a, a_min=None, a_max=None, **kwargs): + if a_min is None and a_max is None: + raise ValueError("clip() requires at least one of `a_min` or `a_max`") + if a_min is None: + return implement_ufunc(pv, None, sdfg, state, 'minimum', [a, a_max], kwargs)[0] + if a_max is None: + return implement_ufunc(pv, None, sdfg, state, 'maximum', [a, a_min], kwargs)[0] + return implement_ufunc(pv, None, sdfg, state, 'clip', [a, a_min, a_max], kwargs)[0] + + def _minmax2(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, a: str, b: str, ismin=True): """ Implements the min or max function with 2 scalar arguments. """ @@ -5321,6 +5332,7 @@ def _vsplit(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, ary: str, ############################################################################################################ # Fast Fourier Transform numpy package (numpy.fft) + def _real_to_complex(real_type: dace.typeclass): if real_type == dace.float32: return dace.complex64 diff --git a/tests/numpy/ufunc_test.py b/tests/numpy/ufunc_test.py index 977f2bc47e..e187247511 100644 --- a/tests/numpy/ufunc_test.py +++ b/tests/numpy/ufunc_test.py @@ -1288,6 +1288,16 @@ def test_ufunc_clip(A: dace.float32[10]): return np.clip(A, 0.2, 0.5) +@compare_numpy_output() +def test_ufunc_clip_min(A: dace.float32[10]): + return np.clip(A, 0.2, None) + + +@compare_numpy_output() +def test_ufunc_clip_max(A: dace.float32[10]): + return np.clip(A, None, a_max=0.5) + + if __name__ == "__main__": test_ufunc_add_ff() test_ufunc_subtract_ff() @@ -1523,3 +1533,5 @@ def test_ufunc_clip(A: dace.float32[10]): test_ufunc_trunc_f() test_ufunc_trunc_u() test_ufunc_clip() + test_ufunc_clip_min() + test_ufunc_clip_max()