Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex maths functions #21539

Merged
merged 46 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
9dd84f6
Add complex handling to tanh
jshepherd01 Aug 7, 2023
75aa6c1
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 7, 2023
20680cc
Modify gelu to allow complex in paddle
jshepherd01 Aug 8, 2023
93dc7db
Refactor paddle backend pow and exp
jshepherd01 Aug 9, 2023
2841ca2
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 9, 2023
0de9919
Add complex number handling to trig functions in paddle backend
jshepherd01 Aug 10, 2023
cada95d
Implement complex number handling for inverse trig
jshepherd01 Aug 10, 2023
cab0691
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 10, 2023
d705b2f
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 14, 2023
f7be90d
Fix some failing tests
jshepherd01 Aug 14, 2023
38be825
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 14, 2023
c7ab3b5
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 15, 2023
ee7370d
Fix some test failures
jshepherd01 Aug 15, 2023
6e5710b
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 15, 2023
96daad1
Fix more test failures
jshepherd01 Aug 15, 2023
fb463b3
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 16, 2023
36cefa6
Add complex_mode to test_tanh, fix associated bugs
jshepherd01 Aug 16, 2023
78f158b
Change logic for testing for complex input
jshepherd01 Aug 16, 2023
d61dca1
Fix line length issue
jshepherd01 Aug 16, 2023
1e06ede
Merge branch 'main' of github.com:unifyai/ivy
jshepherd01 Aug 16, 2023
f661929
Merge branch 'main' of github.com:unifyai/ivy
jshepherd01 Aug 16, 2023
0de40cc
Minor changes to pow and exp
jshepherd01 Aug 16, 2023
af62ab9
Merge branch 'main' of github.com:unifyai/ivy
jshepherd01 Aug 17, 2023
7f53dc5
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 17, 2023
7ce8343
Implemented suggested changes
jshepherd01 Aug 17, 2023
e182d4f
Remove unnecessary `str()` operation
jshepherd01 Aug 17, 2023
3cc05de
Fix doctstrings and tests for complex activations
jshepherd01 Aug 17, 2023
a6e228b
Add complex_mode to backend call signatures
jshepherd01 Aug 18, 2023
3ec9bda
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 18, 2023
71ad8cd
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 21, 2023
0cb8f29
Merge branch 'main' into complex-maths-functions
jshepherd01 Aug 21, 2023
ff03dcb
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 22, 2023
ad44032
fix(framework): native out with complex_mode split
jshepherd01 Aug 22, 2023
3ed1e6c
fix(test): Apply safety factor to atan test
jshepherd01 Aug 22, 2023
a02efc1
fix(test): apply safety factor to test_sqrt
jshepherd01 Aug 22, 2023
d265bda
fix(test): apply safety factor to test_jax_leaky_relu
jshepherd01 Aug 22, 2023
9aceb26
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 22, 2023
92347a4
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 23, 2023
e0947f4
Merge branch 'main' into complex-maths-functions
jshepherd01 Aug 23, 2023
a26fe2b
style: move `complex_mode` before `out` argument
jshepherd01 Aug 24, 2023
8ff6e71
Merge branch 'complex-maths-functions' of github.com:jshepherd01/ivy …
jshepherd01 Aug 24, 2023
8713949
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 24, 2023
641bed2
style: move `complex_mode` after `out` argument
jshepherd01 Aug 24, 2023
fe5ecf5
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 29, 2023
de3d93e
Merge branch 'master' into complex-maths-functions
jshepherd01 Aug 30, 2023
0a1f32f
tests: add complex_mode to ivy.softplus test
jshepherd01 Aug 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ivy/data_classes/array/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union
from typing import Optional, Union, Literal

