Skip to content

Commit

Permalink
Merge pull request #70 from adtzlr/add-math-split
Browse files Browse the repository at this point in the history
Add `math.split()`
  • Loading branch information
adtzlr authored Feb 12, 2023
2 parents 8003206 + 2062caf commit 65ce811
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tensortrax/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes.
"""

__version__ = "0.8.2"
__version__ = "0.8.3"
1 change: 1 addition & 0 deletions src/tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
sign,
sin,
sinh,
split,
sqrt,
stack,
sum,
Expand Down
15 changes: 15 additions & 0 deletions src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,18 @@ def stack(arrays, axis=0):
)
else:
return np.stack(arrays, axis=0)


def split(ary, indices_or_sections, axis=0):
"Split an array into multiple sub-arrays as views into ary."

if isinstance(ary, Tensor):
return Tensor(
x=np.split(f(ary), indices_or_sections=indices_or_sections, axis=axis),
δx=np.split(δ(ary), indices_or_sections=indices_or_sections, axis=axis),
Δx=np.split(Δ(ary), indices_or_sections=indices_or_sections, axis=axis),
Δδx=np.split(Δδ(ary), indices_or_sections=indices_or_sections, axis=axis),
ntrax=ary.ntrax,
)
else:
return np.split(ary, indices_or_sections=indices_or_sections, axis=axis)
1 change: 1 addition & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_math():
assert np.allclose(tm.stack([T, T]).x, tm.stack([F, F]))
assert np.allclose(tm.repeat(T, 3).x, tm.repeat(F, 3))
assert np.allclose(tm.tile(T, 3).x, tm.tile(F, 3))
assert np.allclose(tm.split(T, [1, 2]).x, tm.split(F, [1, 2]))


def test_einsum():
Expand Down

0 comments on commit 65ce811

Please sign in to comment.