From 9dd84f6fc0b0e0b2f1110a0745a1f4a15bd28f4e Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Mon, 7 Aug 2023 17:01:37 +0100 Subject: [PATCH 01/24] Add complex handling to tanh --- ivy/data_classes/array/elementwise.py | 13 ++++++++++--- ivy/data_classes/container/elementwise.py | 10 +++++++++- ivy/functional/backends/paddle/elementwise.py | 6 +++++- ivy/functional/ivy/elementwise.py | 5 ++++- ivy/stateful/activations.py | 18 ++++++++++++++---- .../test_core/test_elementwise.py | 4 +++- 6 files changed, 45 insertions(+), 11 deletions(-) diff --git a/ivy/data_classes/array/elementwise.py b/ivy/data_classes/array/elementwise.py index ce851189ca05e..845cffadd4788 100644 --- a/ivy/data_classes/array/elementwise.py +++ b/ivy/data_classes/array/elementwise.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -2552,7 +2552,12 @@ def tan(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ return ivy.tan(self._data, out=out) - def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: + def tanh( + self: ivy.Array, + *, + out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", + ) -> ivy.Array: """ ivy.Array instance method variant of ivy.tanh. This method simply wraps the function, and so the docstring for ivy.tanh also applies to this method with @@ -2566,6 +2571,8 @@ def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: out optional output, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -2581,7 +2588,7 @@ def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: >>> print(y) ivy.array([0., 0.762, 0.964]) """ - return ivy.tanh(self._data, out=out) + return ivy.tanh(self._data, out=out, complex_mode=complex_mode) def trunc(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ diff --git a/ivy/data_classes/container/elementwise.py b/ivy/data_classes/container/elementwise.py index 8b085ca013b32..95466a022d727 100644 --- a/ivy/data_classes/container/elementwise.py +++ b/ivy/data_classes/container/elementwise.py @@ -1,5 +1,5 @@ # global -from typing import Optional, Union, List, Dict +from typing import Optional, Union, List, Dict, Literal # local import ivy @@ -8273,6 +8273,7 @@ def _static_tanh( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container static method variant of ivy.tanh. This method simply wraps the @@ -8298,6 +8299,8 @@ def _static_tanh( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -8324,6 +8327,7 @@ def _static_tanh( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) def tanh( @@ -8334,6 +8338,7 @@ def tanh( prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, out: Optional[ivy.Container] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Container: """ ivy.Container instance method variant of ivy.tanh. This method simply wraps the @@ -8359,6 +8364,8 @@ def tanh( out optional output container, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. Returns ------- @@ -8385,6 +8392,7 @@ def tanh( prune_unapplied=prune_unapplied, map_sequences=map_sequences, out=out, + complex_mode=complex_mode, ) @staticmethod diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 19f210543fc55..4a757a5df0b5b 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -627,7 +627,7 @@ def not_equal( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def tanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -640,6 +640,10 @@ def tanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. paddle.float16, ]: return paddle.tanh(x.astype("float32")).astype(x.dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + tanh_a = paddle.tanh(paddle.real(x)) + tan_b = paddle.tan(paddle.imag(x)) + return (tanh_a + 1j * tan_b) / (1 + 1j * (tanh_a * tan_b)) return paddle.tanh(x) diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 2f56f3ab9c1df..89e7833719a75 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -1,6 +1,6 @@ # global from numbers import Number -from typing import Optional, Union +from typing import Optional, Union, Literal # local import ivy @@ -12,6 +12,7 @@ handle_array_like_without_promotion, inputs_to_ivy_arrays, handle_device_shifting, + handle_complex_input, ) from ivy.utils.exceptions import handle_exceptions @@ -6290,11 +6291,13 @@ def tan( @to_native_arrays_and_back @handle_array_function @handle_device_shifting +@handle_complex_input def tanh( x: Union[ivy.Array, ivy.NativeArray], /, *, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ Calculate an implementation-dependent approximation to the hyperbolic tangent, diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 0958770452912..c85a6086a1da5 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -254,24 +254,34 @@ def _forward(self, x): class Tanh(Module): - def __init__(self): - """Apply the TANH activation function.""" + def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): + """ + Apply the TANH activation function. + + Parameters + ---------- + complex_mode + Specifies how to handle complex input. + """ + self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x): + def _forward(self, x, complex_mode=None): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. + complex_mode + Specifies how to handle complex input. Returns ------- ret The outputs following the TANH activation *[batch_shape, d]* """ - return ivy.tanh(x) + return ivy.tanh(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) class ReLU6(Module): diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index e7b3dad977677..a483811847ac9 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -1489,7 +1489,9 @@ def test_tan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # tanh @handle_test( fn_tree="functional.ivy.tanh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_tanh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x From 20680ccb4a3763727d63fcaab23f30760367a174 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 8 Aug 2023 11:07:03 +0100 Subject: [PATCH 02/24] Modify gelu to allow complex in paddle Also modify it to pass some more tests, there seem to be issues still with approximations being mismatched between different backends, but to fix that will require implementing erf for complex inputs too, which should be its own commit --- ivy/functional/backends/paddle/activations.py | 13 +------------ ivy/functional/ivy/activations.py | 17 +++++++++++++---- ivy/functional/ivy/elementwise.py | 3 +++ .../test_nn/test_non_linear_activations.py | 9 +++------ .../test_functional/test_nn/test_activations.py | 5 +---- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 185218223ce33..c42e9f3931a12 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -59,7 +59,7 @@ def leaky_relu( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex128", "complex64")}}, backend_version + {"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version ) def gelu( x: paddle.Tensor, @@ -69,17 +69,6 @@ def gelu( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: - if paddle.is_complex(x): - if approximate: - return ( - 0.5 - * x - * ( - 1 - + paddle_backend.tanh(0.7978845608 * (x + 0.044715 * x * x * x)) - ) - ) - return 0.5 * x * (1 + paddle_backend.erf(x / paddle_backend.sqrt(2))) return F.gelu(x.cast("float32"), approximate=approximate).cast(x.dtype) return F.gelu(x, approximate=approximate) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index d16155e073eeb..c68c7805efbea 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -27,10 +27,19 @@ def _gelu_jax_like( ) -> ivy.Array: # We don't have the exact implementation # cuz the erf function doesn't work on complex numbers - sqrt_2_over_pi = ivy.sqrt(2 / ivy.pi).astype(x.dtype) - x_pw = ivy.pow(x, 3) - cdf = 0.5 * (1.0 + ivy.tanh(sqrt_2_over_pi * (x + 0.044715 * x_pw))) - return x * cdf + # Magic number #1 is sqrt(2/pi) + # Magic number #2 is from https://arxiv.org/abs/1606.08415v5 + return ivy.multiply( + ivy.multiply(x, 0.5), + ivy.add( + 1.0, + ivy.tanh( + ivy.multiply( + 0.7978845608, ivy.add(x, ivy.multiply(0.044715, ivy.pow(x, 3))) + ) + ), + ), + ) @handle_exceptions diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 89e7833719a75..801cda4f71618 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -6369,6 +6369,9 @@ def tanh( out optional output, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + `ivy.func_wrapper.handle_complex_input` for more detail. Returns ------- diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index f71a331e618fe..78311a8d5f279 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -170,9 +170,6 @@ def test_jax_leaky_relu( fn_tree="jax.nn.gelu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), - large_abs_safety_factor=1, - small_abs_safety_factor=1, - safety_factor_scale="linear", min_value=-1e4, max_value=1e4, ), @@ -191,7 +188,7 @@ def test_jax_gelu( ): input_dtype, x = dtype_and_x # As erf function doesn't support complex dtype - if "complex" in str(x[0].dtype): + if "complex" in str(input_dtype): approximate = True helpers.test_frontend_function( input_dtypes=input_dtype, @@ -200,8 +197,8 @@ def test_jax_gelu( test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - rtol=1e-02, - atol=1e-02, + rtol=1e-2, + atol=1e-2, x=x[0], approximate=approximate, ) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 0cc9653edeeaa..ef7865378e61a 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -63,9 +63,6 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d fn_tree="functional.ivy.gelu", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), - large_abs_safety_factor=1, - small_abs_safety_factor=1, - safety_factor_scale="linear", min_value=-1e4, max_value=1e4, ), @@ -73,7 +70,7 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d ) def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x - if "complex" in str(x[0].dtype): + if "complex" in str(dtype): approximate = True helpers.test_function( input_dtypes=dtype, From 93dc7db8184c0f71277f80022bb5c90eedce70bd Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 9 Aug 2023 12:57:14 +0100 Subject: [PATCH 03/24] Refactor paddle backend pow and exp To improve support of complex numbers. pow was having stability issues due to exp(log(r)), while exp had no native complex number support so this was added (and eventually the call to it from pow was dropped, but its still useful to have). Also adjusted the tests of both, and of gelu --- ivy/functional/backends/paddle/elementwise.py | 24 ++++++++++--------- ivy/functional/ivy/activations.py | 5 ++-- .../test_core/test_elementwise.py | 4 +++- 3 files changed, 19 insertions(+), 14 deletions(-) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 4a757a5df0b5b..a0d88a3fe906a 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -720,20 +720,17 @@ def pow( paddle.int16, paddle.uint8, paddle.float16, - paddle.complex64, - paddle.complex128, paddle.bool, ]: - if paddle.is_complex(x1): - # https://math.stackexchange.com/questions/476968/complex-power-of-a-complex-number - r = paddle.abs(x1) - theta = paddle.angle(x1) - power = x2 * paddle.complex(paddle.log(r), theta) - result = paddle.exp(power.real()) * paddle.complex( - paddle.cos(power.imag()), paddle.sin(power.imag()) - ) - return result return paddle.pow(x1.astype("float32"), x2.astype("float32")).astype(ret_dtype) + if x1.dtype in [paddle.complex64, paddle.complex128]: + # https://math.stackexchange.com/questions/476968/complex-power-of-a-complex-number + r = paddle.abs(x1) + theta = paddle.angle(x1) + res_mag = paddle.pow(r, x2.real()) / paddle.exp(x2.imag() * theta) + res_ang = paddle.log(r) * x2.imag() + theta * x2.real() + result = res_mag * paddle.complex(paddle.cos(res_ang), paddle.sin(res_ang)) + return result.astype(ret_dtype) return paddle.pow(x1, x2) @@ -979,6 +976,11 @@ def log(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T def exp(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: return paddle.exp(x) + if x.dtype in [paddle.complex64, paddle.complex128]: + return paddle.multiply( + paddle.exp(x.real()), + paddle.complex(paddle.cos(x.imag()), paddle.sin(x.imag())), + ) return pow(math.e, x).astype(x.dtype) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index c68c7805efbea..b1876b676a543 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -21,7 +21,7 @@ def _gelu_jax_like( x: Union[ivy.Array, ivy.NativeArray], /, *, - fn_original=None, + fn_original: Optional[Callable] = None, approximate: bool = False, out: Optional[ivy.Array] = None, ) -> ivy.Array: @@ -35,7 +35,8 @@ def _gelu_jax_like( 1.0, ivy.tanh( ivy.multiply( - 0.7978845608, ivy.add(x, ivy.multiply(0.044715, ivy.pow(x, 3))) + ivy.sqrt(2 / ivy.pi).astype(x.dtype), + ivy.add(x, ivy.multiply(0.044715, ivy.pow(x, 3))), ) ), ), diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index a483811847ac9..423fd32df0508 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -511,7 +511,9 @@ def test_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # exp @handle_test( fn_tree="functional.ivy.exp", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_exp(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x From 0de9919340c8ab404b1619dd07f4d28987a88182 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 10 Aug 2023 12:12:40 +0100 Subject: [PATCH 04/24] Add complex number handling to trig functions in paddle backend --- ivy/functional/backends/paddle/elementwise.py | 40 ++++++++++++++++--- .../test_core/test_elementwise.py | 24 ++++++++--- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index a0d88a3fe906a..44cbfad8f1829 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -196,7 +196,7 @@ def floor(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool", "bfloat16")}}, backend_version, ) def asin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -278,7 +278,7 @@ def sqrt(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def cosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -292,6 +292,12 @@ def cosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.cosh(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + re = x.real() + im = x.imag() + return paddle.complex( + paddle.cosh(re) * paddle.cos(im), paddle.sinh(re) * paddle.sin(im) + ) return paddle.cosh(x) @@ -402,7 +408,7 @@ def multiply( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def cos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -416,6 +422,13 @@ def cos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T ]: ret_dtype = x.dtype return paddle.cos(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + re = x.real() + im = x.imag() + return paddle.complex( + paddle.cos(re) * paddle.cosh(im), + -paddle.sin(re) * paddle.sinh(im), + ) return paddle.cos(x) @@ -588,7 +601,7 @@ def acosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def sin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -601,6 +614,12 @@ def sin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T paddle.float16, ]: return paddle.sin(x.astype("float32")).astype(x.dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + re = x.real() + im = x.imag() + return paddle.complex( + paddle.sin(re) * paddle.cosh(im), paddle.cos(re) * paddle.sinh(im) + ) return paddle.sin(x) @@ -672,7 +691,7 @@ def bitwise_or( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def sinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -686,6 +705,12 @@ def sinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.sinh(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + re = x.real() + im = x.imag() + return paddle.complex( + paddle.sinh(re) * paddle.cos(im), paddle.cosh(re) * paddle.sin(im) + ) return paddle.sinh(x) @@ -907,7 +932,7 @@ def real(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def tan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -921,6 +946,9 @@ def tan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T ]: ret_dtype = x.dtype return paddle.tan(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + tanh_ix = paddle_backend.tanh(paddle.complex(-x.imag(), x.real())) + return paddle.complex(tanh_ix.imag(), -tanh_ix.real()) return paddle.tan(x) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 423fd32df0508..430aaedc9aefc 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -429,7 +429,9 @@ def test_ceil(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # cos @handle_test( fn_tree="functional.ivy.cos", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_cos(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -446,7 +448,9 @@ def test_cos(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # cosh @handle_test( fn_tree="functional.ivy.cosh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex"), + ), ) def test_cosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -827,7 +831,9 @@ def test_less_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # log @handle_test( fn_tree="functional.ivy.log", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_log(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -1368,7 +1374,9 @@ def test_sign(*, dtype_and_x, np_variant, test_flags, backend_fw, fn_name, on_de # sin @handle_test( fn_tree="functional.ivy.sin", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_sin(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -1385,7 +1393,9 @@ def test_sin(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # sinh @handle_test( fn_tree="functional.ivy.sinh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_sinh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -1472,7 +1482,9 @@ def test_subtract(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_dev # tan @handle_test( fn_tree="functional.ivy.tan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_tan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x From cada95d895c697171cd797bc4b9b9987ace9184b Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 10 Aug 2023 15:31:31 +0100 Subject: [PATCH 05/24] Implement complex number handling for inverse trig Implemented for asin, acos, atan, asinh, acosh, atanh. Also, modified tests for functions (e.g. exp, log, sqrt) that already support complex numbers across backends so that complex numbers will also be tested. This introduces a few new failures, mostly related to numerical stability in JAX, but this is only due to greater test coverage. --- ivy/functional/backends/paddle/elementwise.py | 46 ++++++++++++++++--- .../test_core/test_elementwise.py | 34 +++++++++----- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 44cbfad8f1829..35f913cd11435 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -196,7 +196,7 @@ def floor(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool", "bfloat16")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def asin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -210,11 +210,14 @@ def asin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.asin(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + asinh_iz = paddle_backend.asinh(paddle.complex(-x.imag(), x.real())) + return paddle.complex(asinh_iz.imag(), -asinh_iz.real()) return paddle.asin(x) @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def asinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -228,6 +231,14 @@ def asinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle ]: ret_dtype = x.dtype return paddle.asinh(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L276 # noqa + s1 = paddle_backend.sqrt(paddle.complex(1 + x.imag(), -x.real())) + s2 = paddle_backend.sqrt(paddle.complex(1 - x.imag(), x.real())) + return paddle.complex( + paddle.asinh(s1.real() * s2.imag() - s2.real() * s1.imag()), + paddle.atan2(x.imag(), s1.real() * s2.real() - s1.imag() * s2.imag()), + ) return paddle.asinh(x) @@ -356,7 +367,7 @@ def log1p(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle paddle.bool, ]: if paddle.is_complex(x): - return paddle.complex(paddle.log1p(paddle.abs(x)), paddle.angle(x + 1)) + return paddle_backend.log(x + 1) return paddle.log1p(x.astype("float32")).astype(x.dtype) return paddle.log1p(x) @@ -512,7 +523,7 @@ def greater_equal( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def acos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -525,6 +536,14 @@ def acos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. paddle.float16, ]: return paddle.acos(x.astype("float32")).astype(x.dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L178 # noqa + s1 = paddle_backend.sqrt(1 - x) + s2 = paddle_backend.sqrt(1 + x) + return paddle.complex( + 2.0 * paddle.atan2(s1.real(), s2.real()), + paddle.asinh(s2.real() * s1.imag() - s2.imag() * s1.real()), + ) return paddle.acos(x) @@ -584,7 +603,7 @@ def logical_or( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def acosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -597,6 +616,14 @@ def acosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle paddle.float16, ]: return paddle.acosh(x.astype("float32")).astype(x.dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L221 # noqa + s1 = paddle_backend.sqrt(paddle.complex(x.real() - 1, x.imag())) + s2 = paddle_backend.sqrt(paddle.complex(x.real() + 1, x.imag())) + return paddle.complex( + paddle.asinh(s1.real() * s2.real() + s1.imag() * s2.imag()), + 2.0 * paddle.atan2(s1.imag(), s2.real()), + ) return paddle.acosh(x) @@ -953,7 +980,7 @@ def tan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def atan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -967,6 +994,9 @@ def atan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.atan(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + atanh_iz = paddle_backend.atanh(paddle.complex(-x.imag(), x.real())) + return paddle.complex(atanh_iz.imag(), -atanh_iz.real()) return paddle.atan(x) @@ -1068,7 +1098,7 @@ def remainder( @with_unsupported_device_and_dtypes( - {"2.5.1 and below": {"cpu": ("complex64", "complex128", "bool")}}, + {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) def atanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: @@ -1082,6 +1112,8 @@ def atanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle ]: ret_dtype = x.dtype return paddle.atanh(x.astype("float32")).astype(ret_dtype) + if x.dtype in [paddle.complex64, paddle.complex128]: + return 0.5 * (paddle_backend.log(1 + x) - paddle_backend.log(1 - x)) return paddle.atanh(x) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 430aaedc9aefc..e4e7088142307 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -57,7 +57,7 @@ def test_abs(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.acosh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=1, large_abs_safety_factor=4, small_abs_safety_factor=4, @@ -81,7 +81,7 @@ def test_acosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.acos", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=4, small_abs_safety_factor=4, ), @@ -133,7 +133,7 @@ def test_add(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.asin", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=4, small_abs_safety_factor=4, ), @@ -156,7 +156,7 @@ def test_asin(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.asinh", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), large_abs_safety_factor=4, small_abs_safety_factor=4, ), @@ -178,7 +178,9 @@ def test_asinh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # atan @handle_test( fn_tree="functional.ivy.atan", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_atan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -225,7 +227,9 @@ def test_atan2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # atanh @handle_test( fn_tree="functional.ivy.atanh", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_atanh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -535,7 +539,7 @@ def test_exp(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.exp2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-10, max_value=10, min_num_dims=1, @@ -560,7 +564,9 @@ def test_exp2(dtype_and_x, test_flags, backend_fw, fn_name, on_device): # expm1 @handle_test( fn_tree="functional.ivy.expm1", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_expm1(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -853,7 +859,7 @@ def test_log(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.log1p", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), + available_dtypes=helpers.get_dtypes("float_and_complex"), small_abs_safety_factor=2, safety_factor_scale="log", ), @@ -875,7 +881,9 @@ def test_log1p(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # log2 @handle_test( fn_tree="functional.ivy.log2", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -895,7 +903,9 @@ def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): # log10 @handle_test( fn_tree="functional.ivy.log10", - dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("float")), + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float_and_complex") + ), ) def test_log10(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x @@ -1434,7 +1444,7 @@ def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.sqrt", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), allow_inf=False + available_dtypes=helpers.get_dtypes("float_and_complex"), allow_inf=False ), ) def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): From f7be90dc3efc648801deb29b1aa6d365284d49ad Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Mon, 14 Aug 2023 17:10:22 +0100 Subject: [PATCH 06/24] Fix some failing tests That included refactoring GeLU so the calculation of the complex approximation is done in the backend rather than the Ivy API. Also restricts the domain of the numpy square function to prevent overflows, and updates the paddle version on paddle frontend's expm1 --- ivy/functional/backends/paddle/activations.py | 6 ++++++ ivy/functional/backends/paddle/elementwise.py | 5 +++++ ivy/functional/backends/tensorflow/activations.py | 2 ++ ivy/functional/frontends/paddle/tensor/math.py | 2 +- ivy/functional/ivy/activations.py | 15 +-------------- .../test_miscellaneous.py | 3 +++ .../test_functional/test_nn/test_activations.py | 2 -- 7 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index c42e9f3931a12..99316a97230c0 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -68,6 +68,12 @@ def gelu( approximate: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: + if x.dtype in [paddle.complex64, paddle.complex128]: + return ( + 0.5 + * x + * (1 + paddle_backend.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) + ) if x.dtype in unsupported_dtypes: return F.gelu(x.cast("float32"), approximate=approximate).cast(x.dtype) return F.gelu(x, approximate=approximate) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 2694629153dbb..b7cd713147c61 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -757,6 +757,11 @@ def square( ) -> paddle.Tensor: if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: return paddle.square(x) + if x.dtype in [paddle.complex64, paddle.complex128]: + return paddle.complex( + paddle.square(paddle.real(x)) - paddle.square(paddle.imag(x)), + 2.0 * paddle.real(x) * paddle.imag(x), + ) return paddle_backend.pow(x, 2).astype(x.dtype) diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 5ba82d96269f6..c30240316d8ae 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -20,6 +20,8 @@ def gelu( x: Tensor, /, *, approximate: bool = False, out: Optional[Tensor] = None ) -> Tensor: + if x.dtype in [tf.complex64, tf.complex128]: + return 0.5 * x * (1 + tf.math.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) return tf.nn.gelu(x, approximate) diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index ecd9ddd0bb74c..bd89adf12110a 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -215,7 +215,7 @@ def lgamma(x, name=None): return ivy.lgamma(x) -@with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") @to_ivy_arrays_and_back def exp(x, name=None): return ivy.exp(x) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 2baac9ea687c2..ae380ee33c65d 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -28,20 +28,7 @@ def _gelu_jax_like( ) -> ivy.Array: # We don't have the exact implementation # cuz the erf function doesn't work on complex numbers - # Magic number #1 is sqrt(2/pi) - # Magic number #2 is from https://arxiv.org/abs/1606.08415v5 - return ivy.multiply( - ivy.multiply(x, 0.5), - ivy.add( - 1.0, - ivy.tanh( - ivy.multiply( - ivy.sqrt(2 / ivy.pi).astype(x.dtype), - ivy.add(x, ivy.multiply(0.044715, ivy.pow(x, 3))), - ) - ), - ), - ) + return fn_original(x, approximate=True, out=out) @handle_exceptions diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py index 9e20eba4b779d..6601bc8ad381e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_mathematical_functions/test_miscellaneous.py @@ -232,6 +232,9 @@ def test_numpy_reciprocal( arr_func=[ lambda: helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("numeric"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ) ], ), diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index ef7865378e61a..99f5f8db37222 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -70,8 +70,6 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d ) def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x - if "complex" in str(dtype): - approximate = True helpers.test_function( input_dtypes=dtype, backend_to_test=backend_fw, From ee7370d629a58f949d7890c681d901044378bdd8 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 15 Aug 2023 16:07:09 +0100 Subject: [PATCH 07/24] Fix some test failures Restrict test cases to avoid generation of too large or too small numbers, modify the JAX backend log1p because it gives incorrect values for some complex numbers. --- ivy/functional/backends/jax/elementwise.py | 4 ++++ .../test_jax/test_nn/test_non_linear_activations.py | 1 + .../test_functional/test_core/test_elementwise.py | 12 +++++++++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index d309a4ac52c36..8a959bf557bb7 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -304,6 +304,10 @@ def log10(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: def log1p(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: + if x.dtype in ["complex64", "complex128"]: + # For certain complex numbers, jnp.log1p(z) != jnp.log(1+z) + # I don't know why + return jnp.log(1 + x) return jnp.log1p(x) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index e55b9ca0cc06c..b7b63a8c6cb13 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -172,6 +172,7 @@ def test_jax_leaky_relu( available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=-1e4, max_value=1e4, + abs_smallest_val=1e-3, ), approximate=st.booleans(), test_with_out=st.just(False), diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index e4e7088142307..0e17bd792e6bf 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -565,7 +565,11 @@ def test_exp2(dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.expm1", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float_and_complex"), + # Can't use linear or log safety factor, since the function is exponential, + # next best option is a hardcoded maximum that won't break any data type. + # expm1 is designed for very small values anyway + max_value=20.0, ), ) def test_expm1(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @@ -838,7 +842,8 @@ def test_less_equal(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.log", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float_and_complex"), + safety_factor_scale="log", ), ) def test_log(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @@ -904,7 +909,8 @@ def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.log10", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float_and_complex"), + safety_factor_scale="log", ), ) def test_log10(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): From 96daad1f4d9a3d31cad36f5bfe6ba98bf4ccf897 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 15 Aug 2023 17:34:22 +0100 Subject: [PATCH 08/24] Fix more test failures Restrict the domain of more tests to avoid test cases that introduce instability --- .../test_functional/test_core/test_elementwise.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 0e17bd792e6bf..765b0b246f1ca 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -59,8 +59,9 @@ def test_abs(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), min_value=1, - large_abs_safety_factor=4, - small_abs_safety_factor=4, + large_abs_safety_factor=2.1, + small_abs_safety_factor=2.1, + safety_factor_scale="log", ), ) def test_acosh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @@ -887,7 +888,8 @@ def test_log1p(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.log2", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float_and_complex"), + safety_factor_scale="log", ), ) def test_log2(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): From 36cefa620b50bce4f8a6931fe91c3be21e60f5e0 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 16 Aug 2023 11:06:11 +0100 Subject: [PATCH 09/24] Add complex_mode to test_tanh, fix associated bugs Adding `complex_mode` to the test of tanh exposed a bug with the handle_complex_input decorator, which is also fixed in this commit. Also removes `complex_mode` argument from the `_forward` method of some stateful API classes (as per #21902) and adds some context for the magic numbers in paddle.gelu --- ivy/func_wrapper.py | 17 ++++++---- ivy/functional/backends/paddle/activations.py | 5 ++- ivy/stateful/activations.py | 33 ++++++------------- .../test_core/test_elementwise.py | 4 ++- 4 files changed, 28 insertions(+), 31 deletions(-) diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index 47c6aa77a7a33..7bd3cc3baadd5 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -1486,16 +1486,21 @@ def _handle_complex_input( jax_like = fn.jax_like if hasattr(fn, "jax_like") else "entire" if complex_mode == "split" or (complex_mode == "jax" and jax_like == "split"): - real_inp = ivy.real(inp) - imag_inp = ivy.imag(inp) - return fn(real_inp, *args, **kwargs) + 1j * fn(imag_inp, *args, **kwargs) + real_inp = ivy.real(inp).data + imag_inp = ivy.imag(inp).data + return ivy.add( + fn(real_inp, *args, **kwargs), + ivy.multiply(1j, fn(imag_inp, *args, **kwargs)), + ) elif complex_mode == "magnitude" or ( complex_mode == "jax" and jax_like == "magnitude" ): - mag_inp = ivy.abs(inp) - angle_inp = ivy.angle(inp) - return fn(mag_inp, *args, **kwargs) * ivy.exp(1j * angle_inp) + mag_inp = ivy.abs(inp).data + angle_inp = ivy.angle(inp).data + return ivy.multiply( + fn(mag_inp, *args, **kwargs), ivy.exp(ivy.multiply(1j, angle_inp)) + ) elif complex_mode == "jax" and jax_like == "entire": return fn(inp, *args, **kwargs) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 99316a97230c0..44231fc701f25 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -69,10 +69,13 @@ def gelu( out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if x.dtype in [paddle.complex64, paddle.complex128]: + sqrt_2_over_pi = 0.7978845608 + # the other magic number comes directly from the formula in + # https://doi.org/10.48550/arXiv.1606.08415 return ( 0.5 * x - * (1 + paddle_backend.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) + * (1 + paddle_backend.tanh(sqrt_2_over_pi * (x + 0.044715 * x * x * x))) ) if x.dtype in unsupported_dtypes: return F.gelu(x.cast("float32"), approximate=approximate).cast(x.dtype) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 84dd94dc9d966..d8514305a29fd 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -18,14 +18,7 @@ def __init__( self._complex_mode = complex_mode Module.__init__(self) - def _forward( - self, - x, - /, - *, - approximate=None, - complex_mode=None, - ): + def _forward(self, x): """ Perform forward pass of the GELU activation. @@ -41,8 +34,8 @@ def _forward( """ return ivy.gelu( x, - approximate=ivy.default(approximate, self._approximate), - complex_mode=ivy.default(complex_mode, self._complex_mode), + approximate=self._approximate, + complex_mode=self._complex_mode, ) @@ -78,7 +71,7 @@ def __init__( self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters @@ -91,7 +84,7 @@ def _forward(self, x, complex_mode=None): ret The outputs following the RELU activation *[batch_shape, d]* """ - return ivy.relu(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.relu(x, self._complex_mode) class LeakyReLU(Module): @@ -114,17 +107,13 @@ def __init__( self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, *, alpha=None, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - alpha - Negative slope for ReLU. - complex_mode - Specifies how to handle complex input. Returns ------- @@ -133,8 +122,8 @@ def _forward(self, x, *, alpha=None, complex_mode=None): """ return ivy.leaky_relu( x, - alpha=ivy.default(alpha, self._alpha), - complex_mode=ivy.default(complex_mode, self._complex_mode), + alpha=self._alpha, + complex_mode=self._complex_mode, ) @@ -287,22 +276,20 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): self._complex_mode = complex_mode Module.__init__(self) - def _forward(self, x, complex_mode=None): + def _forward(self, x): """ Parameters ---------- x Inputs to process *[batch_shape, d]*. - complex_mode - Specifies how to handle complex input. Returns ------- ret The outputs following the TANH activation *[batch_shape, d]* """ - return ivy.tanh(x, complex_mode=ivy.default(complex_mode, self._complex_mode)) + return ivy.tanh(x, self._complex_mode) class ReLU6(Module): diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 765b0b246f1ca..693e39bffb6fe 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -1524,8 +1524,9 @@ def test_tan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex") ), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_tanh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_tanh(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x helpers.test_function( input_dtypes=input_dtype, @@ -1536,6 +1537,7 @@ def test_tanh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): rtol_=1e-1, atol_=1e-2, x=x[0], + complex_mode=complex_mode, ) From 78f158b81a1ba0bfd4a640bd553adf411a67737e Mon Sep 17 00:00:00 2001 From: Joe Shepherd <72616151+jshepherd01@users.noreply.github.com> Date: Wed, 16 Aug 2023 11:39:27 +0100 Subject: [PATCH 10/24] Change logic for testing for complex input Suggested from code review. Uses paddle's `is_complex()` function Co-authored-by: Mahmoud Ashraf --- ivy/functional/backends/paddle/activations.py | 2 +- ivy/functional/backends/paddle/elementwise.py | 26 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 44231fc701f25..6bd37aac6c15f 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -68,7 +68,7 @@ def gelu( approximate: bool = False, out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): sqrt_2_over_pi = 0.7978845608 # the other magic number comes directly from the formula in # https://doi.org/10.48550/arXiv.1606.08415 diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index b7cd713147c61..d4e8fda2cddaa 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -210,7 +210,7 @@ def asin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.asin(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): asinh_iz = paddle_backend.asinh(paddle.complex(-x.imag(), x.real())) return paddle.complex(asinh_iz.imag(), -asinh_iz.real()) return paddle.asin(x) @@ -231,7 +231,7 @@ def asinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle ]: ret_dtype = x.dtype return paddle.asinh(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L276 # noqa s1 = paddle_backend.sqrt(paddle.complex(1 + x.imag(), -x.real())) s2 = paddle_backend.sqrt(paddle.complex(1 - x.imag(), x.real())) @@ -304,7 +304,7 @@ def cosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.cosh(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): re = x.real() im = x.imag() return paddle.complex( @@ -434,7 +434,7 @@ def cos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T ]: ret_dtype = x.dtype return paddle.cos(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): re = x.real() im = x.imag() return paddle.complex( @@ -537,7 +537,7 @@ def acos(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. paddle.float16, ]: return paddle.acos(x.astype("float32")).astype(x.dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L178 # noqa s1 = paddle_backend.sqrt(1 - x) s2 = paddle_backend.sqrt(1 + x) @@ -617,7 +617,7 @@ def acosh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle paddle.float16, ]: return paddle.acosh(x.astype("float32")).astype(x.dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): # From https://github.com/python/cpython/blob/39ef93edb9802dccdb6555d4209ac2e60875a011/Modules/cmathmodule.c#L221 # noqa s1 = paddle_backend.sqrt(paddle.complex(x.real() - 1, x.imag())) s2 = paddle_backend.sqrt(paddle.complex(x.real() + 1, x.imag())) @@ -642,7 +642,7 @@ def sin(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T paddle.float16, ]: return paddle.sin(x.astype("float32")).astype(x.dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): re = x.real() im = x.imag() return paddle.complex( @@ -687,7 +687,7 @@ def tanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. paddle.float16, ]: return paddle.tanh(x.astype("float32")).astype(x.dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): tanh_a = paddle.tanh(paddle.real(x)) tan_b = paddle.tan(paddle.imag(x)) return (tanh_a + 1j * tan_b) / (1 + 1j * (tanh_a * tan_b)) @@ -733,7 +733,7 @@ def sinh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle. ]: ret_dtype = x.dtype return paddle.sinh(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): re = x.real() im = x.imag() return paddle.complex( @@ -757,7 +757,7 @@ def square( ) -> paddle.Tensor: if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: return paddle.square(x) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): return paddle.complex( paddle.square(paddle.real(x)) - paddle.square(paddle.imag(x)), 2.0 * paddle.real(x) * paddle.imag(x), @@ -979,7 +979,7 @@ def tan(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T ]: ret_dtype = x.dtype return paddle.tan(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): tanh_ix = paddle_backend.tanh(paddle.complex(-x.imag(), x.real())) return paddle.complex(tanh_ix.imag(), -tanh_ix.real()) return paddle.tan(x) @@ -1040,7 +1040,7 @@ def log(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T def exp(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: if x.dtype in [paddle.int32, paddle.int64, paddle.float32, paddle.float64]: return paddle.exp(x) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): return paddle.multiply( paddle.exp(x.real()), paddle.complex(paddle.cos(x.imag()), paddle.sin(x.imag())), @@ -1118,7 +1118,7 @@ def atanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle ]: ret_dtype = x.dtype return paddle.atanh(x.astype("float32")).astype(ret_dtype) - if x.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x): return 0.5 * (paddle_backend.log(1 + x) - paddle_backend.log(1 - x)) return paddle.atanh(x) From d61dca171c991c07abe70fde355fa084c2b33c71 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 16 Aug 2023 13:58:52 +0100 Subject: [PATCH 11/24] Fix line length issue Linter was throwing an error about a line in a docstring being too long --- ivy/functional/ivy/elementwise.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 331d4a63fe25f..abe2db096835b 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -3152,7 +3152,8 @@ def greater_equal( With a mix of :class:`ivy.Array` and :class:`ivy.Container` inputs: >>> x = ivy.array([[5.1, 2.3, -3.6]]) - >>> y = ivy.Container(a=ivy.array([[4.], [5.], [6.]]), b=ivy.array([[5.], [6.], [7.]])) + >>> y = ivy.Container(a=ivy.array([[4.], [5.], [6.]]), + ... b=ivy.array([[5.], [6.], [7.]])) >>> z = ivy.greater_equal(x, y) >>> print(z) { From 0de40cc4e32de20c9e8eb570a27a3e81ed1e7423 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 16 Aug 2023 16:47:33 +0100 Subject: [PATCH 12/24] Minor changes to pow and exp Change `pow` to use the same `is_complex` logic as other functions, and `exp` to explicitly call `pow` via `paddle_backend` rather than from the same file directly --- ivy/functional/backends/paddle/elementwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index d4e8fda2cddaa..5baa5663eefa0 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -781,7 +781,7 @@ def pow( paddle.bool, ]: return paddle.pow(x1.astype("float32"), x2.astype("float32")).astype(ret_dtype) - if x1.dtype in [paddle.complex64, paddle.complex128]: + if paddle.is_complex(x1): # https://math.stackexchange.com/questions/476968/complex-power-of-a-complex-number r = paddle.abs(x1) theta = paddle.angle(x1) @@ -1045,7 +1045,7 @@ def exp(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.T paddle.exp(x.real()), paddle.complex(paddle.cos(x.imag()), paddle.sin(x.imag())), ) - return pow(math.e, x).astype(x.dtype) + return paddle_backend.pow(math.e, x).astype(x.dtype) def exp2( From 7ce83433c47445794ede64395a1cf2cb3534497b Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 17 Aug 2023 10:46:26 +0100 Subject: [PATCH 13/24] Implemented suggested changes Standardised `complex_mode` docstring across wrapper functions, add `complex_mode` to `array.gelu`, fix `complex_mode` parameter in stateful API, remove special case handling in `jax_backend.log1p` and restrict its tests to managable values. --- ivy/data_classes/array/activations.py | 14 +++++++++++--- ivy/data_classes/array/elementwise.py | 3 ++- ivy/data_classes/container/activations.py | 18 ++++++++++++------ ivy/data_classes/container/elementwise.py | 6 ++++-- ivy/functional/backends/jax/elementwise.py | 4 ---- ivy/functional/ivy/activations.py | 13 +++++++------ ivy/functional/ivy/elementwise.py | 2 +- ivy/stateful/activations.py | 4 ++-- .../test_nn/test_non_linear_activations.py | 2 +- .../test_core/test_elementwise.py | 1 + 10 files changed, 41 insertions(+), 26 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index 33fa7e1844cb1..c33e995045c32 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -30,7 +30,8 @@ def relu( optional output array, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -69,7 +70,8 @@ def leaky_relu( optional output array, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -93,6 +95,7 @@ def gelu( *, approximate: bool = False, out: Optional[ivy.Array] = None, + complex_mode: Literal["split", "magnitude", "jax"] = "jax", ) -> ivy.Array: """ ivy.Array instance method variant of ivy.gelu. This method simply wraps the @@ -108,6 +111,9 @@ def gelu( out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. + complex_mode + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -121,7 +127,9 @@ def gelu( >>> print(y) ivy.array([-0.138, -0.165, 1.4]) """ - return ivy.gelu(self._data, approximate=approximate, out=out) + return ivy.gelu( + self._data, approximate=approximate, out=out, complex_mode=complex_mode + ) def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ diff --git a/ivy/data_classes/array/elementwise.py b/ivy/data_classes/array/elementwise.py index ff8255d8394fb..645407a9edb29 100644 --- a/ivy/data_classes/array/elementwise.py +++ b/ivy/data_classes/array/elementwise.py @@ -2572,7 +2572,8 @@ def tanh( optional output, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index 104ae63ea31af..ea5e9f6f7ca21 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -45,7 +45,8 @@ def _static_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -109,7 +110,8 @@ def relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -176,7 +178,8 @@ def _static_leaky_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -243,7 +246,8 @@ def leaky_relu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -310,7 +314,8 @@ def _static_gelu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -376,7 +381,8 @@ def gelu( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- diff --git a/ivy/data_classes/container/elementwise.py b/ivy/data_classes/container/elementwise.py index c5f0a7e0f5679..815b98429f6ac 100644 --- a/ivy/data_classes/container/elementwise.py +++ b/ivy/data_classes/container/elementwise.py @@ -8301,7 +8301,8 @@ def _static_tanh( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -8366,7 +8367,8 @@ def tanh( optional output container, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode - optional specifier for how to handle complex data types. + optional specifier for how to handle complex data types. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index 8a959bf557bb7..d309a4ac52c36 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -304,10 +304,6 @@ def log10(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: def log1p(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: - if x.dtype in ["complex64", "complex128"]: - # For certain complex numbers, jnp.log1p(z) != jnp.log(1+z) - # I don't know why - return jnp.log(1 + x) return jnp.log1p(x) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index ae380ee33c65d..27529e2985437 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -56,13 +56,14 @@ def gelu( x Input array. approximate - Whether to approximate, default is ``True``. + Whether to approximate, default is ``True``. An approximation is always used if + the input array is complex. out optional output array, for writing the result to. It must have a shape that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -156,7 +157,7 @@ def leaky_relu( inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- @@ -314,8 +315,8 @@ def relu( If the input is complex, then by default each element is set to zero if either its real part is strictly negative or if its real part is zero and its - imaginary part is negative. - This behaviour can be changed by specifying a different `complex_mode`. + imaginary part is negative. This behaviour can be changed by specifying a different + `complex_mode`. Parameters ---------- @@ -326,7 +327,7 @@ def relu( inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 63f5c7c4c60d4..27b3032d1553a 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -6451,7 +6451,7 @@ def tanh( broadcast to. complex_mode optional specifier for how to handle complex data types. See - `ivy.func_wrapper.handle_complex_input` for more detail. + ``ivy.func_wrapper.handle_complex_input`` for more detail. Returns ------- diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index d8514305a29fd..3691b56b3bccf 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -84,7 +84,7 @@ def _forward(self, x): ret The outputs following the RELU activation *[batch_shape, d]* """ - return ivy.relu(x, self._complex_mode) + return ivy.relu(x, complex_mode=self._complex_mode) class LeakyReLU(Module): @@ -289,7 +289,7 @@ def _forward(self, x): ret The outputs following the TANH activation *[batch_shape, d]* """ - return ivy.tanh(x, self._complex_mode) + return ivy.tanh(x, complex_mode=self._complex_mode) class ReLU6(Module): diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index b7b63a8c6cb13..2a6baeb015d05 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -189,7 +189,7 @@ def test_jax_gelu( ): input_dtype, x = dtype_and_x # As erf function doesn't support complex dtype - if "complex" in str(input_dtype): + if "complex" in str(input_dtype[0]): approximate = True helpers.test_frontend_function( input_dtypes=input_dtype, diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 693e39bffb6fe..534c834897cda 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -867,6 +867,7 @@ def test_log(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), small_abs_safety_factor=2, + large_abs_safety_factor=2.1, safety_factor_scale="log", ), ) From e182d4f317e00d288f54db338cd40ab75d6998ae Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 17 Aug 2023 10:51:33 +0100 Subject: [PATCH 14/24] Remove unnecessary `str()` operation `input_dtype[0]` is already a string so I don't need to turn it into one --- .../test_jax/test_nn/test_non_linear_activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 2a6baeb015d05..714d94d9ff44c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -189,7 +189,7 @@ def test_jax_gelu( ): input_dtype, x = dtype_and_x # As erf function doesn't support complex dtype - if "complex" in str(input_dtype[0]): + if "complex" in input_dtype[0]: approximate = True helpers.test_frontend_function( input_dtypes=input_dtype, From 3cc05de159556493366e6c5856e0ef2453dc9a53 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 17 Aug 2023 11:45:07 +0100 Subject: [PATCH 15/24] Fix doctstrings and tests for complex activations Fix doctsrings on the stateful API for ReLU and GELU, add complex_mode to tests for ReLU, GELU, LeakyReLU --- ivy/stateful/activations.py | 30 +++++++++++++++---- .../test_nn/test_activations.py | 23 ++++++++++++-- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/ivy/stateful/activations.py b/ivy/stateful/activations.py index 3691b56b3bccf..b3775647d2516 100644 --- a/ivy/stateful/activations.py +++ b/ivy/stateful/activations.py @@ -13,7 +13,17 @@ def __init__( approximate: bool = False, complex_mode: Literal["split", "magnitude", "jax"] = "jax", ): - """Apply the GELU activation function.""" + """ + Apply the GELU activation function. + + Parameters + ---------- + approximate + whether to use the gelu approximation algorithm or exact formulation. + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ self._approximate = approximate self._complex_mode = complex_mode Module.__init__(self) @@ -67,7 +77,15 @@ def __init__( self, complex_mode: Literal["split", "magnitude", "jax"] = "jax", ): - """Apply the RELU activation function.""" + """ + Apply the RELU activation function. + + Parameters + ---------- + complex_mode + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. + """ self._complex_mode = complex_mode Module.__init__(self) @@ -99,9 +117,10 @@ def __init__( Parameters ---------- alpha - Negative slope for ReLU. + Negative slope for ReLU. complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._alpha = alpha self._complex_mode = complex_mode @@ -271,7 +290,8 @@ def __init__(self, complex_mode: Literal["split", "magnitude", "jax"] = "jax"): Parameters ---------- complex_mode - Specifies how to handle complex input. + Specifies how to handle complex input. See + ``ivy.func_wrapper.handle_complex_input`` for more detail. """ self._complex_mode = complex_mode Module.__init__(self) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 99f5f8db37222..9d816bee06056 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -17,8 +17,9 @@ small_abs_safety_factor=8, safety_factor_scale="log", ), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): +def test_relu(*, dtype_and_x, complex_mode, test_flags, backend_fw, fn_name, on_device): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -27,6 +28,7 @@ def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_name=fn_name, on_device=on_device, x=x[0], + complex_mode=complex_mode, ) @@ -42,8 +44,11 @@ def test_relu(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): safety_factor_scale="log", ), alpha=st.floats(min_value=-1e-4, max_value=1e-4), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_device): +def test_leaky_relu( + *, dtype_and_x, alpha, complex_mode, test_flags, backend_fw, fn_name, on_device +): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -55,6 +60,7 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d atol_=1e-2, x=x[0], alpha=alpha, + complex_mode=complex_mode, ) @@ -67,8 +73,18 @@ def test_leaky_relu(*, dtype_and_x, alpha, test_flags, backend_fw, fn_name, on_d max_value=1e4, ), approximate=st.booleans(), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) -def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_device): +def test_gelu( + *, + dtype_and_x, + approximate, + complex_mode, + test_flags, + backend_fw, + fn_name, + on_device, +): dtype, x = dtype_and_x helpers.test_function( input_dtypes=dtype, @@ -80,6 +96,7 @@ def test_gelu(*, dtype_and_x, approximate, test_flags, backend_fw, fn_name, on_d rtol_=1e-2, x=x[0], approximate=approximate, + complex_mode=complex_mode, ) From a6e228b3cd5dca8fa39f52c8aaa9dc85904a8fc2 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Fri, 18 Aug 2023 12:50:22 +0100 Subject: [PATCH 16/24] Add complex_mode to backend call signatures For all activation functions which currently accept it: tanh, relu, leaky_relu, and gelu. --- ivy/functional/backends/jax/activations.py | 12 ++++++++++-- ivy/functional/backends/jax/elementwise.py | 4 +++- ivy/functional/backends/mxnet/activations.py | 15 ++++++++++++--- ivy/functional/backends/mxnet/elementwise.py | 1 + ivy/functional/backends/numpy/activations.py | 18 +++++++++++++++--- ivy/functional/backends/numpy/elementwise.py | 4 +++- ivy/functional/backends/paddle/activations.py | 6 +++++- ivy/functional/backends/paddle/elementwise.py | 4 +++- .../backends/tensorflow/activations.py | 16 +++++++++++++--- .../backends/tensorflow/elementwise.py | 1 + ivy/functional/backends/torch/activations.py | 6 +++++- ivy/functional/backends/torch/elementwise.py | 4 +++- 12 files changed, 74 insertions(+), 17 deletions(-) diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index d7705c9847d14..c6dac4872e11a 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -18,17 +18,25 @@ def gelu( *, approximate: bool = False, out: Optional[JaxArray] = None, + complex_mode="jax", ) -> JaxArray: return jax.nn.gelu(x, approximate) def leaky_relu( - x: JaxArray, /, *, alpha: float = 0.2, out: Optional[JaxArray] = None + x: JaxArray, + /, + *, + alpha: float = 0.2, + out: Optional[JaxArray] = None, + complex_mode="jax", ) -> JaxArray: return jnp.asarray(jnp.where(x > 0, x, jnp.multiply(x, alpha)), x.dtype) -def relu(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def relu( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: return jnp.maximum(x, 0) diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index d309a4ac52c36..67dd0657df5ac 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -502,7 +502,9 @@ def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: return jnp.tan(x) -def tanh(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: +def tanh( + x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" +) -> JaxArray: return jnp.tanh(x) diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 9b6cfdcc9467f..48f95e3029f6b 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -11,7 +11,14 @@ from typing import Optional, Union -def gelu(x: None, /, *, approximate: bool = False, out: Optional[None] = None) -> None: +def gelu( + x: None, + /, + *, + approximate: bool = False, + out: Optional[None] = None, + complex_mode="jax", +) -> None: if approximate: return ( 0.5 * x * (1 + mx.nd.tanh(((2 / np.pi) ** 0.5) * (x + 0.044715 * x**3))) @@ -19,11 +26,13 @@ def gelu(x: None, /, *, approximate: bool = False, out: Optional[None] = None) - return mx.nd.LeakyReLU(x, act_type="gelu") -def leaky_relu(x: None, /, *, alpha: float = 0.2, out: Optional[None] = None) -> None: +def leaky_relu( + x: None, /, *, alpha: float = 0.2, out: Optional[None] = None, complex_mode="jax" +) -> None: return mx.nd.LeakyReLU(x, slope=alpha) -def relu(x: None, /, *, out: Optional[None] = None) -> None: +def relu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: return mx.nd.relu(x) diff --git a/ivy/functional/backends/mxnet/elementwise.py b/ivy/functional/backends/mxnet/elementwise.py index e05fe3d7a67b2..2f262db28d7ca 100644 --- a/ivy/functional/backends/mxnet/elementwise.py +++ b/ivy/functional/backends/mxnet/elementwise.py @@ -599,6 +599,7 @@ def tanh( /, *, out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, + complex_mode="jax", ) -> Union[(None, mx.ndarray.NDArray)]: return mx.nd.tanh(x) diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index d2ee4c0675de3..4b052baee3b75 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -10,7 +10,9 @@ @_scalar_output_to_0d_array -def relu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def relu( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: return np.maximum(x, 0, out=out, dtype=x.dtype) @@ -18,14 +20,24 @@ def relu(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: def leaky_relu( - x: np.ndarray, /, *, alpha: float = 0.2, out: Optional[np.ndarray] = None + x: np.ndarray, + /, + *, + alpha: float = 0.2, + out: Optional[np.ndarray] = None, + complex_mode="jax", ) -> np.ndarray: return np.asarray(np.where(x > 0, x, np.multiply(x, alpha)), x.dtype) @_scalar_output_to_0d_array def gelu( - x: np.ndarray, /, *, approximate: bool = False, out: Optional[np.ndarray] = None + x: np.ndarray, + /, + *, + approximate: bool = False, + out: Optional[np.ndarray] = None, + complex_mode="jax", ) -> np.ndarray: if approximate: ret = 0.5 * x * (1 + np.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) diff --git a/ivy/functional/backends/numpy/elementwise.py b/ivy/functional/backends/numpy/elementwise.py index ccc37b3298076..441f470fd5d94 100644 --- a/ivy/functional/backends/numpy/elementwise.py +++ b/ivy/functional/backends/numpy/elementwise.py @@ -753,7 +753,9 @@ def tan(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array -def tanh(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: +def tanh( + x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" +) -> np.ndarray: return np.tanh(x, out=out) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 6bd37aac6c15f..45bfe5665f9d9 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -30,7 +30,9 @@ ] -def relu(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def relu( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" +) -> paddle.Tensor: if x.dtype in unsupported_dtypes: if paddle.is_complex(x): return paddle.complex(F.relu(x.real()), F.relu(x.imag())) @@ -47,6 +49,7 @@ def leaky_relu( *, alpha: float = 0.2, out: Optional[paddle.Tensor] = None, + complex_mode="jax", ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: if paddle.is_complex(x): @@ -67,6 +70,7 @@ def gelu( *, approximate: bool = False, out: Optional[paddle.Tensor] = None, + complex_mode="jax", ) -> paddle.Tensor: if paddle.is_complex(x): sqrt_2_over_pi = 0.7978845608 diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 5baa5663eefa0..6c09b66c67d76 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -677,7 +677,9 @@ def not_equal( {"2.5.1 and below": {"cpu": ("bool", "bfloat16")}}, backend_version, ) -def tanh(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: +def tanh( + x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" +) -> paddle.Tensor: if x.dtype in [ paddle.int8, paddle.int16, diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index c30240316d8ae..39d4cc6817ccd 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -18,7 +18,12 @@ def gelu( - x: Tensor, /, *, approximate: bool = False, out: Optional[Tensor] = None + x: Tensor, + /, + *, + approximate: bool = False, + out: Optional[Tensor] = None, + complex_mode="jax", ) -> Tensor: if x.dtype in [tf.complex64, tf.complex128]: return 0.5 * x * (1 + tf.math.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) @@ -26,12 +31,17 @@ def gelu( def leaky_relu( - x: Tensor, /, *, alpha: float = 0.2, out: Optional[Tensor] = None + x: Tensor, + /, + *, + alpha: float = 0.2, + out: Optional[Tensor] = None, + complex_mode="jax", ) -> Tensor: return tf.nn.leaky_relu(x, alpha) -def relu(x: Tensor, /, *, out: Optional[Tensor] = None) -> Tensor: +def relu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: return tf.nn.relu(x) diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 81d1850f626a9..858eb875358ec 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -745,6 +745,7 @@ def tanh( /, *, out: Optional[Union[tf.Tensor, tf.Variable]] = None, + complex_mode="jax", ) -> Union[tf.Tensor, tf.Variable]: return tf.tanh(x) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index f94890adb1424..cee0c75dc235d 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -18,7 +18,9 @@ @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) -def relu(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def relu( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: return torch.relu(x) @@ -29,6 +31,7 @@ def leaky_relu( *, alpha: float = 0.2, out: Optional[torch.Tensor] = None, + complex_mode="jax", ) -> torch.Tensor: return torch.nn.functional.leaky_relu(x, alpha) @@ -40,6 +43,7 @@ def gelu( *, approximate: bool = False, out: Optional[torch.Tensor] = None, + complex_mode="jax", ) -> torch.Tensor: if approximate: return ( diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index 784452d922e1f..bf6b68aca8050 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -507,7 +507,9 @@ def not_equal( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) @handle_numpy_arrays_in_specific_backend -def tanh(x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None) -> torch.Tensor: +def tanh( + x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" +) -> torch.Tensor: x = _cast_for_unary_op(x) return torch.tanh(x, out=out) From 71ad8cd6bc2173e37406046a2fc70671463149f7 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Mon, 21 Aug 2023 11:42:55 +0100 Subject: [PATCH 17/24] Merge branch 'master' into complex-maths-functions --- .github/workflows/pypi.yml | 8 +- deploy_pypi.sh | 2 +- ivy/data_classes/array/general.py | 16 +++ ivy/data_classes/container/general.py | 73 +++++++++++- .../backends/jax/experimental/sorting.py | 2 + ivy/functional/backends/paddle/elementwise.py | 61 +++++++--- .../backends/paddle/experimental/sorting.py | 4 + ivy/functional/backends/paddle/layers.py | 60 +++++++++- .../tensorflow/experimental/sorting.py | 2 +- .../backends/torch/experimental/sorting.py | 4 + ivy/functional/frontends/jax/numpy/linalg.py | 4 +- .../frontends/jax/numpy/statistical.py | 11 +- .../mindspore/ops/function/nn_func.py | 19 ++++ .../numpy/manipulation_routines/__init__.py | 2 + .../manipulation_routines/padding_arrays.py | 10 ++ .../frontends/paddle/tensor/math.py | 12 ++ .../frontends/paddle/tensor/stat.py | 14 +++ .../frontends/paddle/vision/transforms.py | 25 ++++ ivy/functional/frontends/pandas/__init__.py | 2 + ivy/functional/frontends/pandas/dataframe.py | 15 ++- ivy/functional/frontends/pandas/generic.py | 6 +- ivy/functional/frontends/pandas/index.py | 76 +++++++++++++ ivy/functional/frontends/pandas/series.py | 10 +- ivy/functional/frontends/torch/tensor.py | 23 ++++ ivy/functional/ivy/general.py | 66 ++++++++++- .../test_ivy/helpers/function_testing.py | 23 ++-- ivy_tests/test_ivy/helpers/pipeline_helper.py | 3 + .../test_ivy/helpers/test_parameter_flags.py | 8 +- .../test_jax/test_lax/test_operators.py | 2 +- .../test_jax/test_numpy/test_dtype.py | 3 +- .../test_jax/test_numpy/test_linalg.py | 32 ++---- .../test_jax/test_numpy/test_logic.py | 5 + .../test_jax/test_numpy/test_statistical.py | 1 + .../test_function/test_mindspore_nn_func.py | 26 +++++ .../test_padding_arrays.py | 107 ++++++++++++++++++ .../test_paddle/test_tensor/test_math.py | 32 ++++++ .../test_paddle/test_tensor/test_stat.py | 31 +++++ .../test_vision/test_transforms.py | 43 +++++++ .../test_functional/test_pooling_functions.py | 5 +- .../test_frontends/test_torch/test_tensor.py | 80 +++++++++++++ .../test_core/test_elementwise.py | 2 +- .../test_core/test_sorting.py | 27 +++-- requirements/requirements.txt | 12 +- 43 files changed, 864 insertions(+), 105 deletions(-) create mode 100644 ivy/functional/frontends/numpy/manipulation_routines/padding_arrays.py create mode 100644 ivy/functional/frontends/pandas/index.py create mode 100644 ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 066b28449aa98..7032f151da031 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -1,6 +1,5 @@ name: pypi on: - push: workflow_call: jobs: upload-if-tagged-commit: @@ -19,7 +18,12 @@ jobs: python3 -m pip install --upgrade build python3 -m pip install --upgrade twine - - name: Upload to Pypi + - name: Add Tag to Version + run: | + cd ivy + echo "__version__ = '${{ github.ref_name }}'" > ivy/_version.py + + - name: Upload to PyPI if: startsWith(github.ref, 'refs/tags') env: PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} diff --git a/deploy_pypi.sh b/deploy_pypi.sh index c407487072e05..cf7548a06023a 100644 --- a/deploy_pypi.sh +++ b/deploy_pypi.sh @@ -1,2 +1,2 @@ python3 -m build -python3 -m twine upload dist/* --repository testpypi -u "__token__" -p "$PYPI_PASSWORD_TEST" --verbose +python3 -m twine upload dist/* -u "__token__" -p "$PYPI_PASSWORD" --verbose diff --git a/ivy/data_classes/array/general.py b/ivy/data_classes/array/general.py index 7463b57285e24..7c94e779a673f 100644 --- a/ivy/data_classes/array/general.py +++ b/ivy/data_classes/array/general.py @@ -1091,6 +1091,22 @@ def stable_pow( ------- ret The new item following the numerically stable power. + + Examples + -------- + With :class:`ivy.Array` instance method: + + >>> x = ivy.asarray([2, 4]) + >>> y = x.stable_pow(2) + >>> print(y) + ivy.array([ 4.00004, 16.00008]) + + >>> x = ivy.asarray([[2., 4.], [6., 8.]]) + >>> y = ivy.asarray([2., 4.]) + >>> z = x.stable_pow(y) + >>> print(z) + ivy.array([[4.00004000e+00, 2.56002560e+02], + [3.60001200e+01, 4.09602048e+03]]) """ return ivy.stable_pow(self, exponent, min_base=min_base) diff --git a/ivy/data_classes/container/general.py b/ivy/data_classes/container/general.py index 290894b7a52e4..b51ae9617a9cb 100644 --- a/ivy/data_classes/container/general.py +++ b/ivy/data_classes/container/general.py @@ -1015,7 +1015,6 @@ def assert_supports_inplace( ret An ivy.Container instance of True bool values if nodes of the Container \ support in-place operations, raises IvyBackendException otherwise - Examples -------- @@ -3195,6 +3194,42 @@ def _static_stable_pow( ret A container of elements containing the new items following the numerically stable power. + + Examples + -------- + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = ivy.Container.stable_pow(x, 2) + >>> print(y) + { + a: ivy.array([4.00004, 16.00008]), + b: ivy.array([36.00012, 64.00016]) + } + + >>> x = ivy.Container(a=4, b=8) + >>> y = ivy.Container.stable_pow(x, 2) + >>> print(y) + { + a: ivy.array(16.00008), + b: ivy.array(64.00016) + } + + >>> x = ivy.Container(a=4, b=8) + >>> y = ivy.asarray([1, 2]) + >>> z = ivy.Container.stable_pow(x, y) + >>> print(z) + { + a: ivy.array([4.00001, 16.00008]), + b: ivy.array([8.00001, 64.00016]) + } + + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = ivy.Container(a=4, b=8) + >>> z = ivy.Container.stable_pow(x, y) + >>> print(z) + { + a: ivy.array([16.00032, 256.00256]), + b: ivy.array([1679638.395, 16777383.77]) + } """ return ContainerBase.cont_multi_map_in_function( "stable_pow", @@ -3248,6 +3283,42 @@ def stable_pow( ret A container of elements containing the new items following the numerically stable power. + + Examples + -------- + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = x.stable_pow(2) + >>> print(y) + { + a: ivy.array([4.00004, 16.00008]), + b: ivy.array([36.00012, 64.00016]) + } + + >>> x = ivy.Container(a=4, b=8) + >>> y = x.stable_pow(2) + >>> print(y) + { + a: ivy.array(16.00008), + b: ivy.array(64.00016) + } + + >>> x = ivy.Container(a=4, b=8) + >>> y = ivy.asarray([1, 2]) + >>> z = x.stable_pow(y) + >>> print(z) + { + a: ivy.array([4.00001, 16.00008]), + b: ivy.array([8.00001, 64.00016]) + } + + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = ivy.Container(a=4, b=8) + >>> z = x.stable_pow(y) + >>> print(z) + { + a: ivy.array([16.00032, 256.00256]), + b: ivy.array([1679638.395, 16777383.77]) + } """ return self._static_stable_pow( self, diff --git a/ivy/functional/backends/jax/experimental/sorting.py b/ivy/functional/backends/jax/experimental/sorting.py index 8a9774a4783cb..15f7f59d47b87 100644 --- a/ivy/functional/backends/jax/experimental/sorting.py +++ b/ivy/functional/backends/jax/experimental/sorting.py @@ -3,6 +3,7 @@ from typing import Optional, Union # local +import ivy from ivy.func_wrapper import with_unsupported_dtypes from ivy.functional.backends.jax import JaxArray from . import backend_version @@ -13,6 +14,7 @@ def invert_permutation( x: Union[JaxArray, list, tuple], /, ) -> JaxArray: + x = jnp.array(x) if not ivy.is_array(x) else x sorted_indices = jnp.argsort(x) inverse = jnp.zeros_like(sorted_indices) inverse = inverse.at[sorted_indices].set(jnp.arange(len(x))) diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 6c09b66c67d76..1c987c2bf806b 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -1,5 +1,5 @@ # global -from typing import Union, Optional +from typing import Union, Optional, Tuple, Type import paddle import math @@ -267,26 +267,55 @@ def sign( return paddle.sgn(x) -def sqrt(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: - if x.dtype in [ +# TODO: Remove `float16` from the list once paddle add it's supporting kernel to `CPU`. +def _determine_sqrt_dtype_cast( + dtype: Type[paddle.Tensor], +) -> Tuple[Optional[str], Optional[str]]: + """ + Determine the appropriate casting dtype for sqrt operations. + + Returns: + (intermediate_dtype, output_dtype) + """ + + cast_and_return_float32_dtype = { paddle.int8, paddle.int16, paddle.int32, - paddle.int64, paddle.uint8, - paddle.float16, - paddle.complex64, - paddle.complex128, paddle.bool, - ]: - if paddle.is_complex(x): - angle = paddle.angle(x) - result = paddle.complex( - paddle.cos(angle / 2), paddle.sin(angle / 2) - ) * paddle.sqrt(paddle.abs(x)) - return result - return paddle.sqrt(x.astype("float32")).astype(x.dtype) - return paddle.sqrt(x) + } + + if dtype in cast_and_return_float32_dtype: + return "float32", "float32" + elif dtype == paddle.int64: + return "float64", "float64" + elif dtype == paddle.float16: + return "float32", "float16" + elif dtype == paddle.bfloat16: + return "float32", "bfloat16" + else: + return None, None + + +def sqrt(x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None) -> paddle.Tensor: + """Calculate the square root with type handling.""" + + if paddle.is_complex(x): + angle = paddle.angle(x) + return paddle.complex( + paddle.cos(angle / 2), paddle.sin(angle / 2) + ) * paddle.sqrt(paddle.abs(x)) + + if x.dtype in {paddle.float32, paddle.float64}: + return paddle.sqrt(x) + + intermediate_dtype, output_dtype = _determine_sqrt_dtype_cast(x.dtype) + if intermediate_dtype: + result = paddle.sqrt(x.astype(intermediate_dtype)) + return result.astype(output_dtype) + + raise ValueError(f"Unsupported data type for sqrt: {x.dtype}") @with_unsupported_device_and_dtypes( diff --git a/ivy/functional/backends/paddle/experimental/sorting.py b/ivy/functional/backends/paddle/experimental/sorting.py index ab4e5548d8d8d..c2686a5d3486a 100644 --- a/ivy/functional/backends/paddle/experimental/sorting.py +++ b/ivy/functional/backends/paddle/experimental/sorting.py @@ -2,12 +2,16 @@ import paddle from typing import Optional, Union +# local +import ivy + # invert_permutation def invert_permutation( x: Union[paddle.Tensor, list, tuple], /, ) -> paddle.Tensor: + x = paddle.to_tensor(x) if not ivy.is_array(x) else x sorted_indices = paddle.argsort(x) inverse = paddle.zeros_like(sorted_indices) inverse[sorted_indices] = paddle.arange(len(x)) diff --git a/ivy/functional/backends/paddle/layers.py b/ivy/functional/backends/paddle/layers.py index f05e230d7d013..f70c374bc70de 100644 --- a/ivy/functional/backends/paddle/layers.py +++ b/ivy/functional/backends/paddle/layers.py @@ -435,4 +435,62 @@ def conv_general_transpose( bias: Optional[paddle.Tensor] = None, out: Optional[paddle.Tensor] = None, ): - raise IvyNotImplementedException() + if data_format == "channel_last": + x = x.transpose(x, (0, dims + 1, *range(1, dims + 1))) + strides = [strides] * dims if isinstance(strides, int) else strides + dilations = [dilations] * dims if isinstance(dilations, int) else dilations + filters = filters.transpose(dims, dims + 1, *range(dims)) + not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose( + x, filters, strides, padding, dims, dilations, output_shape, filters.shape[2:] + ) + if dims == 1: + res = paddle.nn.functional.conv1d_transpose( + x, + filters, + bias=bias, + stride=strides, + padding=padding_list, + output_padding=output_padding, + groups=feature_group_count, + dilation=dilations, + data_format="NCL", + ) + if not_valid_pad[0]: + res = res[:, :, 0:-1] + elif dims == 2: + res = paddle.nn.functional.conv2d_transpose( + x, + filters, + bias=bias, + stride=strides, + padding=padding_list, + output_padding=output_padding, + groups=feature_group_count, + dilation=dilations, + data_format="NCHW", + ) + if not_valid_pad[0]: + res = res[:, :, 0:-1, :] + if not_valid_pad[1]: + res = res[:, :, :, 0:-1] + else: + res = paddle.nn.functional.conv3d_transpose( + x, + filters, + bias=bias, + stride=strides, + padding=padding_list, + output_padding=output_padding, + groups=feature_group_count, + dilation=dilations, + data_format="NCDHW", + ) + if not_valid_pad[0]: + res = res[:, 0:-1, :, :] + if not_valid_pad[1]: + res = res[:, :, 0:-1, :] + if not_valid_pad[2]: + res = res[:, :, :, 0:-1] + if data_format == "channel_last": + res = res.transpose(0, *range(2, dims + 2), 1) + return res diff --git a/ivy/functional/backends/tensorflow/experimental/sorting.py b/ivy/functional/backends/tensorflow/experimental/sorting.py index fb8f0f5951ad2..120457e45ecad 100644 --- a/ivy/functional/backends/tensorflow/experimental/sorting.py +++ b/ivy/functional/backends/tensorflow/experimental/sorting.py @@ -8,7 +8,7 @@ def invert_permutation( x: Union[tf.Tensor, tf.Variable, list, tuple], /, ) -> Union[tf.Tensor, tf.Variable]: - return tf.math.invert_permutation(x) + return tf.cast(tf.math.invert_permutation(x), tf.int64) # lexsort diff --git a/ivy/functional/backends/torch/experimental/sorting.py b/ivy/functional/backends/torch/experimental/sorting.py index 74f0d9b92c2ac..c6b7ee0f5bb06 100644 --- a/ivy/functional/backends/torch/experimental/sorting.py +++ b/ivy/functional/backends/torch/experimental/sorting.py @@ -2,12 +2,16 @@ import torch from typing import Optional, Union +# local +import ivy + # invert_permutation def invert_permutation( x: Union[torch.Tensor, list, tuple], /, ) -> torch.Tensor: + x = torch.tensor(x) if not ivy.is_array(x) else x sorted_indices = torch.argsort(x) inverse = torch.zeros_like(sorted_indices) inverse[sorted_indices] = torch.arange(len(x)) diff --git a/ivy/functional/frontends/jax/numpy/linalg.py b/ivy/functional/frontends/jax/numpy/linalg.py index c888da224d17a..819eb9821dc69 100644 --- a/ivy/functional/frontends/jax/numpy/linalg.py +++ b/ivy/functional/frontends/jax/numpy/linalg.py @@ -129,5 +129,5 @@ def cond(x, p=None): @to_ivy_arrays_and_back -def multi_dot(a): - return ivy.multi_dot(a) +def multi_dot(arrays, *, precision=None): + return ivy.multi_dot(arrays) diff --git a/ivy/functional/frontends/jax/numpy/statistical.py b/ivy/functional/frontends/jax/numpy/statistical.py index 069e6aedba66b..a797d1c244f0b 100644 --- a/ivy/functional/frontends/jax/numpy/statistical.py +++ b/ivy/functional/frontends/jax/numpy/statistical.py @@ -247,7 +247,7 @@ def nanmax( ax = axis[0] % len(s) else: ax = axis % len(s) - s[ax] = 1 + s[ax] = ivy.array(1) header = ivy.full(ivy.Shape(s.to_list()), initial, dtype=ivy.dtype(a)) if axis: if isinstance(axis, (tuple, list)) or ivy.is_array(axis): @@ -268,7 +268,7 @@ def nanmax( ) if where_mask is not None and ivy.any(where_mask): res = ivy.where(ivy.logical_not(where_mask), res, ivy.nan, out=out) - return res + return res.astype(ivy.dtype(a)) @to_ivy_arrays_and_back @@ -294,7 +294,8 @@ def nanmin( ax = axis[0] % len(s) else: ax = axis % len(s) - s[ax] = 1 + + s[ax] = ivy.array(1) header = ivy.full(ivy.Shape(s.to_list()), initial, dtype=ivy.dtype(a)) if axis: if isinstance(axis, (tuple, list)) or ivy.is_array(axis): @@ -315,7 +316,7 @@ def nanmin( ) if where_mask is not None and ivy.any(where_mask): res = ivy.where(ivy.logical_not(where_mask), res, ivy.nan, out=out) - return res + return res.astype(ivy.dtype(a)) @handle_jax_dtype @@ -442,7 +443,7 @@ def nanmedian( ): return ivy.nanmedian( a, axis=axis, keepdims=keepdims, out=out, overwrite_input=overwrite_input - ) + ).astype(a.dtype) @to_ivy_arrays_and_back diff --git a/ivy/functional/frontends/mindspore/ops/function/nn_func.py b/ivy/functional/frontends/mindspore/ops/function/nn_func.py index 03b3a91561982..6c0ff69dba5dd 100644 --- a/ivy/functional/frontends/mindspore/ops/function/nn_func.py +++ b/ivy/functional/frontends/mindspore/ops/function/nn_func.py @@ -298,6 +298,25 @@ def interpolate( ) +@with_supported_dtypes( + { + "2.0 and below": ( + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "float64", + ) + }, + "mindspore", +) +@to_ivy_arrays_and_back +def hardswish(x): + return ivy.hardswish(x) + + @with_supported_dtypes( { "2.0 and below": ( diff --git a/ivy/functional/frontends/numpy/manipulation_routines/__init__.py b/ivy/functional/frontends/numpy/manipulation_routines/__init__.py index 9c6512f847f2a..07310cab242b6 100644 --- a/ivy/functional/frontends/numpy/manipulation_routines/__init__.py +++ b/ivy/functional/frontends/numpy/manipulation_routines/__init__.py @@ -8,6 +8,8 @@ from .changing_kind_of_array import * from . import changing_number_of_dimensions from .changing_number_of_dimensions import * +from . import padding_arrays +from .padding_arrays import * from . import joining_arrays from .joining_arrays import * from . import rearranging_elements diff --git a/ivy/functional/frontends/numpy/manipulation_routines/padding_arrays.py b/ivy/functional/frontends/numpy/manipulation_routines/padding_arrays.py new file mode 100644 index 0000000000000..a1e94e52b962f --- /dev/null +++ b/ivy/functional/frontends/numpy/manipulation_routines/padding_arrays.py @@ -0,0 +1,10 @@ +# local +import ivy +from ivy.functional.frontends.numpy.func_wrapper import ( + to_ivy_arrays_and_back, +) + + +@to_ivy_arrays_and_back +def pad(array, pad_width, mode="constant", **kwargs): + return ivy.pad(array, pad_width, mode=mode, **kwargs) diff --git a/ivy/functional/frontends/paddle/tensor/math.py b/ivy/functional/frontends/paddle/tensor/math.py index bd89adf12110a..7edf833a66feb 100644 --- a/ivy/functional/frontends/paddle/tensor/math.py +++ b/ivy/functional/frontends/paddle/tensor/math.py @@ -479,3 +479,15 @@ def amax(x, axis=None, keepdims=False): if i < 0 or i >= x.ndim: raise ValueError("axis {} is out of range [-{}:{}]".format(i, 0, x.ndim)) return ivy.max(x, axis=axis, keepdims=keepdims) + + +@with_supported_dtypes({"2.5.0 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def stanh(x, scale_a=0.67, scale_b=1.7159, name=None): + # TODO this function will be simplified as soon as the ivy.stanh(x,a,b) is added + exp_ax = ivy.exp(ivy.multiply(scale_a, x)) + exp_minus_ax = ivy.exp(ivy.multiply(-scale_a, x)) + numerator = ivy.subtract(exp_ax, exp_minus_ax) + denominator = ivy.add(exp_ax, exp_minus_ax) + ret = ivy.multiply(scale_b, ivy.divide(numerator, denominator)) + return ret diff --git a/ivy/functional/frontends/paddle/tensor/stat.py b/ivy/functional/frontends/paddle/tensor/stat.py index a97120e4e3d12..637935b943542 100644 --- a/ivy/functional/frontends/paddle/tensor/stat.py +++ b/ivy/functional/frontends/paddle/tensor/stat.py @@ -69,3 +69,17 @@ def nanmedian(x, axis=None, keepdim=True, name=None): else ivy.astype(x, ivy.float32) ) return ivy.median(x, axis=axis, keepdims=keepdim) + + +@with_supported_dtypes( + {"2.5.1 and below": ("float32", "float64", "uint16")}, + "paddle", +) +@to_ivy_arrays_and_back +def std(x, axis=None, unbiased=True, keepdim=False, name=None): + x = ( + ivy.astype(x, ivy.float64) + if ivy.dtype(x) == "float64" + else ivy.astype(x, ivy.float32) + ) + return ivy.std(x, axis=axis, correction=int(unbiased), keepdims=keepdim) diff --git a/ivy/functional/frontends/paddle/vision/transforms.py b/ivy/functional/frontends/paddle/vision/transforms.py index 36abc70964357..eef2cd26117cd 100644 --- a/ivy/functional/frontends/paddle/vision/transforms.py +++ b/ivy/functional/frontends/paddle/vision/transforms.py @@ -30,6 +30,18 @@ def _get_image_num_channels(img, data_format): return ivy.shape(img)[_get_image_c_axis(data_format)] +def _blend_images(img1, img2, ratio): + # TODO: ivy.check_float(img1) returns False for ivy array + # TODO: when lerp supports int type and when the above issue is fixed, + # replace this with ivy.check_float(img1) + max_value = ( + 1.0 if ivy.dtype(img1) == "float32" or ivy.dtype(img1) == "float64" else 255.0 + ) + return ivy.astype( + ivy.lerp(img2, img1, float(ratio)).clip(0, max_value), ivy.dtype(img1) + ) + + def _rgb_to_hsv(img): maxc = ivy.max(img, axis=-3) minc = ivy.min(img, axis=-3) @@ -120,6 +132,19 @@ def adjust_hue(img, hue_factor): return img_adjusted +@with_supported_dtypes({"2.5.1 and below": ("float32", "float64")}, "paddle") +@to_ivy_arrays_and_back +def adjust_brightness(img, brightness_factor): + assert brightness_factor >= 0, "brightness_factor should be non-negative." + assert _get_image_num_channels(img, "CHW") in [ + 1, + 3, + ], "channels of input should be either 1 or 3." + + extreme_target = ivy.zeros_like(img) + return _blend_images(img, extreme_target, brightness_factor) + + @with_unsupported_device_and_dtypes( { "2.5.1 and below": { diff --git a/ivy/functional/frontends/pandas/__init__.py b/ivy/functional/frontends/pandas/__init__.py index d64d79c9d4028..56f0f260ff9c6 100644 --- a/ivy/functional/frontends/pandas/__init__.py +++ b/ivy/functional/frontends/pandas/__init__.py @@ -2,6 +2,8 @@ from .pandas_func_wrappers import * from . import series from .series import * +from . import index +from .index import * from . import dataframe from .dataframe import * from . import generic diff --git a/ivy/functional/frontends/pandas/dataframe.py b/ivy/functional/frontends/pandas/dataframe.py index 03a717abd26bc..7a1c40b52b2e4 100644 --- a/ivy/functional/frontends/pandas/dataframe.py +++ b/ivy/functional/frontends/pandas/dataframe.py @@ -1,6 +1,7 @@ from .generic import NDFrame import ivy from .series import Series +from ivy.functional.frontends.pandas.index import Index class DataFrame(NDFrame): @@ -79,7 +80,7 @@ def __repr__(self): def sum(self, axis=None, skipna=True, level=None, numeric_only=None, min_count=0): _array = self.array if axis is None or axis == "index": - axis = 0 # due to https://github.com/pandas-dev/pandas/issues/54547. TODO: remove this when fixed + axis = 0 # due to https://github.com/pandas-dev/pandas/issues/54547. TODO: remove this when fixed # noqa: E501 elif axis == "columns": axis = 1 if min_count > 0: @@ -104,4 +105,14 @@ def mean(self, axis=0, skipna=True, numeric_only=None, **kwargs): ret = _array.mean(axis=axis) if axis is None: return ret # scalar case - return Series(ret, index=self.columns if axis in (0, "index") else self.index) + return Series( + ret, index=Index(self.columns) if axis in (0, "index") else self.index + ) + + def get(self, key, default=None): + if key in self.columns: + return self[key] + return default + + def keys(self): + return self.columns diff --git a/ivy/functional/frontends/pandas/generic.py b/ivy/functional/frontends/pandas/generic.py index 964da6e78035c..1b21efc89208f 100644 --- a/ivy/functional/frontends/pandas/generic.py +++ b/ivy/functional/frontends/pandas/generic.py @@ -3,6 +3,7 @@ import copy as py_copy from ivy.functional.frontends.pandas.pandas_func_wrappers import outputs_to_self_class import ivy.functional.frontends.pandas.series as series +from ivy.functional.frontends.pandas.index import Index class NDFrame: @@ -24,7 +25,7 @@ def __init__(self, data, index, columns, dtype, name, copy, *args, **kwargs): orig_data_len = len(self.orig_data) if index is None: if data_is_array_or_like: - index = ivy.arange(orig_data_len).tolist() + index = ivy.arange(orig_data_len) elif isinstance(data, dict): index = list(data.keys()) elif isinstance(data, series.Series): @@ -57,6 +58,9 @@ def __init__(self, data, index, columns, dtype, name, copy, *args, **kwargs): "Data must be one of array, dict, iterables, scalar value or Series." f" Got {type(data)}" ) + self.index = ( + Index(self.index) if not isinstance(self.index, Index) else self.index + ) @property def data(self): diff --git a/ivy/functional/frontends/pandas/index.py b/ivy/functional/frontends/pandas/index.py new file mode 100644 index 0000000000000..91bc8788e900c --- /dev/null +++ b/ivy/functional/frontends/pandas/index.py @@ -0,0 +1,76 @@ +import ivy +import ivy.functional.frontends.pandas.series as series + + +class Index: + def __init__(self, data, dtype=None, copy=False, name=None, tupleize_cols=True): + self.index = data + if not isinstance(data, ivy.Array): + self.index_array = ivy.array(data, dtype=dtype) + else: + self.index_array = data + self.dtype = dtype + self.name = name + self.copy = copy + self.tupleize_cols = tupleize_cols + + def __repr__(self): + return f"Index {self.index_array.to_list()}" + + def __len__(self): + return len(self.index_array) + + def __iter__(self): + return iter(self.index_array.to_list()) + + @property + def ndim(self): + return self.index_array.ndim + + @property + def size(self): + return self.index_array.size + + @property + def array(self): + return self.index_array + + @property + def shape(self): + return tuple(self.index_array.shape) + + @property + def has_duplicates(self): + return not self.is_unique() + + def unique(self, level=None): + # todo handle level with mutliindexer + self.index_array = ivy.unique_values(self) + return Index(self.index_array, dtype=self.dtype, copy=self.copy, name=self.name) + + def is_unique(self): + uniques = ivy.unique_values(self) + return len(uniques) == len(self.index_array) + + def to_list(self): + return self.index_array.to_list() + + def to_numpy(self, dtype=None, copy=False, na_value=ivy.nan, **kwargs): + if dtype: + return self.index_array.astype(dtype).to_numpy(copy=copy) + return self.index_array.to_numpy(copy=copy) + + def to_series(self, index=None, name=None): + if index is None: + index = self.index_array + return series.Series(index, index=index, name=name) + + def min(self, axis=None, skipna=True, *args, **kwargs): + return self.index_array.min() + + def max(self, axis=None, skipna=True, *args, **kwargs): + return self.index_array.max() + + def isin(self, values, level=None): + # todo handle level with mutliindexer + return ivy.isin(self.index_array, values) diff --git a/ivy/functional/frontends/pandas/series.py b/ivy/functional/frontends/pandas/series.py index a6e11d1774c91..3429e445a9116 100644 --- a/ivy/functional/frontends/pandas/series.py +++ b/ivy/functional/frontends/pandas/series.py @@ -31,7 +31,7 @@ def __repr__(self): series_name = f"{self.name} " if self.name is not None else "" return ( f"frontends.pandas.Series {series_name}({self.array.to_list()}," - f" index={self.index})" + f" index={self.index.array.to_list()})" ) def __getitem__(self, index_val): @@ -77,3 +77,11 @@ def add(self, other, level=None, fill_value=None, axis=0): # todo handle data alignment new_array = ivy.add(self.array, other.array) return Series(new_array) + + def get(self, key, default=None): + if key in self.index: + return self[key] + return default + + def keys(self): + return self.index diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index b7bf0114bc8e8..6e70ae356a545 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -883,6 +883,11 @@ def inverse(self): def neg(self): return torch_frontend.negative(self) + @with_unsupported_dtypes({"2.0.1 and below": ("bool",)}, "torch") + def neg_(self): + self.ivy_array = torch_frontend.negative(self).ivy_array + return self + __neg__ = neg def int(self, memory_format=None): @@ -1775,6 +1780,24 @@ def isnan(self): def lcm(self, other, *, out=None): return torch_frontend.lcm(self, other, out=out) + @with_unsupported_dtypes( + { + "2.0.1 and below": ( + "bfloat16", + "int8", + "uint8", + "int16", + "complex128", + "complex64", + "bool", + ) + }, + "torch", + ) + def triu_(self, diagonal=0): + self.ivy_array = torch_frontend.triu(self, diagonal).ivy_array + return self + @with_unsupported_dtypes( {"2.0.1 and below": ("float16", "bfloat16")}, "torch", diff --git a/ivy/functional/ivy/general.py b/ivy/functional/ivy/general.py index 478a42d463bf8..82bffa79101df 100644 --- a/ivy/functional/ivy/general.py +++ b/ivy/functional/ivy/general.py @@ -2302,6 +2302,64 @@ def stable_pow( ------- ret The new item following the numerically stable power. + + Examples + -------- + With :code:`int` input: + + >>> x = ivy.stable_pow(2, 2) + >>> print(x) + ivy.array(4.00004) + + >>> x = ivy.stable_pow(2, 2, min_base=2) + >>> print(x) + ivy.array(16) + + With float input: + + >>> x = ivy.stable_pow(4.0, .5) + >>> print(x) + ivy.array(2.00000262) + + With :code:`complex` input: + + >>> x = ivy.stable_pow(3+4j, 2j) + >>> print(x) + ivy.array(-0.15605032-0.01208451j) + + With :class:`ivy.Array` input: + + >>> x = ivy.asarray([[2, 4], + ... [6, 8]]) + >>> y = ivy.stable_pow(x, 2) + >>> print(y) + ivy.array([[ 4.00004, 16.00008], + [36.00012, 64.00016]]) + + >>> x = ivy.asarray([2, 4, 6]) + >>> y = ivy.asarray([2, 3, 4]) + >>> z = ivy.stable_pow(x, y) + >>> print(z) + ivy.array([ 4.00004, 64.00048, 1296.00864]) + + With :class:`ivy.Container` input: + + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = ivy.stable_pow(x, 2) + >>> print(y) + { + a: ivy.array([4.00004, 16.00008]), + b: ivy.array([36.00012, 64.00016]) + } + + >>> x = ivy.Container(a=ivy.asarray([2, 4]), b=ivy.asarray([6, 8])) + >>> y = ivy.Container(a=ivy.asarray([1, 3]), b=ivy.asarray([4, 5])) + >>> z = ivy.stable_pow(x, y) + >>> print(z) + { + a: ivy.array([2.00001, 64.00048]), + b: ivy.array([1296.00864, 32768.2048]) + } """ return_dtype = ivy.promote_types( ivy.default_dtype(item=base), @@ -2829,7 +2887,7 @@ def set_item( *[i for i in range(len(x.shape)) if i not in vector_inds], ] x = ivy.permute_dims(x, axes=perm) - inv_perm = ivy.invert_permutation(perm) + inv_perm = ivy.invert_permutation(perm).to_list() val = _broadcast_to(val, target_shape).astype(x.dtype) ret = ivy.scatter_nd(query, val, reduction="replace", out=x) if inv_perm is not None: @@ -3048,10 +3106,10 @@ def _parse_slice(idx, s): stop = idx.stop if stop > s: stop = s - elif stop <= -s: + elif stop < -s: + stop = -1 + elif stop == -s: stop = 0 - if start == 0: - stop = -1 elif stop < 0: stop = stop + s q_i = ivy.arange(start, stop, step).to_list() diff --git a/ivy_tests/test_ivy/helpers/function_testing.py b/ivy_tests/test_ivy/helpers/function_testing.py index bf607c54a3345..ad3771bfc1ba6 100644 --- a/ivy_tests/test_ivy/helpers/function_testing.py +++ b/ivy_tests/test_ivy/helpers/function_testing.py @@ -34,7 +34,7 @@ def compiled_if_required(backend: str, fn, test_compile=False, args=None, kwargs with BackendHandler.update_backend(backend) as ivy_backend: if test_compile: fn = ivy_backend.compile(fn, args=args, kwargs=kwargs) - return fn + return fn # Ivy Function testing ########################## @@ -944,7 +944,7 @@ def _gt_grad_fn(all_args): args, kwargs, i = all_args call_fn = gt_backend.__dict__[fn] if isinstance(fn, str) else fn[i] ret = compiled_if_required( - backend_to_test, + ground_truth_backend, call_fn, test_compile=test_compile, args=args, @@ -959,7 +959,7 @@ def _gt_grad_fn(all_args): ret_grad_idxs=ret_grad_idxs, ) grads_np_from_gt_flat = flatten_and_to_np( - backend=backend_to_test, ret=grads_from_gt + backend=ground_truth_backend, ret=grads_from_gt ) assert len(grads_np_flat) == len( @@ -971,15 +971,14 @@ def _gt_grad_fn(all_args): len(grads_np_from_gt_flat), ) - for grad_np_flat, grad_np_from_gt_flat in zip(grads_np_flat, grads_np_from_gt_flat): - value_test( - ret_np_flat=grad_np_flat, - ret_np_from_gt_flat=grad_np_from_gt_flat, - rtol=rtol_, - atol=atol_, - backend=backend_to_test, - ground_truth_backend=ground_truth_backend, - ) + value_test( + ret_np_flat=grads_np_flat, + ret_np_from_gt_flat=grads_np_from_gt_flat, + rtol=rtol_, + atol=atol_, + backend=backend_to_test, + ground_truth_backend=ground_truth_backend, + ) def test_method( diff --git a/ivy_tests/test_ivy/helpers/pipeline_helper.py b/ivy_tests/test_ivy/helpers/pipeline_helper.py index 5adbf0610531d..15b863932879f 100644 --- a/ivy_tests/test_ivy/helpers/pipeline_helper.py +++ b/ivy_tests/test_ivy/helpers/pipeline_helper.py @@ -22,13 +22,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): class BackendHandler: _context = WithBackendContext + _current_mode = BackendHandlerMode.WithBackend @classmethod def _update_context(cls, mode: BackendHandlerMode): if mode == BackendHandlerMode.WithBackend: cls._context = WithBackendContext + cls._current_mode = BackendHandlerMode.WithBackend elif mode == BackendHandlerMode.SetBackend: cls._context = ivy.utils.backend.ContextManager + cls._current_mode = BackendHandlerMode.SetBackend else: raise ValueError(f"Unknown backend handler mode! {mode}") diff --git a/ivy_tests/test_ivy/helpers/test_parameter_flags.py b/ivy_tests/test_ivy/helpers/test_parameter_flags.py index ae78322de24f0..370e0cb554618 100644 --- a/ivy_tests/test_ivy/helpers/test_parameter_flags.py +++ b/ivy_tests/test_ivy/helpers/test_parameter_flags.py @@ -7,8 +7,8 @@ @st.composite def _gradient_strategy(draw): if test_globals.CURRENT_BACKEND == "numpy": - draw(st.just(False)) - draw(st.booleans()) + return draw(st.just(False)) + return draw(st.booleans()) @st.composite @@ -28,9 +28,7 @@ def _as_varaible_strategy(draw): @st.composite -def _compile_strategy(draw): # TODO remove later when paddle is supported - if test_globals.CURRENT_BACKEND == "paddle": - draw(st.just(False)) +def _compile_strategy(draw): draw(st.booleans()) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py index 7a4ca4114a910..4ed4f7e81bd96 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_lax/test_operators.py @@ -2700,7 +2700,7 @@ def test_jax_select( helpers.test_frontend_function( input_dtypes=["bool"] + input_dtype, frontend=frontend, - bakcend_to_test=backend_fw, + backend_to_test=backend_fw, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py index 2f6fefb0d05b2..50d366dc025f1 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_dtype.py @@ -70,8 +70,7 @@ def test_jax_promote_types( type2=type2[0], test_values=False, ) - assert str(ret._ivy_dtype) == str(frontend_ret) - print(frontend_ret) + assert str(ret._ivy_dtype) == str(frontend_ret[0]) @handle_frontend_test( diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py index f89a309afd2f5..e6d3827c2214c 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_linalg.py @@ -896,12 +896,14 @@ def test_jax_cond( # multi_dot @handle_frontend_test( - fn_tree="jax.lax.linalg.multi_dot", + fn_tree="jax.numpy.linalg.multi_dot", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float"), min_value=0, max_value=10, shape=helpers.ints(min_value=2, max_value=5).map(lambda x: tuple([x, x])), + num_arrays=2, + shared_dtype=True, ).filter( lambda x: "float16" not in x[0] and "bfloat16" not in x[0] @@ -910,34 +912,16 @@ def test_jax_cond( ), test_with_out=st.just(False), ) -def test_jax_lax_multi_dot( - *, - dtype_and_x, - on_device, - fn_tree, - frontend, - test_flags, +def test_jax_multi_dot( + *, dtype_and_x, on_device, fn_tree, frontend, test_flags, backend_fw ): dtype, x = dtype_and_x - x = np.asarray(x[0], dtype=dtype[0]) - - ret, frontend_ret = helpers.test_frontend_function( + helpers.test_frontend_function( input_dtypes=dtype, frontend=frontend, test_flags=test_flags, fn_tree=fn_tree, on_device=on_device, - test_values=False, - args=(x,), - ) - - ret = ivy.to_numpy(ret) - frontend_ret = np.asarray(frontend_ret) - - assert_all_close( - ret_np=ret, - ret_from_gt_np=frontend_ret, - rtol=1e-2, - atol=1e-2, - ground_truth_backend=frontend, + arrays=(x[0], x[1]), + backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py index 136fb168b77b6..81bdc4ef7f786 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_logic.py @@ -700,6 +700,7 @@ def test_jax_isin( fn_tree, frontend, test_flags, + backend_fw, ): x_and_dtype = assume_unique_and_dtype_and_x dtypes, values = x_and_dtype @@ -713,6 +714,7 @@ def test_jax_isin( element=elements, test_elements=test_elements, invert=invert, + backend_to_test=backend_fw, ) @@ -832,6 +834,7 @@ def test_jax_logical_or( on_device=on_device, x1=x[0], x2=x[1], + backend_to_test=backend_fw, ) @@ -1131,6 +1134,7 @@ def test_jax_numpy_packbits( *, fn_tree, test_flags, + backend_fw, ): input_dtype, x, axis = dtype_x_axis helpers.test_frontend_function( @@ -1142,4 +1146,5 @@ def test_jax_numpy_packbits( x=x[0], axis=axis, bitorder=bitorder, + backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py index fa75caf0a5b06..b076d278fc809 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_statistical.py @@ -346,6 +346,7 @@ def test_jax_sum( keepdims=keepdims, initial=initial, where=where, + backend_to_test=backend_fw, ) diff --git a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py index 4d40858ce4a6e..4203c9cc788a3 100644 --- a/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py +++ b/ivy_tests/test_ivy/test_frontends/test_mindspore/test_ops/test_function/test_mindspore_nn_func.py @@ -252,6 +252,32 @@ def test_mindspore_interpolate( ) +# hardswish +# @handle_frontend_test( +# fn_tree="mindspore.ops.function.nn_func.hardswish", +# dtype_and_x=helpers.dtype_and_values( +# available_dtypes=helpers.get_dtypes("valid"), +# ), +# ) +# def test_mindspore_hardswish( +# *, +# dtype_and_x, +# on_device, +# fn_tree, +# frontend, +# test_flags, +# ): +# input_dtype, x = dtype_and_x +# helpers.test_frontend_function( +# input_dtypes=input_dtype, +# frontend=frontend, +# test_flags=test_flags, +# fn_tree=fn_tree, +# on_device=on_device, +# x=x[0], +# ) + + # pad @pytest.mark.skip("Testing pipeline not yet implemented") @handle_frontend_test( diff --git a/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py new file mode 100644 index 0000000000000..cb41299539350 --- /dev/null +++ b/ivy_tests/test_ivy/test_frontends/test_numpy/test_manipulation_routines/test_padding_arrays.py @@ -0,0 +1,107 @@ +# global +from hypothesis import strategies as st + +# local +import ivy_tests.test_ivy.helpers as helpers +from ivy_tests.test_ivy.helpers import handle_frontend_test + + +def st_tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False): + return st.lists( + elements, + min_size=min_size, + max_size=max_size, + unique_by=unique_by, + unique=unique, + ).map(tuple) + + +def _st_tuples_or_int(n_pairs, min_val=0): + return st.one_of( + st_tuples( + st.tuples( + st.integers(min_value=min_val, max_value=4), + st.integers(min_value=min_val, max_value=4), + ), + min_size=n_pairs, + max_size=n_pairs, + ), + helpers.ints(min_value=min_val, max_value=4), + ) + + +@st.composite +def _pad_helper(draw): + mode = draw( + st.sampled_from( + [ + "constant", + "edge", + "linear_ramp", + "maximum", + "mean", + "median", + "minimum", + "reflect", + "symmetric", + "wrap", + ] + ) + ) + if mode in ["median", "mean"]: + dtypes = "float" + else: + dtypes = "numeric" + dtype, input, shape = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes(dtypes), + ret_shape=True, + min_num_dims=1, + min_value=-100, + max_value=100, + ).filter( + lambda x: x[0][0] not in ["float16", "bfloat16", "complex64", "complex128"] + ), + ) + ndim = len(shape) + pad_width = draw(_st_tuples_or_int(ndim, min_val=0)) + kwargs = {} + if mode == "reflect" or mode == "symmetric": + kwargs["reflect_type"] = draw(st.sampled_from(["even", "odd"])) + if mode in ["maximum", "mean", "median", "minimum"]: + kwargs["stat_length"] = draw(_st_tuples_or_int(ndim, min_val=2)) + if mode in ["linear_ramp"]: + kwargs["end_values"] = draw(_st_tuples_or_int(ndim)) + if mode == "constant": + kwargs["constant_values"] = draw(_st_tuples_or_int(ndim)) + return dtype, input[0], pad_width, kwargs, mode + + +# pad +@handle_frontend_test( + fn_tree="numpy.pad", + args=_pad_helper(), + test_with_out=st.just(False), +) +def test_numpy_pad( + *, + args, + fn_tree, + backend_fw, + on_device, + test_flags, + frontend, +): + dtype, x, pad_width, kwargs, mode = args + helpers.test_frontend_function( + input_dtypes=dtype, + backend_to_test="numpy", + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + array=x, + pad_width=pad_width, + mode=mode, + **kwargs, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py index 3deb517b9f20a..66ba8586be7c8 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_math.py @@ -2048,3 +2048,35 @@ def test_paddle_amax( on_device=on_device, x=x[0], ) + + +# stanh +@handle_frontend_test( + fn_tree="paddle.tensor.math.stanh", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + ), + scale_a=st.floats(1e-5, 1e5), + scale_b=st.floats(1e-5, 1e5), +) +def test_paddle_stanh( + *, + dtype_and_x, + on_device, + fn_tree, + frontend, + test_flags, + scale_a, + scale_b, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + x=x[0], + scale_a=scale_a, + scale_b=scale_b, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py index f3c4fa485b3ca..a82530bc610b9 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_tensor/test_stat.py @@ -193,3 +193,34 @@ def test_paddle_nanmedian( axis=axis, keepdim=keepdim, ) + + +# std +@handle_frontend_test( + fn_tree="paddle.std", + dtype_and_x=_statistical_dtype_values(function="std"), + unbiased=st.booleans(), + keepdim=st.booleans(), +) +def test_paddle_std( + *, + unbiased, + dtype_and_x, + keepdim, + fn_tree, + frontend, + backend_fw, + test_flags, +): + input_dtype, x, axis, _ = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + x=x[0], + axis=axis, + unbiased=unbiased, + keepdim=keepdim, + ) diff --git a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py index 92f3335a2800e..f869af563290a 100644 --- a/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py +++ b/ivy_tests/test_ivy/test_frontends/test_paddle/test_vision/test_transforms.py @@ -1,4 +1,5 @@ # global +from hypothesis import strategies as st # local import ivy_tests.test_ivy.helpers as helpers @@ -31,6 +32,16 @@ def test_paddle_to_tensor( ) +@st.composite +def _chw_image_shape_helper(draw): + c = draw(st.sampled_from([1, 3]), label="channel") + h = draw(helpers.ints(min_value=1, max_value=100), label="height") + w = draw(helpers.ints(min_value=1, max_value=100), label="width") + + shape = (c, h, w) + return shape + + # adjust_hue @handle_frontend_test( fn_tree="paddle.vision.transforms.adjust_hue", @@ -70,6 +81,38 @@ def test_paddle_adjust_hue( ) +# adjust_brightness +@handle_frontend_test( + fn_tree="paddle.vision.transforms.adjust_brightness", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + shape=_chw_image_shape_helper(), + ), + brightness_factor=helpers.floats(min_value=0), +) +def test_paddle_adjust_brightness( + *, + dtype_and_x, + brightness_factor, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_function( + input_dtypes=input_dtype, + backend_to_test=backend_fw, + frontend=frontend, + test_flags=test_flags, + fn_tree=fn_tree, + on_device=on_device, + img=x[0], + brightness_factor=brightness_factor, + ) + + @handle_frontend_test( fn_tree="paddle.vision.transforms.vflip", dtype_and_x=helpers.dtype_and_values( diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index 2081a66eab98b..cb7c27a1ce5a7 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -2,7 +2,6 @@ from hypothesis import strategies as st # local -import ivy import ivy_tests.test_ivy.helpers as helpers from ivy_tests.test_ivy.helpers import handle_frontend_test import math @@ -69,7 +68,7 @@ def test_torch_avg_pool1d( # TODO: remove the processing of padding attribute when ivy.avg_pool # support explicit padding x_shape = [x[0].shape[2]] - padding = (pad[i] for i, pad in enumerate(padding)) + padding = [pad[i] for i, pad in enumerate(padding)] # figuring out the exact kernel_size for SAME and VALID padding # As ivy.avg_pool1d doesn't support explicit padding scheme if not sum(padding) == 0: @@ -119,7 +118,7 @@ def test_torch_avg_pool2d( input_dtype, x, kernel_size, stride, padding = dtype_x_k_s # TODO: remove the processing of padding attribute when ivy.avg_pool # support explicit padding - padding = (pad[i] for i, pad in enumerate(padding)) + padding = [pad[i] for i, pad in enumerate(padding)] x_shape = x[0].shape[2:] if not sum(padding) == 0: padding = calculate_same_padding(kernel_size, [stride[0]] * 2, x_shape) diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index bbeb1a94a821f..b823e31b8d420 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -6419,6 +6419,44 @@ def test_torch_tensor_neg( ) +# neg_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="neg_", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), +) +def test_torch_tensor_neg_( + dtype_and_x, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_and_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={ + "data": x[0], + }, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) + + # __neg__ @handle_frontend_method( class_tree=CLASS_TREE, @@ -12243,3 +12281,45 @@ def test_torch_index_fill( method_flags=method_flags, on_device=on_device, ) + + +# triu_ +@handle_frontend_method( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="triu_", + dtype_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("valid"), + min_num_dims=2, + max_num_dims=5, + min_dim_size=1, + max_dim_size=5, + ), + diagonal=st.integers( + min_value=-4, + max_value=4, + ), +) +def test_torch_triu_( + dtype_x, + diagonal, + frontend, + frontend_method_data, + init_flags, + method_flags, + on_device, + backend_fw, +): + input_dtype, x = dtype_x + helpers.test_frontend_method( + init_input_dtypes=input_dtype, + backend_to_test=backend_fw, + init_all_as_kwargs_np={"data": x[0]}, + method_input_dtypes=input_dtype, + method_all_as_kwargs_np={"diagonal": diagonal}, + frontend_method_data=frontend_method_data, + init_flags=init_flags, + method_flags=method_flags, + frontend=frontend, + on_device=on_device, + ) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 99de8fbc0212a..181212bfe86b9 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -1454,7 +1454,7 @@ def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): fn_tree="functional.ivy.sqrt", dtype_and_x=helpers.dtype_and_values( available_dtypes=helpers.get_dtypes("float_and_complex"), allow_inf=False - ), + ).filter(lambda x: x[0][0] not in ["bfloat16"]), ) def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): input_dtype, x = dtype_and_x diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py index 084762f728a99..6c8e666540b6b 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_sorting.py @@ -9,34 +9,33 @@ @st.composite def _invert_permutation_helper(draw): - return ["int64"], [ - np.array( - draw( - st.permutations( - list(range(draw(st.integers(min_value=3, max_value=10)))) - ) - ) - ) - ] + perm = draw( + st.permutations(list(range(draw(st.integers(min_value=3, max_value=10))))) + ) + if draw(st.booleans()): + perm = np.array(perm) + dtype = draw(st.sampled_from(["int32", "int64"])) + return dtype, perm # invert_permutation @handle_test( fn_tree="functional.ivy.experimental.invert_permutation", - dtype_and_x=_invert_permutation_helper(), + dtype_and_perm=_invert_permutation_helper(), test_instance_method=st.just(False), test_with_out=st.just(False), test_gradients=st.just(False), + ground_truth_backend="numpy", ) -def test_invert_permutation(dtype_and_x, test_flags, backend_fw, fn_name, on_device): - dtype, x = dtype_and_x +def test_invert_permutation(dtype_and_perm, test_flags, backend_fw, fn_name, on_device): + dtype, perm = dtype_and_perm helpers.test_function( - input_dtypes=dtype, + input_dtypes=[dtype], test_flags=test_flags, on_device=on_device, backend_to_test=backend_fw, fn_name=fn_name, - x=x[0], + x=perm, ) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index ffa59281e835e..818e9c52d7a2d 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,10 +1,10 @@ numpy # unpinned we test with latest version now -einops==0.6.1 -psutil==5.9.1 -termcolor==1.1.0 -colorama==0.4.5 -packaging==21.3 -nvidia-ml-py<=11.495.46 # mod_name=pynvml +einops +psutil +termcolor +colorama +packaging +nvidia-ml-py # mod_name=pynvml diskcache google-auth # mod_name=google.auth requests From ad440328303d57731670dda0a8339caddd538b41 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 22 Aug 2023 15:13:46 +0100 Subject: [PATCH 18/24] fix(framework): native out with complex_mode split When complex_mode is "split", the inner function is called twice by @handle_complex_input, so previously the result of the second function was overwriting the result of the first in the `out` array. Fixed by passing views of the real and imaginary parts of the `out` array to their respective function calls. --- ivy/func_wrapper.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ivy/func_wrapper.py b/ivy/func_wrapper.py index 7bd3cc3baadd5..3ac83acbf6fa4 100644 --- a/ivy/func_wrapper.py +++ b/ivy/func_wrapper.py @@ -1488,9 +1488,16 @@ def _handle_complex_input( if complex_mode == "split" or (complex_mode == "jax" and jax_like == "split"): real_inp = ivy.real(inp).data imag_inp = ivy.imag(inp).data + if "out" in kwargs and kwargs["out"] is not None: + out = kwargs.pop("out") + real_ret = fn(real_inp, *args, out=ivy.real(out), **kwargs) + imag_ret = fn(imag_inp, *args, out=ivy.imag(out), **kwargs) + else: + real_ret = fn(real_inp, *args, **kwargs) + imag_ret = fn(imag_inp, *args, **kwargs) return ivy.add( - fn(real_inp, *args, **kwargs), - ivy.multiply(1j, fn(imag_inp, *args, **kwargs)), + real_ret, + ivy.multiply(ivy.array(1j, dtype=inp.dtype), imag_ret), ) elif complex_mode == "magnitude" or ( From 3ed1e6cd587f6c8da69b7190c2cd853c4848fbb5 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 22 Aug 2023 15:37:24 +0100 Subject: [PATCH 19/24] fix(test): Apply safety factor to atan test --- .../test_ivy/test_functional/test_core/test_elementwise.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 181212bfe86b9..8c9b552548946 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -180,7 +180,10 @@ def test_asinh(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.atan", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex") + available_dtypes=helpers.get_dtypes("float_and_complex"), + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ), ) def test_atan(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): From a02efc134ab8d81b7cd094b545519d9a6f77d73b Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 22 Aug 2023 15:50:59 +0100 Subject: [PATCH 20/24] fix(test): apply safety factor to test_sqrt --- .../test_functional/test_core/test_elementwise.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py index 8c9b552548946..e07c018a8d0dd 100644 --- a/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py +++ b/ivy_tests/test_ivy/test_functional/test_core/test_elementwise.py @@ -1456,7 +1456,13 @@ def test_square(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): @handle_test( fn_tree="functional.ivy.sqrt", dtype_and_x=helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float_and_complex"), allow_inf=False + available_dtypes=helpers.get_dtypes("float_and_complex"), + allow_inf=False, + # Safety factor is to account for complex, where taking square root + # involves taking absolute value first + large_abs_safety_factor=2, + small_abs_safety_factor=2, + safety_factor_scale="log", ).filter(lambda x: x[0][0] not in ["bfloat16"]), ) def test_sqrt(*, dtype_and_x, test_flags, backend_fw, fn_name, on_device): From d265bdac2048c031b9d51e37a52d3ce75568c5da Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Tue, 22 Aug 2023 16:05:55 +0100 Subject: [PATCH 21/24] fix(test): apply safety factor to test_jax_leaky_relu The array itself had a safety factor already, but the alpha value did not. --- .../test_jax/test_nn/test_non_linear_activations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py index 714d94d9ff44c..4d0600e4a8daa 100644 --- a/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py +++ b/ivy_tests/test_ivy/test_frontends/test_jax/test_nn/test_non_linear_activations.py @@ -138,7 +138,9 @@ def test_jax_silu( small_abs_safety_factor=2, safety_factor_scale="linear", ), - negative_slope=helpers.floats(min_value=0.0, max_value=1.0), + negative_slope=helpers.floats( + min_value=0.0, max_value=1.0, small_abs_safety_factor=16 + ), test_with_out=st.just(False), ) def test_jax_leaky_relu( From a26fe2b15133fca77601598dd4f010dc4b68c63b Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 24 Aug 2023 10:47:58 +0100 Subject: [PATCH 22/24] style: move `complex_mode` before `out` argument --- ivy/data_classes/array/activations.py | 12 +++++----- ivy/data_classes/array/elementwise.py | 4 ++-- ivy/data_classes/container/activations.py | 24 +++++++++---------- ivy/data_classes/container/elementwise.py | 8 +++---- ivy/functional/backends/jax/activations.py | 6 ++--- ivy/functional/backends/jax/elementwise.py | 2 +- ivy/functional/backends/mxnet/activations.py | 6 ++--- ivy/functional/backends/mxnet/elementwise.py | 2 +- ivy/functional/backends/numpy/activations.py | 6 ++--- ivy/functional/backends/numpy/elementwise.py | 2 +- ivy/functional/backends/paddle/activations.py | 6 ++--- ivy/functional/backends/paddle/elementwise.py | 2 +- .../backends/tensorflow/activations.py | 6 ++--- .../backends/tensorflow/elementwise.py | 2 +- ivy/functional/backends/torch/activations.py | 6 ++--- ivy/functional/backends/torch/elementwise.py | 2 +- ivy/functional/ivy/activations.py | 6 ++--- ivy/functional/ivy/elementwise.py | 2 +- 18 files changed, 52 insertions(+), 52 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index c33e995045c32..f8e33063592ae 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -14,8 +14,8 @@ def relu( self: ivy.Array, /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.relu. This method simply wraps the @@ -45,15 +45,15 @@ def relu( >>> print(y) ivy.array([0., 0., 1.]) """ - return ivy.relu(self._data, out=out, complex_mode=complex_mode) + return ivy.relu(self._data, complex_mode=complex_mode, out=out) def leaky_relu( self: ivy.Array, /, *, alpha: float = 0.2, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.leaky_relu. This method simply wraps @@ -86,7 +86,7 @@ def leaky_relu( ivy.array([ 0.39, -0.17]) """ return ivy.leaky_relu( - self._data, alpha=alpha, out=out, complex_mode=complex_mode + self._data, alpha=alpha, complex_mode=complex_mode, out=out ) def gelu( @@ -94,8 +94,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.gelu. This method simply wraps the @@ -128,7 +128,7 @@ def gelu( ivy.array([-0.138, -0.165, 1.4]) """ return ivy.gelu( - self._data, approximate=approximate, out=out, complex_mode=complex_mode + self._data, approximate=approximate, complex_mode=complex_mode, out=out ) def sigmoid(self: ivy.Array, /, *, out: Optional[ivy.Array] = None) -> ivy.Array: diff --git a/ivy/data_classes/array/elementwise.py b/ivy/data_classes/array/elementwise.py index 645407a9edb29..e59a048fbcb9e 100644 --- a/ivy/data_classes/array/elementwise.py +++ b/ivy/data_classes/array/elementwise.py @@ -2555,8 +2555,8 @@ def tan(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: def tanh( self: ivy.Array, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.tanh. This method simply wraps the @@ -2589,7 +2589,7 @@ def tanh( >>> print(y) ivy.array([0., 0.762, 0.964]) """ - return ivy.tanh(self._data, out=out, complex_mode=complex_mode) + return ivy.tanh(self._data, complex_mode=complex_mode, out=out) def trunc(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array: """ diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index ea5e9f6f7ca21..84ed9044b17c2 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -18,8 +18,8 @@ def _static_relu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.relu. This method simply wraps the @@ -71,8 +71,8 @@ def _static_relu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def relu( @@ -83,8 +83,8 @@ def relu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.relu. This method simply wraps the @@ -135,8 +135,8 @@ def relu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -149,8 +149,8 @@ def _static_leaky_relu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.leaky_relu. This method simply wraps @@ -204,8 +204,8 @@ def _static_leaky_relu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def leaky_relu( @@ -217,8 +217,8 @@ def leaky_relu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.leaky_relu. This method simply @@ -271,8 +271,8 @@ def leaky_relu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod @@ -285,8 +285,8 @@ def _static_gelu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.gelu. This method simply wraps the @@ -339,8 +339,8 @@ def _static_gelu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def gelu( @@ -352,8 +352,8 @@ def gelu( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.gelu. This method simply wraps the @@ -406,8 +406,8 @@ def gelu( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod diff --git a/ivy/data_classes/container/elementwise.py b/ivy/data_classes/container/elementwise.py index 815b98429f6ac..382b6b9ad4f88 100644 --- a/ivy/data_classes/container/elementwise.py +++ b/ivy/data_classes/container/elementwise.py @@ -8273,8 +8273,8 @@ def _static_tanh( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.tanh. This method simply wraps the @@ -8328,8 +8328,8 @@ def _static_tanh( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def tanh( @@ -8339,8 +8339,8 @@ def tanh( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.tanh. This method simply wraps the @@ -8394,8 +8394,8 @@ def tanh( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index c6dac4872e11a..edc8c0977dc2b 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -17,8 +17,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[JaxArray] = None, complex_mode="jax", + out: Optional[JaxArray] = None, ) -> JaxArray: return jax.nn.gelu(x, approximate) @@ -28,14 +28,14 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[JaxArray] = None, complex_mode="jax", + out: Optional[JaxArray] = None, ) -> JaxArray: return jnp.asarray(jnp.where(x > 0, x, jnp.multiply(x, alpha)), x.dtype) def relu( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: return jnp.maximum(x, 0) diff --git a/ivy/functional/backends/jax/elementwise.py b/ivy/functional/backends/jax/elementwise.py index e5afb561bd29d..9bd939b1abff4 100644 --- a/ivy/functional/backends/jax/elementwise.py +++ b/ivy/functional/backends/jax/elementwise.py @@ -503,7 +503,7 @@ def tan(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray: def tanh( - x: JaxArray, /, *, out: Optional[JaxArray] = None, complex_mode="jax" + x: JaxArray, /, *, complex_mode="jax", out: Optional[JaxArray] = None ) -> JaxArray: return jnp.tanh(x) diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 48f95e3029f6b..4f87a23028ddd 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -16,8 +16,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[None] = None, complex_mode="jax", + out: Optional[None] = None, ) -> None: if approximate: return ( @@ -27,12 +27,12 @@ def gelu( def leaky_relu( - x: None, /, *, alpha: float = 0.2, out: Optional[None] = None, complex_mode="jax" + x: None, /, *, alpha: float = 0.2, complex_mode="jax", out: Optional[None] = None ) -> None: return mx.nd.LeakyReLU(x, slope=alpha) -def relu(x: None, /, *, out: Optional[None] = None, complex_mode="jax") -> None: +def relu(x: None, /, *, complex_mode="jax", out: Optional[None] = None) -> None: return mx.nd.relu(x) diff --git a/ivy/functional/backends/mxnet/elementwise.py b/ivy/functional/backends/mxnet/elementwise.py index 2f262db28d7ca..42a8a2788b605 100644 --- a/ivy/functional/backends/mxnet/elementwise.py +++ b/ivy/functional/backends/mxnet/elementwise.py @@ -598,8 +598,8 @@ def tanh( x: Union[(None, mx.ndarray.NDArray)], /, *, - out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, complex_mode="jax", + out: Optional[Union[(None, mx.ndarray.NDArray)]] = None, ) -> Union[(None, mx.ndarray.NDArray)]: return mx.nd.tanh(x) diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index 4b052baee3b75..a1f7558030c58 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -11,7 +11,7 @@ @_scalar_output_to_0d_array def relu( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: return np.maximum(x, 0, out=out, dtype=x.dtype) @@ -24,8 +24,8 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[np.ndarray] = None, complex_mode="jax", + out: Optional[np.ndarray] = None, ) -> np.ndarray: return np.asarray(np.where(x > 0, x, np.multiply(x, alpha)), x.dtype) @@ -36,8 +36,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[np.ndarray] = None, complex_mode="jax", + out: Optional[np.ndarray] = None, ) -> np.ndarray: if approximate: ret = 0.5 * x * (1 + np.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) diff --git a/ivy/functional/backends/numpy/elementwise.py b/ivy/functional/backends/numpy/elementwise.py index 441f470fd5d94..6c4459bbaaf4f 100644 --- a/ivy/functional/backends/numpy/elementwise.py +++ b/ivy/functional/backends/numpy/elementwise.py @@ -754,7 +754,7 @@ def tan(x: np.ndarray, /, *, out: Optional[np.ndarray] = None) -> np.ndarray: @_scalar_output_to_0d_array def tanh( - x: np.ndarray, /, *, out: Optional[np.ndarray] = None, complex_mode="jax" + x: np.ndarray, /, *, complex_mode="jax", out: Optional[np.ndarray] = None ) -> np.ndarray: return np.tanh(x, out=out) diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 45bfe5665f9d9..0228884358ef3 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -31,7 +31,7 @@ def relu( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: if paddle.is_complex(x): @@ -48,8 +48,8 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[paddle.Tensor] = None, complex_mode="jax", + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if x.dtype in unsupported_dtypes: if paddle.is_complex(x): @@ -69,8 +69,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[paddle.Tensor] = None, complex_mode="jax", + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if paddle.is_complex(x): sqrt_2_over_pi = 0.7978845608 diff --git a/ivy/functional/backends/paddle/elementwise.py b/ivy/functional/backends/paddle/elementwise.py index 1c987c2bf806b..a1e2d5c119542 100644 --- a/ivy/functional/backends/paddle/elementwise.py +++ b/ivy/functional/backends/paddle/elementwise.py @@ -707,7 +707,7 @@ def not_equal( backend_version, ) def tanh( - x: paddle.Tensor, /, *, out: Optional[paddle.Tensor] = None, complex_mode="jax" + x: paddle.Tensor, /, *, complex_mode="jax", out: Optional[paddle.Tensor] = None ) -> paddle.Tensor: if x.dtype in [ paddle.int8, diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 39d4cc6817ccd..95696acfd77e2 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -22,8 +22,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[Tensor] = None, complex_mode="jax", + out: Optional[Tensor] = None, ) -> Tensor: if x.dtype in [tf.complex64, tf.complex128]: return 0.5 * x * (1 + tf.math.tanh(0.7978845608 * (x + 0.044715 * x * x * x))) @@ -35,13 +35,13 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[Tensor] = None, complex_mode="jax", + out: Optional[Tensor] = None, ) -> Tensor: return tf.nn.leaky_relu(x, alpha) -def relu(x: Tensor, /, *, out: Optional[Tensor] = None, complex_mode="jax") -> Tensor: +def relu(x: Tensor, /, *, complex_mode="jax", out: Optional[Tensor] = None) -> Tensor: return tf.nn.relu(x) diff --git a/ivy/functional/backends/tensorflow/elementwise.py b/ivy/functional/backends/tensorflow/elementwise.py index 858eb875358ec..1e060e8e8cae3 100644 --- a/ivy/functional/backends/tensorflow/elementwise.py +++ b/ivy/functional/backends/tensorflow/elementwise.py @@ -744,8 +744,8 @@ def tanh( x: Union[tf.Tensor, tf.Variable], /, *, - out: Optional[Union[tf.Tensor, tf.Variable]] = None, complex_mode="jax", + out: Optional[Union[tf.Tensor, tf.Variable]] = None, ) -> Union[tf.Tensor, tf.Variable]: return tf.tanh(x) diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index cee0c75dc235d..b46bb51f4bcfb 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -19,7 +19,7 @@ @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) def relu( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: return torch.relu(x) @@ -30,8 +30,8 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[torch.Tensor] = None, complex_mode="jax", + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: return torch.nn.functional.leaky_relu(x, alpha) @@ -42,8 +42,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[torch.Tensor] = None, complex_mode="jax", + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if approximate: return ( diff --git a/ivy/functional/backends/torch/elementwise.py b/ivy/functional/backends/torch/elementwise.py index bf6b68aca8050..a0a3f5c11e053 100644 --- a/ivy/functional/backends/torch/elementwise.py +++ b/ivy/functional/backends/torch/elementwise.py @@ -508,7 +508,7 @@ def not_equal( @with_unsupported_dtypes({"2.0.1 and below": ("float16",)}, backend_version) @handle_numpy_arrays_in_specific_backend def tanh( - x: torch.Tensor, /, *, out: Optional[torch.Tensor] = None, complex_mode="jax" + x: torch.Tensor, /, *, complex_mode="jax", out: Optional[torch.Tensor] = None ) -> torch.Tensor: x = _cast_for_unary_op(x) return torch.tanh(x, out=out) diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 27529e2985437..88aaf9ee91abe 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -45,8 +45,8 @@ def gelu( /, *, approximate: bool = False, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the Gaussian error linear unit (GELU) activation function. @@ -135,8 +135,8 @@ def leaky_relu( /, *, alpha: float = 0.2, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the leaky rectified linear unit function element-wise. @@ -307,8 +307,8 @@ def relu( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the rectified linear unit function element-wise. diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 0792615f5a378..0f0541ca00f15 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -6400,8 +6400,8 @@ def tanh( x: Union[ivy.Array, ivy.NativeArray], /, *, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Calculate an implementation-dependent approximation to the hyperbolic tangent, From 641bed27886c6431edd9f4c6ad66bcf349cea4d6 Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Thu, 24 Aug 2023 11:07:54 +0100 Subject: [PATCH 23/24] style: move `complex_mode` after `out` argument Apply this change to the `softmax` function, which newly gained this argument. Also reflect this change in docstrings of relevant functions. --- ivy/data_classes/array/activations.py | 26 ++++----- ivy/data_classes/array/elementwise.py | 6 +- ivy/data_classes/container/activations.py | 56 +++++++++---------- ivy/data_classes/container/elementwise.py | 12 ++-- ivy/functional/backends/jax/activations.py | 2 +- ivy/functional/backends/mxnet/activations.py | 2 +- ivy/functional/backends/numpy/activations.py | 2 +- ivy/functional/backends/paddle/activations.py | 2 +- .../backends/tensorflow/activations.py | 2 +- ivy/functional/backends/torch/activations.py | 2 +- ivy/functional/ivy/activations.py | 26 ++++----- ivy/functional/ivy/elementwise.py | 6 +- 12 files changed, 72 insertions(+), 72 deletions(-) diff --git a/ivy/data_classes/array/activations.py b/ivy/data_classes/array/activations.py index cd5b8bd137ec3..8bc0745927d91 100644 --- a/ivy/data_classes/array/activations.py +++ b/ivy/data_classes/array/activations.py @@ -26,12 +26,12 @@ def relu( ---------- self input array. - out - optional output array, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -66,12 +66,12 @@ def leaky_relu( input array. alpha the slope of the negative section. - out - optional output array, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -108,12 +108,12 @@ def gelu( input array. approximate whether to use the approximate version of the gelu function. - out - optional output array, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -203,8 +203,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ ivy.Array instance method variant of ivy.softplus. This method simply wraps the @@ -219,11 +219,11 @@ def softplus( the beta parameter of the softplus function. threshold the threshold parameter of the softplus function. - out - optional output array, for writing the result to. It must have a shape complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape Returns ------- @@ -251,8 +251,8 @@ def softplus( self._data, beta=beta, threshold=threshold, - out=out, complex_mode=complex_mode, + out=out, ) def log_softmax( diff --git a/ivy/data_classes/array/elementwise.py b/ivy/data_classes/array/elementwise.py index e59a048fbcb9e..794da705e123c 100644 --- a/ivy/data_classes/array/elementwise.py +++ b/ivy/data_classes/array/elementwise.py @@ -2568,12 +2568,12 @@ def tanh( self input array whose elements each represent a hyperbolic angle. Should have a real-valued floating-point data type. - out - optional output, for writing the result to. It must have a shape that the - inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output, for writing the result to. It must have a shape that the + inputs broadcast to. Returns ------- diff --git a/ivy/data_classes/container/activations.py b/ivy/data_classes/container/activations.py index bffb6b906d764..af9e04c0f6f15 100644 --- a/ivy/data_classes/container/activations.py +++ b/ivy/data_classes/container/activations.py @@ -41,12 +41,12 @@ def _static_relu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -106,12 +106,12 @@ def relu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -174,12 +174,12 @@ def _static_leaky_relu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -242,12 +242,12 @@ def leaky_relu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -310,12 +310,12 @@ def _static_gelu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -377,12 +377,12 @@ def gelu( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -665,8 +665,8 @@ def _static_softplus( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container static method variant of ivy.softplus. This method simply wraps @@ -692,12 +692,12 @@ def _static_softplus( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -730,8 +730,8 @@ def _static_softplus( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) def softplus( @@ -744,8 +744,8 @@ def softplus( to_apply: Union[bool, ivy.Container] = True, prune_unapplied: Union[bool, ivy.Container] = False, map_sequences: Union[bool, ivy.Container] = False, - out: Optional[ivy.Container] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Container] = None, ) -> ivy.Container: """ ivy.Container instance method variant of ivy.softplus. This method simply wraps @@ -771,12 +771,12 @@ def softplus( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -807,8 +807,8 @@ def softplus( to_apply=to_apply, prune_unapplied=prune_unapplied, map_sequences=map_sequences, - out=out, complex_mode=complex_mode, + out=out, ) @staticmethod diff --git a/ivy/data_classes/container/elementwise.py b/ivy/data_classes/container/elementwise.py index 382b6b9ad4f88..226a9fe15edf2 100644 --- a/ivy/data_classes/container/elementwise.py +++ b/ivy/data_classes/container/elementwise.py @@ -8297,12 +8297,12 @@ def _static_tanh( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- @@ -8363,12 +8363,12 @@ def tanh( map_sequences Whether to also map method to sequences (lists, tuples). Default is ``False``. - out - optional output container, for writing the result to. It must have a shape - that the inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output container, for writing the result to. It must have a shape + that the inputs broadcast to. Returns ------- diff --git a/ivy/functional/backends/jax/activations.py b/ivy/functional/backends/jax/activations.py index a2aea2f07d9c1..6d980aaf6022e 100644 --- a/ivy/functional/backends/jax/activations.py +++ b/ivy/functional/backends/jax/activations.py @@ -58,8 +58,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[JaxArray] = None, complex_mode="jax", + out: Optional[JaxArray] = None, ) -> JaxArray: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/mxnet/activations.py b/ivy/functional/backends/mxnet/activations.py index 8ac888626f28c..8efd486936334 100644 --- a/ivy/functional/backends/mxnet/activations.py +++ b/ivy/functional/backends/mxnet/activations.py @@ -52,8 +52,8 @@ def softplus( *, beta: Optional[Union[(int, float)]] = None, threshold: Optional[Union[(int, float)]] = None, - out: Optional[None] = None, complex_mode="jax", + out: Optional[None] = None, ) -> None: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/numpy/activations.py b/ivy/functional/backends/numpy/activations.py index 294e53a33afaf..2e230498b9574 100644 --- a/ivy/functional/backends/numpy/activations.py +++ b/ivy/functional/backends/numpy/activations.py @@ -70,8 +70,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[np.ndarray] = None, complex_mode="jax", + out: Optional[np.ndarray] = None, ) -> np.ndarray: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/paddle/activations.py b/ivy/functional/backends/paddle/activations.py index 1866908dbc4ee..d090f69a66a4b 100644 --- a/ivy/functional/backends/paddle/activations.py +++ b/ivy/functional/backends/paddle/activations.py @@ -122,8 +122,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[paddle.Tensor] = None, complex_mode="jax", + out: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/tensorflow/activations.py b/ivy/functional/backends/tensorflow/activations.py index 97f455a666648..de43f95f9a8f7 100644 --- a/ivy/functional/backends/tensorflow/activations.py +++ b/ivy/functional/backends/tensorflow/activations.py @@ -77,8 +77,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[Tensor] = None, complex_mode="jax", + out: Optional[Tensor] = None, ) -> Tensor: if beta is not None and beta != 1: x_beta = x * beta diff --git a/ivy/functional/backends/torch/activations.py b/ivy/functional/backends/torch/activations.py index 39d79c19bc6d2..afb85ebe96f5d 100644 --- a/ivy/functional/backends/torch/activations.py +++ b/ivy/functional/backends/torch/activations.py @@ -82,8 +82,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[torch.Tensor] = None, complex_mode="jax", + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: kwargs = { k: v for k, v in {"beta": beta, "threshold": threshold}.items() if v is not None diff --git a/ivy/functional/ivy/activations.py b/ivy/functional/ivy/activations.py index 2a54783418ce4..f19180ab7a529 100644 --- a/ivy/functional/ivy/activations.py +++ b/ivy/functional/ivy/activations.py @@ -58,12 +58,12 @@ def gelu( approximate Whether to approximate, default is ``True``. An approximation is always used if the input array is complex. - out - optional output array, for writing the result to. It must have a shape that the - inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. Returns ------- @@ -152,12 +152,12 @@ def leaky_relu( Input array. alpha Negative slope for ReLU. - out - optional output array, for writing the result to. It must have a shape that the - inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. Returns ------- @@ -322,12 +322,12 @@ def relu( ---------- x input array - out - optional output array, for writing the result to. It must have a shape that the - inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. Returns ------- @@ -551,8 +551,8 @@ def softplus( *, beta: Optional[Union[int, float]] = None, threshold: Optional[Union[int, float]] = None, - out: Optional[ivy.Array] = None, complex_mode: Literal["split", "magnitude", "jax"] = "jax", + out: Optional[ivy.Array] = None, ) -> ivy.Array: """ Apply the softplus function element-wise. @@ -572,12 +572,12 @@ def softplus( threshold values above this revert to a linear function If the input is complex, only its real part is considered. Default: ``None`` - out - optional output array, for writing the result to. It must have a shape that the - inputs broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output array, for writing the result to. It must have a shape that the + inputs broadcast to. Returns ------- diff --git a/ivy/functional/ivy/elementwise.py b/ivy/functional/ivy/elementwise.py index 0f0541ca00f15..0910dea53c4cd 100644 --- a/ivy/functional/ivy/elementwise.py +++ b/ivy/functional/ivy/elementwise.py @@ -6470,12 +6470,12 @@ def tanh( input array whose elements each represent a hyperbolic angle. Should have a real-valued floating-point data type. - out - optional output, for writing the result to. It must have a shape that the inputs - broadcast to. complex_mode optional specifier for how to handle complex data types. See ``ivy.func_wrapper.handle_complex_input`` for more detail. + out + optional output, for writing the result to. It must have a shape that the inputs + broadcast to. Returns ------- From 0a1f32fea0af541ea76d5887afdc0e5d3b550a3d Mon Sep 17 00:00:00 2001 From: Joe Shepherd Date: Wed, 30 Aug 2023 11:52:09 +0100 Subject: [PATCH 24/24] tests: add complex_mode to ivy.softplus test --- .../test_functional/test_nn/test_activations.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py index 4da89e05797b4..2399daf623170 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_activations.py @@ -246,9 +246,18 @@ def test_softmax(*, dtype_and_x, axis, test_flags, backend_fw, fn_name, on_devic ), beta=st.one_of(helpers.number(min_value=0.1, max_value=10), st.none()), threshold=st.one_of(helpers.number(min_value=0.1, max_value=30), st.none()), + complex_mode=st.sampled_from(["jax", "split", "magnitude"]), ) def test_softplus( - *, dtype_and_x, beta, threshold, test_flags, backend_fw, fn_name, on_device + *, + dtype_and_x, + beta, + threshold, + complex_mode, + test_flags, + backend_fw, + fn_name, + on_device, ): assume(beta != 0) assume(threshold != 0) @@ -264,4 +273,5 @@ def test_softplus( x=x[0], beta=beta, threshold=threshold, + complex_mode=complex_mode, )