diff --git a/src/tensortrax/__about__.py b/src/tensortrax/__about__.py index f88c25b..43ea437 100644 --- a/src/tensortrax/__about__.py +++ b/src/tensortrax/__about__.py @@ -2,4 +2,4 @@ tensorTRAX: Math on (Hyper-Dual) Tensors with Trailing Axes. """ -__version__ = "0.22.0" +__version__ = "0.23.0" diff --git a/src/tensortrax/math/linalg/_linalg_tensor.py b/src/tensortrax/math/linalg/_linalg_tensor.py index 9535f30..49911ec 100644 --- a/src/tensortrax/math/linalg/_linalg_tensor.py +++ b/src/tensortrax/math/linalg/_linalg_tensor.py @@ -126,27 +126,55 @@ def eigh(A, eps=np.sqrt(np.finfo(float).eps)): # alpha = [0, 1, 2] # beta = [(1, 2), (2, 0), (0, 1)] - alpha = np.arange(dim) beta = [ np.concatenate([np.arange(a + 1, dim), np.arange(a)]) for a in np.arange(dim) ] δN = [] + ΔN = [] for α in alpha: δNα = [] + ΔNα = [] for β in beta[α]: Mαβ = einsum("i...,j...->ij...", N[α], N[β]) δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A)) + ΔAαβ = einsum("ij...,ij...->...", Mαβ, Δ(A)) λαβ = λ[α] - λ[β] δNα.append(1 / λαβ * N[β] * δAαβ) + ΔNα.append(1 / λαβ * N[β] * ΔAαβ) δN.append(sum(δNα, axis=0)) + ΔN.append(sum(ΔNα, axis=0)) + + ΔδN = [] + for α in alpha: + ΔδNα = [] + for β in beta[α]: + Mαβ = einsum("i...,j...->ij...", N[α], N[β]) + δAαβ = einsum("ij...,ij...->...", Mαβ, δ(A)) + ΔδAαβ = einsum("ij...,ij...->...", Mαβ, Δδ(A)) + λαβ = λ[α] - λ[β] + Δλαβ = Δλ[α] - Δλ[β] + ΔδNα.append( + -(λαβ**-2) * Δλαβ * N[β] * δAαβ + + 1 / λαβ * ΔN[β] * δAαβ + + 1 / λαβ * N[β] * ΔδAαβ + ) + ΔδN.append(sum(ΔδNα, axis=0)) δM = einsum("ai...,aj...->aij...", δN, N) + einsum("ai...,aj...->aij...", N, δN) + ΔM = einsum("ai...,aj...->aij...", ΔN, N) + einsum("ai...,aj...->aij...", N, ΔN) Δδλ = einsum("aij...,ij...->a...", δM, Δ(A)) + einsum( "aij...,ij...->a...", M, Δδ(A) ) + ΔδM = ( + einsum("ai...,aj...->aij...", δN, ΔN) + + einsum("ai...,aj...->aij...", ΔN, δN) + + einsum("ai...,aj...->aij...", ΔδN, N) + + einsum("ai...,aj...->aij...", N, ΔδN) + ) + return ( Tensor( x=λ, @@ -158,8 +186,8 @@ def eigh(A, eps=np.sqrt(np.finfo(float).eps)): Tensor( x=M, δx=δM, - Δx=δM * np.nan, - Δδx=δM * np.nan, + Δx=ΔM, + Δδx=ΔδM, ntrax=A.ntrax, ), )