From 758727b9f16cb402374fbc344121a69360132467 Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Sun, 12 Feb 2023 11:56:26 +0100 Subject: [PATCH 1/2] add `math.split()` --- src/tensortrax/math/__init__.py | 1 + src/tensortrax/math/_math_tensor.py | 15 +++++++++++++++ tests/test_math.py | 1 + 3 files changed, 17 insertions(+) diff --git a/src/tensortrax/math/__init__.py b/src/tensortrax/math/__init__.py index 42ffa9a..dd83557 100644 --- a/src/tensortrax/math/__init__.py +++ b/src/tensortrax/math/__init__.py @@ -31,6 +31,7 @@ sign, sin, sinh, + split, sqrt, stack, sum, diff --git a/src/tensortrax/math/_math_tensor.py b/src/tensortrax/math/_math_tensor.py index d3b9898..65c5fe2 100644 --- a/src/tensortrax/math/_math_tensor.py +++ b/src/tensortrax/math/_math_tensor.py @@ -308,3 +308,18 @@ def stack(arrays, axis=0): ) else: return np.stack(arrays, axis=0) + + +def split(ary, indices_or_sections, axis=0): + "Split an array into multiple sub-arrays as views into ary." + + if isinstance(ary, Tensor): + return Tensor( + x=np.split(f(ary), indices_or_sections=indices_or_sections, axis=axis), + δx=np.split(δ(ary), indices_or_sections=indices_or_sections, axis=axis), + Δx=np.split(Δ(ary), indices_or_sections=indices_or_sections, axis=axis), + Δδx=np.split(Δδ(ary), indices_or_sections=indices_or_sections, axis=axis), + ntrax=ary.ntrax, + ) + else: + return np.split(ary, indices_or_sections=indices_or_sections, axis=axis) diff --git a/tests/test_math.py b/tests/test_math.py index 2cecad7..5de14d3 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -99,6 +99,7 @@ def test_math(): assert np.allclose(tm.stack([T, T]).x, tm.stack([F, F])) assert np.allclose(tm.repeat(T, 3).x, tm.repeat(F, 3)) assert np.allclose(tm.tile(T, 3).x, tm.tile(F, 3)) + assert np.allclose(tm.split(T, [1, 2]).x, tm.split(F, [1, 2])) def test_einsum(): From 2062caf1aea3bbd700dd0322059801d86d0cea52 Mon Sep 17 00:00:00 2001 From: Andreas Dutzler Date: Sun, 12 Feb 2023 12:23:48 +0100 Subject: [PATCH 2/2] Update __about__.py --- src/tensortrax/__about__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tensortrax/__about__.py b/src/tensortrax/__about__.py index 796998f..4fd8045 100644 --- a/src/tensortrax/__about__.py +++ b/src/tensortrax/__about__.py @@ -2,4 +2,4 @@ tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes. """ -__version__ = "0.8.2" +__version__ = "0.8.3"