From 4d2184159b15137b587b2f765c508621dce3a4d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateus=20Ara=C3=BAjo?= Date: Mon, 7 Oct 2024 04:31:49 +0200 Subject: [PATCH] cleanup functions of Hermitian matrices (#55951) The functions of Hermitian matrices are a bit of a mess. For example, if we have a Hermitian matrix `a` with negative eigenvalues, `a^0.5` doesn't produce the `Symmetric` wrapper, but `sqrt(a)` does. On the other hand, if we have a positive definite `b`, `b^0.5` will be `Hermitian`, but `sqrt(b)` will be `Symmetric`: ```julia using LinearAlgebra a = Hermitian([1.0 2.0;2.0 1.0]) a^0.5 sqrt(a) b = Hermitian([2.0 1.0; 1.0 2.0]) b^0.5 sqrt(b) ``` This sort of arbitrary assignment of wrappers happens with pretty much all functions defined there. There's also some oddities, such as `cis` being the only function defined for `SymTridiagonal`, even though all `eigen`-based functions work, and `cbrt` being the only function not defined for complex Hermitian matrices. I did a cleanup: I defined all functions for `SymTridiagonal` and `Hermitian{<:Complex}`, and always assigned the appropriate wrapper, preserving the input one when possible. There's an inconsistency remaining that I didn't fix, that only `sqrt` and `log` accept a tolerance argument, as changing that is probably breaking. There were also hardly any tests that I could find (only `exp`, `log`, `cis`, and `sqrt`). I'm happy to add them if it's desired. --- stdlib/LinearAlgebra/src/symmetric.jl | 158 +++++++++++++++----------- 1 file changed, 93 insertions(+), 65 deletions(-) diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index a7739596a73bb..e17eb80d25453 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -810,26 +810,32 @@ end # Matrix functions ^(A::Symmetric{<:Real}, p::Integer) = sympow(A, p) ^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p) -function sympow(A::Symmetric, p::Integer) - if p < 0 - return Symmetric(Base.power_by_squaring(inv(A), -p)) - else - return Symmetric(Base.power_by_squaring(A, p)) - end -end -function ^(A::Symmetric{<:Real}, p::Real) - isinteger(p) && return integerpow(A, p) - F = eigen(A) - if all(λ -> λ ≥ 0, F.values) - return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors') - else - return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors') +^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p) +^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p) +for hermtype in (:Symmetric, :SymTridiagonal) + @eval begin + function sympow(A::$hermtype, p::Integer) + if p < 0 + return Symmetric(Base.power_by_squaring(inv(A), -p)) + else + return Symmetric(Base.power_by_squaring(A, p)) + end + end + function ^(A::$hermtype{<:Real}, p::Real) + isinteger(p) && return integerpow(A, p) + F = eigen(A) + if all(λ -> λ ≥ 0, F.values) + return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors') + end + end + function ^(A::$hermtype{<:Complex}, p::Real) + isinteger(p) && return integerpow(A, p) + return Symmetric(schurpow(A, p)) + end end end -function ^(A::Symmetric{<:Complex}, p::Real) - isinteger(p) && return integerpow(A, p) - return Symmetric(schurpow(A, p)) -end function ^(A::Hermitian, p::Integer) if p < 0 retmat = Base.power_by_squaring(inv(A), -p) @@ -855,16 +861,25 @@ function ^(A::Hermitian{T}, p::Real) where T return Hermitian(retmat) end else - return (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors' + retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors' + if T <: Real + return Symmetric(retmat) + else + return retmat + end end end -for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh) - @eval begin - function ($func)(A::HermOrSym{<:Real}) - F = eigen(A) - return Symmetric((F.vectors * Diagonal(($func).(F.values))) * F.vectors') +for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt) + for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] + @eval begin + function ($func)(A::$hermtype{<:Real}) + F = eigen(A) + return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + end end + end + @eval begin function ($func)(A::Hermitian{<:Complex}) n = checksquare(A) F = eigen(A) @@ -877,23 +892,34 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh) end end -function cis(A::Union{RealHermSymComplexHerm,SymTridiagonal{<:Real}}) +for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal) + @eval begin + function cis(A::$wrapper{<:Real}) + F = eigen(A) + return Symmetric(F.vectors .* cis.(F.values') * F.vectors') + end + end +end +function cis(A::Hermitian{<:Complex}) F = eigen(A) - # The returned matrix is unitary, and is complex-symmetric for real A return F.vectors .* cis.(F.values') * F.vectors' end + for func in (:acos, :asin) - @eval begin - function ($func)(A::HermOrSym{<:Real}) - F = eigen(A) - if all(λ -> -1 ≤ λ ≤ 1, F.values) - retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors' - else - retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors' + for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] + @eval begin + function ($func)(A::$hermtype{<:Real}) + F = eigen(A) + if all(λ -> -1 ≤ λ ≤ 1, F.values) + return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') + end end - return Symmetric(retmat) end + end + @eval begin function ($func)(A::Hermitian{<:Complex}) n = checksquare(A) F = eigen(A) @@ -910,14 +936,17 @@ for func in (:acos, :asin) end end -function acosh(A::HermOrSym{<:Real}) - F = eigen(A) - if all(λ -> λ ≥ 1, F.values) - retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors' - else - retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors' +for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] + @eval begin + function acosh(A::$hermtype{<:Real}) + F = eigen(A) + if all(λ -> λ ≥ 1, F.values) + return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors') + end + end end - return Symmetric(retmat) end function acosh(A::Hermitian{<:Complex}) n = checksquare(A) @@ -933,14 +962,18 @@ function acosh(A::Hermitian{<:Complex}) end end -function sincos(A::HermOrSym{<:Real}) - n = checksquare(A) - F = eigen(A) - S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,))) - for i in 1:n - S.diag[i], C.diag[i] = sincos(F.values[i]) +for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] + @eval begin + function sincos(A::$hermtype{<:Real}) + n = checksquare(A) + F = eigen(A) + S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,))) + for i in 1:n + S.diag[i], C.diag[i] = sincos(F.values[i]) + end + return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors') + end end - return Symmetric((F.vectors * S) * F.vectors'), Symmetric((F.vectors * C) * F.vectors') end function sincos(A::Hermitian{<:Complex}) n = checksquare(A) @@ -962,18 +995,20 @@ for func in (:log, :sqrt) # sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[] rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0 - @eval begin - function ($func)(A::HermOrSym{T}; $(rtolarg...)) where {T<:Real} - F = eigen(A) - λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff - if all(λ -> λ ≥ λ₀, F.values) - retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors' - else - retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors' + for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)] + @eval begin + function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real} + F = eigen(A) + λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff + if all(λ -> λ ≥ λ₀, F.values) + return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors') + else + return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors') + end end - return Symmetric(retmat) end - + end + @eval begin function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex} n = checksquare(A) F = eigen(A) @@ -992,13 +1027,6 @@ for func in (:log, :sqrt) end end -# Cube root of a real-valued symmetric matrix -function cbrt(A::HermOrSym{<:Real}) - F = eigen(A) - A = F.vectors * Diagonal(cbrt.(F.values)) * F.vectors' - return A -end - """ hermitianpart(A::AbstractMatrix, uplo::Symbol=:U) -> Hermitian