Skip to content

Commit

Permalink
cleanup functions of Hermitian matrices (#55951)
Browse files Browse the repository at this point in the history
The functions of Hermitian matrices are a bit of a mess. For example, if
we have a Hermitian matrix `a` with negative eigenvalues, `a^0.5`
doesn't produce the `Symmetric` wrapper, but `sqrt(a)` does. On the
other hand, if we have a positive definite `b`, `b^0.5` will be
`Hermitian`, but `sqrt(b)` will be `Symmetric`:
```julia
using LinearAlgebra
a = Hermitian([1.0 2.0;2.0 1.0])
a^0.5
sqrt(a)
b = Hermitian([2.0 1.0; 1.0 2.0])
b^0.5
sqrt(b)
```
This sort of arbitrary assignment of wrappers happens with pretty much
all functions defined there. There's also some oddities, such as `cis`
being the only function defined for `SymTridiagonal`, even though all
`eigen`-based functions work, and `cbrt` being the only function not
defined for complex Hermitian matrices.

I did a cleanup: I defined all functions for `SymTridiagonal` and
`Hermitian{<:Complex}`, and always assigned the appropriate wrapper,
preserving the input one when possible.

There's an inconsistency remaining that I didn't fix, that only `sqrt`
and `log` accept a tolerance argument, as changing that is probably
breaking.

There were also hardly any tests that I could find (only `exp`, `log`,
`cis`, and `sqrt`). I'm happy to add them if it's desired.
  • Loading branch information
araujoms authored Oct 7, 2024
1 parent c2a2e38 commit 4d21841
Showing 1 changed file with 93 additions and 65 deletions.
158 changes: 93 additions & 65 deletions stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -810,26 +810,32 @@ end
# Matrix functions
^(A::Symmetric{<:Real}, p::Integer) = sympow(A, p)
^(A::Symmetric{<:Complex}, p::Integer) = sympow(A, p)
function sympow(A::Symmetric, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^(A::Symmetric{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
if all-> λ 0, F.values)
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
^(A::SymTridiagonal{<:Real}, p::Integer) = sympow(A, p)
^(A::SymTridiagonal{<:Complex}, p::Integer) = sympow(A, p)
for hermtype in (:Symmetric, :SymTridiagonal)
@eval begin
function sympow(A::$hermtype, p::Integer)
if p < 0
return Symmetric(Base.power_by_squaring(inv(A), -p))
else
return Symmetric(Base.power_by_squaring(A, p))
end
end
function ^(A::$hermtype{<:Real}, p::Real)
isinteger(p) && return integerpow(A, p)
F = eigen(A)
if all-> λ 0, F.values)
return Symmetric((F.vectors * Diagonal((F.values).^p)) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(complex.(F.values).^p)) * F.vectors')
end
end
function ^(A::$hermtype{<:Complex}, p::Real)
isinteger(p) && return integerpow(A, p)
return Symmetric(schurpow(A, p))
end
end
end
function ^(A::Symmetric{<:Complex}, p::Real)
isinteger(p) && return integerpow(A, p)
return Symmetric(schurpow(A, p))
end
function ^(A::Hermitian, p::Integer)
if p < 0
retmat = Base.power_by_squaring(inv(A), -p)
Expand All @@ -855,16 +861,25 @@ function ^(A::Hermitian{T}, p::Real) where T
return Hermitian(retmat)
end
else
return (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
retmat = (F.vectors * Diagonal((complex.(F.values).^p))) * F.vectors'
if T <: Real
return Symmetric(retmat)
else
return retmat
end
end
end

for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
@eval begin
function ($func)(A::HermOrSym{<:Real})
F = eigen(A)
return Symmetric((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh, :cbrt)
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
end
end
end
@eval begin
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
Expand All @@ -877,23 +892,34 @@ for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
end
end

function cis(A::Union{RealHermSymComplexHerm,SymTridiagonal{<:Real}})
for wrapper in (:Symmetric, :Hermitian, :SymTridiagonal)
@eval begin
function cis(A::$wrapper{<:Real})
F = eigen(A)
return Symmetric(F.vectors .* cis.(F.values') * F.vectors')
end
end
end
function cis(A::Hermitian{<:Complex})
F = eigen(A)
# The returned matrix is unitary, and is complex-symmetric for real A
return F.vectors .* cis.(F.values') * F.vectors'
end


for func in (:acos, :asin)
@eval begin
function ($func)(A::HermOrSym{<:Real})
F = eigen(A)
if all-> -1 λ 1, F.values)
retmat = (F.vectors * Diagonal(($func).(F.values))) * F.vectors'
else
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{<:Real})
F = eigen(A)
if all-> -1 λ 1, F.values)
return $wrapper((F.vectors * Diagonal(($func).(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
return Symmetric(retmat)
end
end
@eval begin
function ($func)(A::Hermitian{<:Complex})
n = checksquare(A)
F = eigen(A)
Expand All @@ -910,14 +936,17 @@ for func in (:acos, :asin)
end
end

function acosh(A::HermOrSym{<:Real})
F = eigen(A)
if all-> λ 1, F.values)
retmat = (F.vectors * Diagonal(acosh.(F.values))) * F.vectors'
else
retmat = (F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors'
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function acosh(A::$hermtype{<:Real})
F = eigen(A)
if all-> λ 1, F.values)
return $wrapper((F.vectors * Diagonal(acosh.(F.values))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(acosh.(complex.(F.values)))) * F.vectors')
end
end
end
return Symmetric(retmat)
end
function acosh(A::Hermitian{<:Complex})
n = checksquare(A)
Expand All @@ -933,14 +962,18 @@ function acosh(A::Hermitian{<:Complex})
end
end

function sincos(A::HermOrSym{<:Real})
n = checksquare(A)
F = eigen(A)
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
for i in 1:n
S.diag[i], C.diag[i] = sincos(F.values[i])
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function sincos(A::$hermtype{<:Real})
n = checksquare(A)
F = eigen(A)
S, C = Diagonal(similar(A, (n,))), Diagonal(similar(A, (n,)))
for i in 1:n
S.diag[i], C.diag[i] = sincos(F.values[i])
end
return $wrapper((F.vectors * S) * F.vectors'), $wrapper((F.vectors * C) * F.vectors')
end
end
return Symmetric((F.vectors * S) * F.vectors'), Symmetric((F.vectors * C) * F.vectors')
end
function sincos(A::Hermitian{<:Complex})
n = checksquare(A)
Expand All @@ -962,18 +995,20 @@ for func in (:log, :sqrt)
# sqrt has rtol arg to handle matrices that are semidefinite up to roundoff errors
rtolarg = func === :sqrt ? Any[Expr(:kw, :(rtol::Real), :(eps(real(float(one(T))))*size(A,1)))] : Any[]
rtolval = func === :sqrt ? :(-maximum(abs, F.values) * rtol) : 0
@eval begin
function ($func)(A::HermOrSym{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all-> λ λ₀, F.values)
retmat = (F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors'
else
retmat = (F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors'
for (hermtype, wrapper) in [(:Symmetric, :Symmetric), (:SymTridiagonal, :Symmetric), (:Hermitian, :Hermitian)]
@eval begin
function ($func)(A::$hermtype{T}; $(rtolarg...)) where {T<:Real}
F = eigen(A)
λ₀ = $rtolval # treat λ ≥ λ₀ as "zero" eigenvalues up to roundoff
if all-> λ λ₀, F.values)
return $wrapper((F.vectors * Diagonal(($func).(max.(0, F.values)))) * F.vectors')
else
return Symmetric((F.vectors * Diagonal(($func).(complex.(F.values)))) * F.vectors')
end
end
return Symmetric(retmat)
end

end
@eval begin
function ($func)(A::Hermitian{T}; $(rtolarg...)) where {T<:Complex}
n = checksquare(A)
F = eigen(A)
Expand All @@ -992,13 +1027,6 @@ for func in (:log, :sqrt)
end
end

# Cube root of a real-valued symmetric matrix
function cbrt(A::HermOrSym{<:Real})
F = eigen(A)
A = F.vectors * Diagonal(cbrt.(F.values)) * F.vectors'
return A
end

"""
hermitianpart(A::AbstractMatrix, uplo::Symbol=:U) -> Hermitian
Expand Down

0 comments on commit 4d21841

Please sign in to comment.