Skip to content

Commit 49023b5

Browse files
dalumandreasnoack
authored andcommitted
Preserve types when adding/subtracting Herm/Sym/UniformScaling (#29500)
* Preserve types when adding/subtracting Herm/Sym/UniformScaling * Make `real(::SymOrHerm{<:Real})` consistent with `real(::Array)`. * Fix embarrassing ambiguity * More tests, remove imag(::Hermitian), simplify code * Remove `.λ`
1 parent 640b155 commit 49023b5

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,12 @@ transpose(A::Hermitian{<:Real}) = A
330330
adjoint(A::Symmetric) = Adjoint(A)
331331
transpose(A::Hermitian) = Transpose(A)
332332

333+
real(A::Symmetric{<:Real}) = A
334+
real(A::Hermitian{<:Real}) = A
335+
real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo))
336+
real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo))
337+
imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo))
338+
333339
Base.copy(A::Adjoint{<:Any,<:Hermitian}) = copy(A.parent)
334340
Base.copy(A::Transpose{<:Any,<:Symmetric}) = copy(A.parent)
335341
Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
@@ -394,6 +400,14 @@ end
394400
(-)(A::Symmetric{Tv,S}) where {Tv,S} = Symmetric{Tv,S}(-A.data, A.uplo)
395401
(-)(A::Hermitian{Tv,S}) where {Tv,S} = Hermitian{Tv,S}(-A.data, A.uplo)
396402

403+
## Addition/subtraction
404+
for f in (:+, :-)
405+
@eval $f(A::Symmetric, B::Symmetric) = Symmetric($f(A.data, B), sym_uplo(A.uplo))
406+
@eval $f(A::Hermitian, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
407+
@eval $f(A::Hermitian, B::Symmetric{<:Real}) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
408+
@eval $f(A::Symmetric{<:Real}, B::Hermitian) = Hermitian($f(A.data, B), sym_uplo(A.uplo))
409+
end
410+
397411
## Matvec
398412
mul!(y::StridedVector{T}, A::Symmetric{T,<:StridedMatrix}, x::StridedVector{T}) where {T<:BlasFloat} =
399413
BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)

stdlib/LinearAlgebra/src/uniformscaling.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,31 @@ for (t1, t2) in ((:UnitUpperTriangular, :UpperTriangular),
112112
end
113113
end
114114

115+
# Adding a complex UniformScaling to the diagonal of a Hermitian
116+
# matrix breaks the hermiticity, if the UniformScaling is non-real.
117+
# However, to preserve type stability, we do not special-case a
118+
# UniformScaling{<:Complex} that happens to be real.
119+
function (+)(A::Hermitian{T,S}, J::UniformScaling{<:Complex}) where {T,S}
120+
A_ = copytri!(copy(parent(A)), A.uplo)
121+
B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_)
122+
@inbounds for i in diagind(B)
123+
B[i] += J
124+
end
125+
return B
126+
end
127+
128+
function (-)(J::UniformScaling{<:Complex}, A::Hermitian{T,S}) where {T,S}
129+
A_ = copytri!(copy(parent(A)), A.uplo)
130+
B = convert(AbstractMatrix{Base._return_type(+, Tuple{eltype(A), typeof(J)})}, A_)
131+
@inbounds for i in eachindex(B)
132+
B[i] = -B[i]
133+
end
134+
@inbounds for i in diagind(B)
135+
B[i] += J
136+
end
137+
return B
138+
end
139+
115140
function (+)(A::AbstractMatrix, J::UniformScaling)
116141
checksquare(A)
117142
B = copy_oftype(A, Base._return_type(+, Tuple{eltype(A), typeof(J)}))

stdlib/LinearAlgebra/test/symmetric.jl

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ end
9090
@test (-Hermitian(aherm))::typeof(Hermitian(aherm)) == -aherm
9191
end
9292

93+
@testset "Addition and subtraction for Symmetric/Hermitian matrices" begin
94+
for f in (+, -)
95+
@test (f(Symmetric(asym), Symmetric(aposs)))::typeof(Symmetric(asym)) == f(asym, aposs)
96+
@test (f(Hermitian(aherm), Hermitian(apos)))::typeof(Hermitian(aherm)) == f(aherm, apos)
97+
@test (f(Symmetric(real(asym)), Hermitian(aherm)))::typeof(Hermitian(aherm)) == f(real(asym), aherm)
98+
@test (f(Hermitian(aherm), Symmetric(real(asym))))::typeof(Hermitian(aherm)) == f(aherm, real(asym))
99+
@test (f(Symmetric(asym), Hermitian(aherm))) == f(asym, aherm)
100+
@test (f(Hermitian(aherm), Symmetric(asym))) == f(aherm, asym)
101+
end
102+
end
103+
93104
@testset "getindex and unsafe_getindex" begin
94105
@test aherm[1,1] == Hermitian(aherm)[1,1]
95106
@test asym[1,1] == Symmetric(asym)[1,1]
@@ -153,6 +164,21 @@ end
153164
@test transpose(H) == Hermitian(copy(transpose(aherm)))
154165
end
155166
end
167+
168+
@testset "real, imag" begin
169+
S = Symmetric(asym)
170+
H = Hermitian(aherm)
171+
@test issymmetric(real(S))
172+
@test ishermitian(real(H))
173+
if eltya <: Real
174+
@test real(S) === S == asym
175+
@test real(H) === H == aherm
176+
elseif eltya <: Complex
177+
@test issymmetric(imag(S))
178+
@test !ishermitian(imag(H))
179+
end
180+
end
181+
156182
end
157183

158184
@testset "linalg unary ops" begin
@@ -415,9 +441,6 @@ end
415441

416442
@test T([true false; false true]) .+ true == T([2 1; 1 2])
417443
end
418-
419-
@test_throws ArgumentError Hermitian(X) + 2im*I
420-
@test_throws ArgumentError Hermitian(X) - 2im*I
421444
end
422445

423446
@testset "Issue #21981" begin

stdlib/LinearAlgebra/test/uniformscaling.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ let
178178
@test @inferred(J - T) == J - Array(T)
179179
@test @inferred(T\I) == inv(T)
180180

181+
if isa(A, Array)
182+
T = Hermitian(randn(3,3))
183+
else
184+
T = Hermitian(view(randn(3,3), 1:3, 1:3))
185+
end
186+
@test @inferred(T + J) == Array(T) + J
187+
@test @inferred(J + T) == J + Array(T)
188+
@test @inferred(T - J) == Array(T) - J
189+
@test @inferred(J - T) == J - Array(T)
190+
181191
@test @inferred(I\A) == A
182192
@test @inferred(A\I) == inv(A)
183193
@test @inferred\I) === UniformScaling(1/λ)

0 commit comments

Comments
 (0)