Skip to content

Commit 85812fe

Browse files
authored
Merge pull request #683 from CarpeNecopinum/eigen_fallback
Handle non-hermitian matrices in eigen
2 parents 57eab2f + ce8da51 commit 85812fe

File tree

2 files changed

+72
-24
lines changed

2 files changed

+72
-24
lines changed

src/eigen.jl

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,25 +123,33 @@ end
123123
return SVector{s[1], T}(vals)
124124
end
125125

126+
# Utility to rewrap `Eigen` of normal `Array` into an Eigen containing `SArray`.
127+
@inline function _make_static(s::Size, E::Eigen{T,V}) where {T,V}
128+
Eigen(similar_type(SVector, V, Size(s[1]))(E.values),
129+
similar_type(SMatrix, T, s)(E.vectors))
130+
end
126131

127-
@inline function _eig(s::Size, A::StaticMatrix, permute, scale)
128-
# Only cover the hermitian branch, for now at least
129-
# This also solves some type-stability issues that arise in Base
132+
@inline function _eig(s::Size, A::T, permute, scale) where {T <: StaticMatrix}
130133
if ishermitian(A)
131-
return _eig(s, Hermitian(A), permute, scale)
134+
return _eig(s, Hermitian(A), permute, scale)
132135
else
133-
error("Only hermitian matrices are diagonalizable by *StaticArrays*. Non-Hermitian matrices should be converted to `Array` first.")
136+
# For the non-hermitian branch fall back to LinearAlgebra eigen().
137+
# Eigenvalues could be real or complex so a Union of concrete types is
138+
# inferred. Having _make_static a separate function allows inference to
139+
# preserve the union of concrete types:
140+
# Union{E{A,B},E{C,D}} -> Union{E{SA,SB},E{SC,SD}}
141+
_make_static(s, eigen(Array(A); permute = permute, scale = scale))
134142
end
135143
end
136144

