Skip to content

Commit

Permalink
Merge pull request #113 from adtzlr/add-math-special-erf
Browse files Browse the repository at this point in the history
Add the (Gauss) error function `math.special.erf(z)`
  • Loading branch information
adtzlr authored Jun 1, 2024
2 parents 67aed44 + 4f2c813 commit 2402c82
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.24.0"
__version__ = "0.25.0"
2 changes: 2 additions & 0 deletions src/tensortrax/math/special/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._special_tensor import (
ddot,
dev,
erf,
from_triu_1d,
from_triu_2d,
sym,
Expand All @@ -17,6 +18,7 @@
__all__ = [
"ddot",
"dev",
"erf",
"from_triu_1d",
"from_triu_2d",
"sym",
Expand Down
18 changes: 18 additions & 0 deletions src/tensortrax/math/special/_special_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from ..._helpers import Δ, Δδ, f, δ
from ..._tensor import Tensor
from .. import _math_array as array
from .._math_tensor import einsum, sqrt, stack, trace, transpose
Expand All @@ -20,6 +21,23 @@ def dev(A):
return A - trace(A) / dim * array.eye(A)


def erf(z):
"The (Gauss) error function."
from scipy.special import erf

if isinstance(z, Tensor):
derf = 2 / np.sqrt(np.pi) * np.exp(-f(z) ** 2)
return Tensor(
x=erf(f(z)),
δx=derf * δ(z),
Δx=derf * Δ(z),
Δδx=-2 * f(z) * derf * δ(z) * Δ(z) + derf * Δδ(z),
ntrax=z.ntrax,
)
else:
return erf(z)


def sym(A):
"Symmetric part of a Tensor."
return (A + transpose(A)) / 2
Expand Down
1 change: 1 addition & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def test_math():
tm.ravel,
tm.abs,
tm.sign,
tm.special.erf,
]:
assert np.allclose(fun(F), fun(T).x)

Expand Down

0 comments on commit 2402c82

Please sign in to comment.