Skip to content

Commit ce8da51

Browse files
committed
Improve inferrability of eigen()
Introduce `_make_static(::Eigen)` to allow the compiler to infer the output type of eigen() as a Union of concrete types. This seems necessary because the compiler's Union-splitting can't act on several statements at once, so pulling out the values and vectors fields of the Eigen individually will result in a pair of Unions which the compiler considers uncorrellated. They are correllated however and the _make_static function barrier allows the compiler to understand this. Also return an Eigen from all the other utility functions for convenience.
1 parent c0d7c5e commit ce8da51

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

src/eigen.jl

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,33 @@ end
123123
return SVector{s[1], T}(vals)
124124
end
125125

126+
# Utility to rewrap `Eigen` of normal `Array` into an Eigen containing `SArray`.
127+
@inline function _make_static(s::Size, E::Eigen{T,V}) where {T,V}
128+
Eigen(similar_type(SVector, V, Size(s[1]))(E.values),
129+
similar_type(SMatrix, T, s)(E.vectors))
130+
end
126131

127132
@inline function _eig(s::Size, A::T, permute, scale) where {T <: StaticMatrix}
128-
# For the non-hermitian branch, fall back to LinearAlgebra
129133
if ishermitian(A)
130-
eivals, eivecs = _eig(s, Hermitian(A), permute, scale)
131-
return eivals, eivecs
134+
return _eig(s, Hermitian(A), permute, scale)
132135
else
133-
eivals, eivecs = eigen(Array(A); permute = permute, scale = scale)
134-
return SVector{s[1]}(eivals), SMatrix{s[1],s[2]}(eivecs)
136+
# For the non-hermitian branch fall back to LinearAlgebra eigen().
137+
# Eigenvalues could be real or complex so a Union of concrete types is
138+
# inferred. Having _make_static a separate function allows inference to
139+
# preserve the union of concrete types:
140+
# Union{E{A,B},E{C,D}} -> Union{E{SA,SB},E{SC,SD}}
141+
_make_static(s, eigen(Array(A); permute = permute, scale = scale))
135142
end
136143
end
137144

138145
@inline function _eig(s::Size, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
139146
E = eigen(Hermitian(Array(parent(A))))
140-
return (SVector{s[1], T}(E.values), SMatrix{s[1], s[2], eltype(A)}(E.vectors))
147+
return Eigen(SVector{s[1], T}(E.values), SMatrix{s[1], s[2], eltype(A)}(E.vectors))
141148
end
142149

143150

144151
@inline function _eig(::Size{(1,1)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
145-
@inbounds return (SVector{1,T}((real(A[1]),)), SMatrix{1,1,eltype(A)}(I))
152+
@inbounds return Eigen(SVector{1,T}((real(A[1]),)), SMatrix{1,1,eltype(A)}(I))
146153
end
147154

148155
@inline function _eig(::Size{(2,2)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
@@ -171,7 +178,7 @@ end
171178
vecs = @SMatrix [ v11 v21 ;
172179
v12 v22 ]
173180

174-
return (vals, vecs)
181+
return Eigen(vals, vecs)
175182
end
176183
else # A.uplo == 'L'
177184
if !iszero(a[2]) # A is not diagonal
@@ -195,7 +202,7 @@ end
195202
vecs = @SMatrix [ v11 v21 ;
196203
v12 v22 ]
197204

198-
return (vals,vecs)
205+
return Eigen(vals,vecs)
199206
end
200207
end
201208

@@ -211,7 +218,7 @@ end
211218
vecs = @SMatrix [convert(TA, 0) convert(TA, 1);
212219
convert(TA, 1) convert(TA, 0)]
213220
end
214-
return (vals,vecs)
221+
return Eigen(vals,vecs)
215222
end
216223

217224
# A small part of the code in the following method was inspired by works of David
@@ -244,19 +251,19 @@ end
244251

245252
if a11 < a22
246253
if a22 < a33
247-
return (SVector(a11, a22, a33), hcat(v1,v2,v3))
254+
return Eigen(SVector((a11, a22, a33)), hcat(v1,v2,v3))
248255
elseif a33 < a11
249-
return (SVector(a33, a11, a22), hcat(v3,v1,v2))
256+
return Eigen(SVector((a33, a11, a22)), hcat(v3,v1,v2))
250257
else
251-
return (SVector(a11, a33, a22), hcat(v1,v3,v2))
258+
return Eigen(SVector((a11, a33, a22)), hcat(v1,v3,v2))
252259
end
253260
else #a22 < a11
254261
if a11 < a33
255-
return (SVector(a22, a11, a33), hcat(v2,v1,v3))
262+
return Eigen(SVector((a22, a11, a33)), hcat(v2,v1,v3))
256263
elseif a33 < a22
257-
return (SVector(a33, a22, a11), hcat(v3,v2,v1))
264+
return Eigen(SVector((a33, a22, a11)), hcat(v3,v2,v1))
258265
else
259-
return (SVector(a22, a33, a11), hcat(v2,v3,v1))
266+
return Eigen(SVector((a22, a33, a11)), hcat(v2,v3,v1))
260267
end
261268
end
262269
end
@@ -393,12 +400,11 @@ end
393400
(eigvec1, eigvec3) = (eigvec3, eigvec1)
394401
end
395402

396-
return (SVector(eig1, eig2, eig3), hcat(eigvec1, eigvec2, eigvec3))
403+
return Eigen(SVector(eig1, eig2, eig3), hcat(eigvec1, eigvec2, eigvec3))
397404
end
398405

399406
@inline function eigen(A::StaticMatrix; permute::Bool=true, scale::Bool=true)
400-
vals, vecs = _eig(Size(A), A, permute, scale)
401-
return Eigen(vals, vecs)
407+
_eig(Size(A), A, permute, scale)
402408
end
403409

404410
# to avoid method ambiguity with LinearAlgebra
@@ -408,8 +414,7 @@ end
408414
@inline eigen(A::Symmetric{<:Complex,<:StaticMatrix}; kwargs...) = _eigen(A; kwargs...)
409415

410416
@inline function _eigen(A::LinearAlgebra.HermOrSym; permute::Bool=true, scale::Bool=true)
411-
vals, vecs = _eig(Size(A), A, permute, scale)
412-
return Eigen(vals, vecs)
417+
_eig(Size(A), A, permute, scale)
413418
end
414419

415420
# NOTE: The following Boost Software License applies to parts of the method:

test/eigen.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ using StaticArrays, Test, LinearAlgebra
254254
@inferred eigen(Hermitian(m))
255255
@inferred eigen(Symmetric(m))
256256

257+
# Test that general eigen() gives a small union of concrete types
258+
SEigen{T} = Eigen{T, T, SArray{Tuple{n,n},T,2,n*n}, SArray{Tuple{n},T,1,n}}
259+
@inferred_maybe_allow Union{SEigen{ComplexF64},SEigen{Float64}} eigen(m)
260+
257261
mc = @SMatrix randn(ComplexF64, n, n)
258262
@inferred eigen(Hermitian(mc + mc'))
259263
end

0 commit comments

Comments
 (0)