diff --git a/array_api_strict/_elementwise_functions.py b/array_api_strict/_elementwise_functions.py index 761caff..7dc6c5c 100644 --- a/array_api_strict/_elementwise_functions.py +++ b/array_api_strict/_elementwise_functions.py @@ -855,6 +855,8 @@ def sign(x: Array, /) -> Array: """ if x.dtype not in _numeric_dtypes: raise TypeError("Only numeric dtypes are allowed in sign") + if x.dtype in _complex_floating_dtypes: + return x/abs(x) return Array._new(np.sign(x._array), device=x.device)