Skip to content

Commit 2de5dab

Browse files
c42fandreasnoack
authored andcommitted
Make svdvals(::Matrix{<:Complex}) type inferrable (#22443)
* Make svdvals(::Matrix{<:Complex}) type inferrable Ensure that svdvals(zeros(Complex128,0,0)) returns a complex real matrix to avoid type instability. Also add some simplistic but explicit tests for svdvals and svdfact, including ensuring this case is inferred. * Add unitarity test for SVD U and Vt parts. * Use \approxeq in tests instead of clobbering \approx
1 parent 236e486 commit 2de5dab

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

base/linalg/svd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ end
128128
Returns the singular values of `A`, saving space by overwriting the input.
129129
See also [`svdvals`](@ref).
130130
"""
131-
svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = findfirst(size(A), 0) > 0 ? zeros(T, 0) : LAPACK.gesdd!('N', A)[2]
131+
svdvals!(A::StridedMatrix{T}) where {T<:BlasFloat} = isempty(A) ? zeros(real(T), 0) : LAPACK.gesdd!('N', A)[2]
132132
svdvals(A::AbstractMatrix{<:BlasFloat}) = svdvals!(copy(A))
133133

134134
"""

test/linalg/svd.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,31 @@ using Base.Test
44

55
using Base.LinAlg: BlasComplex, BlasFloat, BlasReal, QRPivoted
66

7+
@testset "Simple svdvals / svdfact tests" begin
8+
(x,y) = isapprox(x,y,rtol=1e-15)
9+
10+
m1 = [2 0; 0 0]
11+
m2 = [2 -2; 1 1]/sqrt(2)
12+
m2c = Complex.([2 -2; 1 1]/sqrt(2))
13+
@test @inferred(svdvals(m1)) [2, 0]
14+
@test @inferred(svdvals(m2)) [2, 1]
15+
@test @inferred(svdvals(m2c)) [2, 1]
16+
17+
sf1 = svdfact(m1)
18+
sf2 = svdfact(m2)
19+
@test sf1.S [2, 0]
20+
@test sf2.S [2, 1]
21+
# U & Vt are unitary
22+
@test sf1.U*sf1.U' eye(2)
23+
@test sf1.Vt*sf1.Vt' eye(2)
24+
@test sf2.U*sf2.U' eye(2)
25+
@test sf2.Vt*sf2.Vt' eye(2)
26+
# SVD not uniquely determined, so just test we can reconstruct the
27+
# matrices from the factorization as expected.
28+
@test sf1.U*Diagonal(sf1.S)*sf1.Vt' m1
29+
@test sf2.U*Diagonal(sf2.S)*sf2.Vt' m2
30+
end
31+
732
n = 10
833

934
# Split n into 2 parts for tests needing two matrices

0 commit comments

Comments
 (0)