# local
import ivy
Expand Down Expand Up @@ -2552,7 +2552,12 @@ def tan(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""
return ivy.tan(self._data, out=out)

def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array:
def tanh(
self: ivy.Array,
*,
out: Optional[ivy.Array] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.tanh. This method simply wraps the
function, and so the docstring for ivy.tanh also applies to this method with
Expand All @@ -2566,6 +2571,8 @@ def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array:
out
optional output, for writing the result to. It must have a shape that the
inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A simple clarification of what each mode means would be great

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an explanation in the base function and I was trying to keep the array method docstrings as minimal as possible, I can definitely add it if that's preferred though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with @MahmoudAshraf97, the parameter descriptions should ideally be the same as the functional one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. Would it be enough to add See ``ivy.func_wrapper.handle_complex_input`` for more detail. or should I put a description in directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's completely fine to add "See ivy.func_wrapper.handle_complex_input for more detail.". The only thing is the parameter descriptions for the ivy.tanh function should be the same as the parameter descriptions for the ivy.Array.tanh method 😄


Returns
-------
Expand All @@ -2581,7 +2588,7 @@ def tanh(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array:
>>> print(y)
ivy.array([0., 0.762, 0.964])
"""
return ivy.tanh(self._data, out=out)
return ivy.tanh(self._data, out=out, complex_mode=complex_mode)

def trunc(self: ivy.Array, *, out: Optional[ivy.Array] = None) -> ivy.Array:
"""
Expand Down
10 changes: 9 additions & 1 deletion ivy/data_classes/container/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union, List, Dict
from typing import Optional, Union, List, Dict, Literal

# local
import ivy
Expand Down Expand Up @@ -8274,6 +8274,7 @@ def _static_tanh(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.tanh. This method simply wraps the
Expand All @@ -8299,6 +8300,8 @@ def _static_tanh(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.

Returns
-------
Expand All @@ -8325,6 +8328,7 @@ def _static_tanh(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

def tanh(
Expand All @@ -8335,6 +8339,7 @@ def tanh(
prune_unapplied: Union[bool, ivy.Container] = False,
map_sequences: Union[bool, ivy.Container] = False,
out: Optional[ivy.Container] = None,
complex_mode: Literal["split", "magnitude", "jax"] = "jax",
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.tanh. This method simply wraps the
Expand All @@ -8360,6 +8365,8 @@ def tanh(
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
complex_mode
optional specifier for how to handle complex data types.

Returns
-------
Expand All @@ -8386,6 +8393,7 @@ def tanh(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
complex_mode=complex_mode,
)

@staticmethod
Expand Down
4 changes: 4 additions & 0 deletions ivy/functional/backends/jax/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ def log10(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:


def log1p(x: JaxArray, /, *, out: Optional[JaxArray] = None) -> JaxArray:
if x.dtype in ["complex64", "complex128"]:
# For certain complex numbers, jnp.log1p(z) != jnp.log(1+z)
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
# I don't know why
return jnp.log(1 + x)
return jnp.log1p(x)


Expand Down
19 changes: 7 additions & 12 deletions ivy/functional/backends/paddle/activations.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you added only the approximate case and left the other one, is there a reason for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other case relies on erf, which none of the backends implement for complex numbers (the only implementation I was able to find anywhere was in scipy, and it's a very long and complicated function that falls back on taylor series expansions for different cases). There's also a wrapper on the ivy API level which sets approximate to True if the input is complex, so the other case should never be reachable with a complex input

Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def leaky_relu(


@with_unsupported_device_and_dtypes(
{"2.5.1 and below": {"cpu": ("complex128", "complex64")}}, backend_version
{"2.5.1 and below": {"cpu": ("bfloat16",)}}, backend_version
)
def gelu(
x: paddle.Tensor,
Expand All @@ -68,18 +68,13 @@ def gelu(
approximate: bool = False,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
if x.dtype in [paddle.complex64, paddle.complex128]:
jshepherd01 marked this conversation as resolved.
Show resolved Hide resolved
return (
0.5
* x
* (1 + paddle_backend.tanh(0.7978845608 * (x + 0.044715 * x * x * x)))
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
)
if x.dtype in unsupported_dtypes:
if paddle.is_complex(x):
if approximate:
return (
0.5
* x
* (
1
+ paddle_backend.tanh(0.7978845608 * (x + 0.044715 * x * x * x))
)
)
return 0.5 * x * (1 + paddle_backend.erf(x / paddle_backend.sqrt(2)))
return F.gelu(x.cast("float32"), approximate=approximate).cast(x.dtype)
return F.gelu(x, approximate=approximate)

Expand Down
Loading
Loading