From 4f2c813a3f23a6f65f5cdf875123438f48b9a359 Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Sun, 2 Jun 2024 00:50:24 +0200 Subject: [PATCH] Add the (Gauss) error function `math.special.erf(z)` --- src/tensortrax/__about__.py | 2 +- src/tensortrax/math/special/__init__.py | 2 ++ src/tensortrax/math/special/_special_tensor.py | 18 ++++++++++++++++++ tests/test_math.py | 1 + 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/tensortrax/__about__.py b/src/tensortrax/__about__.py index 10a4943..826ea2b 100644 --- a/src/tensortrax/__about__.py +++ b/src/tensortrax/__about__.py @@ -2,4 +2,4 @@ tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes. """ -__version__ = "0.24.0" +__version__ = "0.25.0" diff --git a/src/tensortrax/math/special/__init__.py b/src/tensortrax/math/special/__init__.py index f11fc68..a9bac1b 100644 --- a/src/tensortrax/math/special/__init__.py +++ b/src/tensortrax/math/special/__init__.py @@ -5,6 +5,7 @@ from ._special_tensor import ( ddot, dev, + erf, from_triu_1d, from_triu_2d, sym, @@ -17,6 +18,7 @@ __all__ = [ "ddot", "dev", + "erf", "from_triu_1d", "from_triu_2d", "sym", diff --git a/src/tensortrax/math/special/_special_tensor.py b/src/tensortrax/math/special/_special_tensor.py index b8bbc25..56d1f98 100644 --- a/src/tensortrax/math/special/_special_tensor.py +++ b/src/tensortrax/math/special/_special_tensor.py @@ -4,6 +4,7 @@ import numpy as np +from ..._helpers import Δ, Δδ, f, δ from ..._tensor import Tensor from .. import _math_array as array from .._math_tensor import einsum, sqrt, stack, trace, transpose @@ -20,6 +21,23 @@ def dev(A): return A - trace(A) / dim * array.eye(A) +def erf(z): + "The (Gauss) error function." + from scipy.special import erf + + if isinstance(z, Tensor): + derf = 2 / np.sqrt(np.pi) * np.exp(-f(z) ** 2) + return Tensor( + x=erf(f(z)), + δx=derf * δ(z), + Δx=derf * Δ(z), + Δδx=-2 * f(z) * derf * δ(z) * Δ(z) + derf * Δδ(z), + ntrax=z.ntrax, + ) + else: + return erf(z) + + def sym(A): "Symmetric part of a Tensor." return (A + transpose(A)) / 2 diff --git a/tests/test_math.py b/tests/test_math.py index a154c4d..4b9f2e9 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -69,6 +69,7 @@ def test_math(): tm.ravel, tm.abs, tm.sign, + tm.special.erf, ]: assert np.allclose(fun(F), fun(T).x)