Skip to content

Commit

Permalink
Add math-functions concatenate(arrays), maximum(x1, x2) and `mini…
Browse files Browse the repository at this point in the history
…mum(x1, x2)`

and also fix the ignored `axis`-argument in `stack()` for array-inputs
  • Loading branch information
adtzlr committed May 3, 2024
1 parent 23851eb commit 2787799
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/tensortrax/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._math_tensor import (
abs,
array,
concatenate,
cos,
cosh,
diagonal,
Expand All @@ -20,6 +21,8 @@
log,
log10,
matmul,
maximum,
minimum,
repeat,
sign,
sin,
Expand All @@ -45,6 +48,7 @@
"broadcast_to",
"cos",
"cosh",
"concatenate",
"diagonal",
"dot",
"dual2real",
Expand All @@ -55,6 +59,8 @@
"log",
"log10",
"matmul",
"maximum",
"minimum",
"ravel",
"repeat",
"reshape",
Expand Down
35 changes: 34 additions & 1 deletion src/tensortrax/math/_math_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,22 @@ def stack(arrays, axis=0):
ntrax=min([A.ntrax for A in arrays]),
)
else:
return np.stack(arrays, axis=0)
return np.stack(arrays, axis=axis)


def concatenate(arrays, axis=0):
"Join a sequence of arrays along an existing axis."

if isinstance(arrays[0], Tensor):
return Tensor(
x=np.concatenate([f(A) for A in arrays], axis=axis),
δx=np.concatenate([δ(A) for A in arrays], axis=axis),
Δx=np.concatenate([Δ(A) for A in arrays], axis=axis),
Δδx=np.concatenate([Δδ(A) for A in arrays], axis=axis),
ntrax=min([A.ntrax for A in arrays]),
)
else:
return np.concatenate(arrays, axis=axis)


def split(ary, indices_or_sections, axis=0):
Expand Down Expand Up @@ -401,3 +416,21 @@ def if_else(cond, true, false):
)

return out


def maximum(x1, x2):
"Element-wise maximum of array elements."

if isinstance(x1, Tensor):
return if_else(x1 > x2, x1, x2)
else:
return np.maximum(x1, x2)


def minimum(x1, x2):
"Element-wise minimum of array elements."

if isinstance(x1, Tensor):
return if_else(x1 < x2, x1, x2)
else:
return np.minimum(x1, x2)
11 changes: 11 additions & 0 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def test_math():
assert np.allclose(tm.vstack([T, T]).x, tm.vstack([F, F]))
assert np.allclose(tm.hstack([T, T]).x, tm.hstack([F, F]))
assert np.allclose(tm.stack([T, T]).x, tm.stack([F, F]))
assert np.allclose(tm.concatenate([T, T]).x, tm.concatenate([F, F]))
assert np.allclose(tm.repeat(T, 3).x, tm.repeat(F, 3))
assert np.allclose(tm.tile(T, 3).x, tm.tile(F, 3))
assert np.allclose(tm.split(T, [1, 2])[1].x, tm.split(F, [1, 2])[1])
Expand Down Expand Up @@ -226,6 +227,16 @@ def test_condition():
Y = tm.if_else(F >= G, 2 * F, G / 2)
Z = tm.if_else(T >= V, 2 * T, V / 2)

max_array = tm.maximum(F, G)
max_tensor = tm.maximum(T, V)

assert np.allclose(max_array, max_tensor.x)

min_array = tm.minimum(F, G)
min_tensor = tm.minimum(T, V)

assert np.allclose(min_array, min_tensor.x)

np.allclose(Y, Z.x)

with pytest.raises(NotImplementedError):
Expand Down

0 comments on commit 2787799

Please sign in to comment.