Skip to content

Commit

Permalink
Add einsum4()
Browse files Browse the repository at this point in the history
  • Loading branch information
adtzlr committed Jun 13, 2024
1 parent 191b0c2 commit 0e86c33
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Highlights
- Write differentiable code with Tensors based on NumPy arrays
- Efficient evaluation of batches by elementwise-operating trailing axes
- Essential vector/tensor Hyper-Dual number math, including limited support for `einsum` (restricted to max. three operands)
- Essential vector/tensor Hyper-Dual number math, including limited support for `einsum` (restricted to max. four operands)
- Math is limited but similar to NumPy, try to use `import tensortrax.math as tm` instead of `import numpy as np` inside functions to be differentiated
- Forward Mode Automatic Differentiation (AD) using Hyper-Dual Tensors, up to second order derivatives
- Create functions in terms of Hyper-Dual Tensors
Expand Down
313 changes: 311 additions & 2 deletions src/tensortrax/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,313 @@ def reshape(A, newshape, order="C"):
return np.reshape(A, newshape=newshape, order=order)


def einsum4(subscripts, *operands):
"Einsum with four operands."
A, B, C, D = operands

def _einsum(*operands):
return np.einsum(subscripts, *operands)

if (
isinstance(A, Tensor)
and isinstance(B, Tensor)
and isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(f(A), f(B), f(C), f(D))
δx = (
_einsum(δ(A), f(B), f(C), f(D))
+ _einsum(f(A), δ(B), f(C), f(D))
+ _einsum(f(A), f(B), δ(C), f(D))
+ _einsum(f(A), f(B), f(C), δ(D))
)
Δx = (
_einsum(Δ(A), f(B), f(C), f(D))
+ _einsum(f(A), Δ(B), f(C), f(D))
+ _einsum(f(A), f(B), Δ(C), f(D))
+ _einsum(f(A), f(B), f(C), Δ(D))
)
Δδx = (
_einsum(Δδ(A), f(B), f(C), f(D))
+ _einsum(f(A), Δδ(B), f(C), f(D))
+ _einsum(f(A), f(B), Δδ(C), f(D))
+ _einsum(f(A), f(B), f(C), Δδ(D))
+ _einsum(δ(A), Δ(B), f(C), f(D))
+ _einsum(Δ(A), δ(B), f(C), f(D))
+ _einsum(δ(A), f(B), Δ(C), f(D))
+ _einsum(Δ(A), f(B), δ(C), f(D))
+ _einsum(δ(A), f(B), f(C), Δ(D))
+ _einsum(Δ(A), f(B), f(C), δ(D))
+ _einsum(f(A), δ(B), Δ(C), f(D))
+ _einsum(f(A), Δ(B), δ(C), f(D))
+ _einsum(f(A), δ(B), f(C), Δ(D))
+ _einsum(f(A), Δ(B), f(C), δ(D))
+ _einsum(f(A), f(B), δ(C), Δ(D))
+ _einsum(f(A), f(B), Δ(C), δ(D))
)
ntrax = min(A.ntrax, B.ntrax, C.ntrax, D.ntrax)
elif (
isinstance(A, Tensor)
and not isinstance(B, Tensor)
and not isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(f(A), B, C, D)
δx = _einsum(δ(A), B, C, D)
Δx = _einsum(Δ(A), B, C, D)
Δδx = _einsum(Δδ(A), B, C, D)
ntrax = A.ntrax
elif (
not isinstance(A, Tensor)
and isinstance(B, Tensor)
and not isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(A, f(B), C, D)
δx = _einsum(A, δ(B), C, D)
Δx = _einsum(A, Δ(B), C, D)
Δδx = _einsum(A, Δδ(B), C, D)
ntrax = B.ntrax
elif (
not isinstance(A, Tensor)
and not isinstance(B, Tensor)
and isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(A, B, f(C), D)
δx = _einsum(A, B, δ(C), D)
Δx = _einsum(A, B, Δ(C), D)
Δδx = _einsum(A, B, Δδ(C), D)
ntrax = C.ntrax
elif (
not isinstance(A, Tensor)
and not isinstance(B, Tensor)
and not isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(A, B, C, f(D))
δx = _einsum(A, B, C, δ(D))
Δx = _einsum(A, B, C, Δ(D))
Δδx = _einsum(A, B, C, Δδ(D))
ntrax = D.ntrax
elif (
isinstance(A, Tensor)
and isinstance(B, Tensor)
and not isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(f(A), f(B), C, D)
δx = _einsum(δ(A), f(B), C, D) + _einsum(f(A), δ(B), C, D)
Δx = _einsum(Δ(A), f(B), C, D) + _einsum(f(A), Δ(B), C, D)
Δδx = (
_einsum(Δδ(A), f(B), C, D)
+ _einsum(f(A), Δδ(B), C, D)
+ _einsum(δ(A), Δ(B), C, D)
+ _einsum(Δ(A), δ(B), C, D)
)
ntrax = min(A.ntrax, B.ntrax)
elif (
isinstance(A, Tensor)
and not isinstance(B, Tensor)
and isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(f(A), B, f(C), D)
δx = _einsum(δ(A), B, f(C), D) + _einsum(f(A), B, δ(C), D)
Δx = _einsum(Δ(A), B, f(C), D) + _einsum(f(A), B, Δ(C), D)
Δδx = (
_einsum(Δδ(A), B, f(C), D)
+ _einsum(f(A), B, Δδ(C), D)
+ _einsum(δ(A), B, Δ(C), D)
+ _einsum(Δ(A), B, δ(C), D)
)
ntrax = min(A.ntrax, C.ntrax)
elif (
isinstance(A, Tensor)
and not isinstance(B, Tensor)
and not isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(f(A), B, C, f(D))
δx = _einsum(δ(A), B, C, f(D)) + _einsum(f(A), B, C, δ(D))
Δx = _einsum(Δ(A), B, C, f(D)) + _einsum(f(A), B, C, Δ(D))
Δδx = (
_einsum(Δδ(A), B, C, f(D))
+ _einsum(f(A), B, C, Δδ(D))
+ _einsum(δ(A), B, C, Δ(D))
+ _einsum(Δ(A), B, C, δ(D))
)
ntrax = min(A.ntrax, D.ntrax)
elif (
not isinstance(A, Tensor)
and isinstance(B, Tensor)
and isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(A, f(B), f(C), D)
δx = _einsum(A, δ(B), f(C), D) + _einsum(A, f(B), δ(C), D)
Δx = _einsum(A, Δ(B), f(C), D) + _einsum(A, f(B), Δ(C), D)
Δδx = (
_einsum(A, Δδ(B), f(C), D)
+ _einsum(A, f(B), Δδ(C), D)
+ _einsum(A, δ(B), Δ(C), D)
+ _einsum(A, Δ(B), δ(C), D)
)
ntrax = min(B.ntrax, C.ntrax)
elif (
not isinstance(A, Tensor)
and isinstance(B, Tensor)
and not isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(A, f(B), C, f(D))
δx = _einsum(A, δ(B), C, f(D)) + _einsum(A, f(B), C, δ(D))
Δx = _einsum(A, Δ(B), C, f(D)) + _einsum(A, f(B), C, Δ(D))
Δδx = (
_einsum(A, Δδ(B), C, f(D))
+ _einsum(A, f(B), C, Δδ(D))
+ _einsum(A, δ(B), C, Δ(D))
+ _einsum(A, Δ(B), C, δ(D))
)
ntrax = min(B.ntrax, D.ntrax)
elif (
not isinstance(A, Tensor)
and not isinstance(B, Tensor)
and isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(A, B, f(C), f(D))
δx = _einsum(A, B, δ(C), f(D)) + _einsum(A, B, f(C), δ(D))
Δx = _einsum(A, B, Δ(C), f(D)) + _einsum(A, B, f(C), Δ(D))
Δδx = (
_einsum(A, B, Δδ(C), f(D))
+ _einsum(A, B, f(C), Δδ(D))
+ _einsum(A, B, δ(C), Δ(D))
+ _einsum(A, B, Δ(C), δ(D))
)
ntrax = min(C.ntrax, D.ntrax)
elif (
isinstance(A, Tensor)
and isinstance(B, Tensor)
and isinstance(C, Tensor)
and not isinstance(D, Tensor)
):
x = _einsum(f(A), f(B), f(C), D)
δx = (
_einsum(δ(A), f(B), f(C), D)
+ _einsum(f(A), δ(B), f(C), D)
+ _einsum(f(A), f(B), δ(C), D)
)
Δx = (
_einsum(Δ(A), f(B), f(C), D)
+ _einsum(f(A), Δ(B), f(C), D)
+ _einsum(f(A), f(B), Δ(C), D)
)
Δδx = (
_einsum(Δδ(A), f(B), f(C), D)
+ _einsum(f(A), Δδ(B), f(C), D)
+ _einsum(f(A), f(B), Δδ(C), D)
+ _einsum(δ(A), Δ(B), f(C), D)
+ _einsum(Δ(A), δ(B), f(C), D)
+ _einsum(δ(A), f(B), Δ(C), D)
+ _einsum(Δ(A), f(B), δ(C), D)
+ _einsum(f(A), δ(B), Δ(C), D)
+ _einsum(f(A), Δ(B), δ(C), D)
)
ntrax = min(A.ntrax, B.ntrax, C.ntrax)
elif (
isinstance(A, Tensor)
and isinstance(B, Tensor)
and not isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(f(A), f(B), C, f(D))
δx = (
_einsum(δ(A), f(B), C, f(D))
+ _einsum(f(A), δ(B), C, f(D))
+ _einsum(f(A), f(B), C, δ(D))
)
Δx = (
_einsum(Δ(A), f(B), C, f(D))
+ _einsum(f(A), Δ(B), C, f(D))
+ _einsum(f(A), f(B), C, Δ(D))
)
Δδx = (
_einsum(Δδ(A), f(B), C, f(D))
+ _einsum(f(A), Δδ(B), C, f(D))
+ _einsum(f(A), f(B), C, Δδ(D))
+ _einsum(δ(A), Δ(B), C, f(D))
+ _einsum(Δ(A), δ(B), C, f(D))
+ _einsum(δ(A), f(B), C, Δ(D))
+ _einsum(Δ(A), f(B), C, δ(D))
+ _einsum(f(A), δ(B), C, Δ(D))
+ _einsum(f(A), Δ(B), C, δ(D))
)
ntrax = min(A.ntrax, B.ntrax, D.ntrax)
elif (
isinstance(A, Tensor)
and not isinstance(B, Tensor)
and isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(f(A), B, f(C), f(D))
δx = (
_einsum(δ(A), B, f(C), f(D))
+ _einsum(f(A), B, δ(C), f(D))
+ _einsum(f(A), B, f(C), δ(D))
)
Δx = (
_einsum(Δ(A), B, f(C), f(D))
+ _einsum(f(A), B, Δ(C), f(D))
+ _einsum(f(A), B, f(C), Δ(D))
)
Δδx = (
_einsum(Δδ(A), B, f(C), f(D))
+ _einsum(f(A), B, Δδ(C), f(D))
+ _einsum(f(A), B, f(C), Δδ(D))
+ _einsum(δ(A), B, Δ(C), f(D))
+ _einsum(Δ(A), B, δ(C), f(D))
+ _einsum(δ(A), B, f(C), Δ(D))
+ _einsum(Δ(A), B, f(C), δ(D))
+ _einsum(f(A), B, δ(C), Δ(D))
+ _einsum(f(A), B, Δ(C), δ(D))
)
ntrax = min(A.ntrax, C.ntrax, D.ntrax)
elif (
not isinstance(A, Tensor)
and isinstance(B, Tensor)
and isinstance(C, Tensor)
and isinstance(D, Tensor)
):
x = _einsum(A, f(B), f(C), f(D))
δx = (
_einsum(A, δ(B), f(C), f(D))
+ _einsum(A, f(B), δ(C), f(D))
+ _einsum(A, f(B), f(C), δ(D))
)
Δx = (
_einsum(A, Δ(B), f(C), f(D))
+ _einsum(A, f(B), Δ(C), f(D))
+ _einsum(A, f(B), f(C), Δ(D))
)
Δδx = (
_einsum(A, Δδ(B), f(C), f(D))
+ _einsum(A, f(B), Δδ(C), f(D))
+ _einsum(A, f(B), f(C), Δδ(D))
+ _einsum(A, δ(B), Δ(C), f(D))
+ _einsum(A, Δ(B), δ(C), f(D))
+ _einsum(A, δ(B), f(C), Δ(D))
+ _einsum(A, Δ(B), f(C), δ(D))
+ _einsum(A, f(B), δ(C), Δ(D))
+ _einsum(A, f(B), Δ(C), δ(D))
)
ntrax = min(B.ntrax, C.ntrax, D.ntrax)
else:
return _einsum(*operands)

return Tensor(x=x, δx=δx, Δx=Δx, Δδx=Δδx, ntrax=ntrax)


def einsum3(subscripts, *operands):
"Einsum with three operands."
A, B, C = operands
Expand Down Expand Up @@ -743,15 +1050,17 @@ def _einsum(*operands):


def einsum(subscripts, *operands):
"Einsum limited to one, two or three operands."
"Einsum limited to one, two, three or four operands."
if len(operands) == 1:
return einsum1(subscripts, *operands)
elif len(operands) == 2:
return einsum2(subscripts, *operands)
elif len(operands) == 3:
return einsum3(subscripts, *operands)
elif len(operands) == 4:
return einsum4(subscripts, *operands)
else:
raise NotImplementedError("More than three operands are not supported.")
raise NotImplementedError("More than four operands are not supported.")


def transpose(A):
Expand Down
21 changes: 19 additions & 2 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,26 @@ def test_einsum():
tm.einsum("ij...,kl...,mn...->ijklmn...", T, F, T)
tm.einsum("ij...,kl...,mn...->ijklmn...", T, T, F)
tm.einsum("ij...,kl...,mn...->ijklmn...", T, T, T)


tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, F, F, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, F, F, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, F, T, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, T, F, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, F, F, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, F, T, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, T, F, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, F, F, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, F, T, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, F, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", F, T, T, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, F, T, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, F, T)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, T, F)
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, T, T)


with pytest.raises(NotImplementedError):
tm.einsum("ij...,kl...,mn...,pq...->ijklmnpq...", T, T, T, T)
tm.einsum("ij...,kl...,mn...,pq...,rs...->ijklmnpqrs...", T, T, T, T, T)


def test_slice():
Expand Down

0 comments on commit 0e86c33

Please sign in to comment.