137145
@inline function _eig(s::Size, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
138146
E = eigen(Hermitian(Array(parent(A))))
139-
return (SVector{s[1], T}(E.values), SMatrix{s[1], s[2], eltype(A)}(E.vectors))
147+
return Eigen(SVector{s[1], T}(E.values), SMatrix{s[1], s[2], eltype(A)}(E.vectors))
140148
end
141149

142150

143151
@inline function _eig(::Size{(1,1)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
144-
@inbounds return (SVector{1,T}((real(A[1]),)), SMatrix{1,1,eltype(A)}(I))
152+
@inbounds return Eigen(SVector{1,T}((real(A[1]),)), SMatrix{1,1,eltype(A)}(I))
145153
end
146154

147155
@inline function _eig(::Size{(2,2)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
@@ -170,7 +178,7 @@ end
170178
vecs = @SMatrix [ v11 v21 ;
171179
v12 v22 ]
172180

173-
return (vals, vecs)
181+
return Eigen(vals, vecs)
174182
end
175183
else # A.uplo == 'L'
176184
if !iszero(a[2]) # A is not diagonal
@@ -194,7 +202,7 @@ end
194202
vecs = @SMatrix [ v11 v21 ;
195203
v12 v22 ]
196204

197-
return (vals,vecs)
205+
return Eigen(vals,vecs)
198206
end
199207
end
200208

@@ -210,7 +218,7 @@ end
210218
vecs = @SMatrix [convert(TA, 0) convert(TA, 1);
211219
convert(TA, 1) convert(TA, 0)]
212220
end
213-
return (vals,vecs)
221+
return Eigen(vals,vecs)
214222
end
215223

216224
# A small part of the code in the following method was inspired by works of David
@@ -243,19 +251,19 @@ end
243251

244252
if a11 < a22
245253
if a22 < a33
246-
return (SVector(a11, a22, a33), hcat(v1,v2,v3))
254+
return Eigen(SVector((a11, a22, a33)), hcat(v1,v2,v3))
247255
elseif a33 < a11
248-
return (SVector(a33, a11, a22), hcat(v3,v1,v2))
256+
return Eigen(SVector((a33, a11, a22)), hcat(v3,v1,v2))
249257
else
250-
return (SVector(a11, a33, a22), hcat(v1,v3,v2))
258+
return Eigen(SVector((a11, a33, a22)), hcat(v1,v3,v2))
251259
end
252260
else #a22 < a11
253261
if a11 < a33
254-
return (SVector(a22, a11, a33), hcat(v2,v1,v3))
262+
return Eigen(SVector((a22, a11, a33)), hcat(v2,v1,v3))
255263
elseif a33 < a22
256-
return (SVector(a33, a22, a11), hcat(v3,v2,v1))
264+
return Eigen(SVector((a33, a22, a11)), hcat(v3,v2,v1))
257265
else
258-
return (SVector(a22, a33, a11), hcat(v2,v3,v1))
266+
return Eigen(SVector((a22, a33, a11)), hcat(v2,v3,v1))
259267
end
260268
end
261269
end
@@ -392,12 +400,11 @@ end
392400
(eigvec1, eigvec3) = (eigvec3, eigvec1)
393401
end
394402

395-
return (SVector(eig1, eig2, eig3), hcat(eigvec1, eigvec2, eigvec3))
403+
return Eigen(SVector(eig1, eig2, eig3), hcat(eigvec1, eigvec2, eigvec3))
396404
end
397405

398406
@inline function eigen(A::StaticMatrix; permute::Bool=true, scale::Bool=true)
399-
vals, vecs = _eig(Size(A), A, permute, scale)
400-
return Eigen(vals, vecs)
407+
_eig(Size(A), A, permute, scale)
401408
end
402409

403410
# to avoid method ambiguity with LinearAlgebra
@@ -407,8 +414,7 @@ end
407414
@inline eigen(A::Symmetric{<:Complex,<:StaticMatrix}; kwargs...) = _eigen(A; kwargs...)
408415

409416
@inline function _eigen(A::LinearAlgebra.HermOrSym; permute::Bool=true, scale::Bool=true)
410-
vals, vecs = _eig(Size(A), A, permute, scale)
411-
return Eigen(vals, vecs)
417+
_eig(Size(A), A, permute, scale)
412418
end
413419

414420
# NOTE: The following Boost Software License applies to parts of the method:

test/eigen.jl

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,6 @@ using StaticArrays, Test, LinearAlgebra
232232
@test vals::SVector sort(m_d)
233233
@test eigvals(m) sort(m_d)
234234
@test eigvals(Hermitian(m)) sort(m_d)
235-
236-
# not Hermitian
237-
@test_throws Exception eigen(@SMatrix randn(4,4))
238235
end
239236

240237
@testset "complex" begin
@@ -248,4 +245,49 @@ using StaticArrays, Test, LinearAlgebra
248245
@test V'*A*V diagm(Val(0) => D)
249246
end
250247
end
248+
249+
@testset "hermitian type stability" begin
250+
for n=1:4
251+
m = @SMatrix randn(n,n)
252+
m += m'
253+
254+
@inferred eigen(Hermitian(m))
255+
@inferred eigen(Symmetric(m))
256+
257+
# Test that general eigen() gives a small union of concrete types
258+
SEigen{T} = Eigen{T, T, SArray{Tuple{n,n},T,2,n*n}, SArray{Tuple{n},T,1,n}}
259+
@inferred_maybe_allow Union{SEigen{ComplexF64},SEigen{Float64}} eigen(m)
260+
261+
mc = @SMatrix randn(ComplexF64, n, n)
262+
@inferred eigen(Hermitian(mc + mc'))
263+
end
264+
end
265+
266+
@testset "non-hermitian 2d" begin
267+
for n=1:5
268+
angle = 2π * rand()
269+
rot = @SMatrix [cos(angle) -sin(angle); sin(angle) cos(angle)]
270+
271+
vals, vecs = eigen(rot)
272+
273+
@test norm(vals[1]) 1.0
274+
@test norm(vals[2]) 1.0
275+
276+
@test vecs[:,1] conj.(vecs[:,2])
277+
end
278+
end
279+
280+
@testset "non-hermitian 3d" begin
281+
for n=1:5
282+
angle = 2π * rand()
283+
rot = @SMatrix [cos(angle) 0.0 -sin(angle); 0.0 1.0 0.0; sin(angle) 0.0 cos(angle)]
284+
285+
vals, vecs = eigen(rot)
286+
287+
@test norm(vals[1]) 1.0
288+
@test norm(vals[2]) 1.0
289+
290+
@test vecs[:,1] conj.(vecs[:,2])
291+
end
292+
end
251293
end

0 commit comments

Comments
 (0)