Skip to content

Commit 2a74b04

Browse files
authored
Merge pull request #397 from martinholters/mh/diagm
diagm 0.7-style
2 parents 9d80122 + b965434 commit 2a74b04

File tree

4 files changed

+57
-44
lines changed

4 files changed

+57
-44
lines changed

src/linalg.jl

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,24 +182,30 @@ end
182182
end
183183
#end
184184

185-
@inline diagm(v::StaticVector, k::Type{Val{D}}=Val{0}) where {D} = _diagm(Size(v), v, k)
186-
@generated function _diagm(::Size{S}, v::StaticVector, ::Type{Val{D}}) where {S,D}
187-
S1 = S[1]
188-
Snew1 = S1+abs(D)
189-
Snew = (Snew1, Snew1)
190-
Lnew = Snew1 * Snew1
191-
T = eltype(v)
192-
ind = diagind(Snew1, Snew1, D)
193-
exprs = fill(:(zero($T)), Lnew)
194-
for n = 1:S[1]
195-
exprs[ind[n]] = :(v[$n])
185+
@generated function diagm(kvs::Pair{<:Val,<:StaticVector}...)
186+
N = maximum(abs(kv.parameters[1].parameters[1]) + length(kv.parameters[2]) for kv in kvs)
187+
X = [Symbol("x_$(i)_$(j)") for i in 1:N, j in 1:N]
188+
T = promote_type((eltype(kv.parameters[2]) for kv in kvs)...)
189+
exprs = fill(:(zero($T)), N*N)
190+
for m in eachindex(kvs)
191+
kv = kvs[m]
192+
ind = diagind(N, N, kv.parameters[1].parameters[1])
193+
for n = 1:length(kv.parameters[2])
194+
exprs[ind[n]] = :(kvs[$m].second[$n])
195+
end
196196
end
197197
return quote
198198
$(Expr(:meta, :inline))
199-
@inbounds return similar_type($v, Size($Snew))(tuple($(exprs...)))
199+
@inbounds return SMatrix{$N,$N,$T}(tuple($(exprs...)))
200200
end
201201
end
202202

203+
if VERSION < v"v0.7.0-DEV.2161"
204+
@inline diagm(v::StaticVector, k::Type{Val{D}}=Val{0}) where {D} = diagm(k() => v)
205+
else
206+
@deprecate(diagm(v::StaticVector, k::Type{Val{D}}=Val{0}) where {D}, diagm(k() => v))
207+
end
208+
203209
@inline diag(m::StaticMatrix, k::Type{Val{D}}=Val{0}) where {D} = _diag(Size(m), m, k)
204210
@generated function _diag(::Size{S}, m::StaticMatrix, ::Type{Val{D}}) where {S,D}
205211
S1, S2 = S

test/SDiagonal.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222

2323
@test StaticArrays.scalem(@SMatrix([1 1 1;1 1 1; 1 1 1]), @SVector [1,2,3]) === @SArray [1 2 3; 1 2 3; 1 2 3]
2424
@test StaticArrays.scalem(@SVector([1,2,3]),@SMatrix [1 1 1;1 1 1; 1 1 1])' === @SArray [1 2 3; 1 2 3; 1 2 3]
25-
26-
m = SDiagonal(@SVector [11, 12, 13, 14])
27-
28-
@test diag(m) === m.diag
29-
30-
m2 = diagm([11, 12, 13, 14])
31-
25+
26+
m = SDiagonal(@SVector [11, 12, 13, 14])
27+
28+
@test diag(m) === m.diag
29+
30+
m2 = diagm(0 => [11, 12, 13, 14])
31+
3232
@test logdet(m) == logdet(m2)
3333
@test logdet(im*m) logdet(im*m2)
3434
@test det(m) == det(m2)
@@ -113,7 +113,7 @@
113113
@test m\[1; 1; 1; 1] == [11; 12; 13; 14].\[1; 1; 1; 1]
114114
@test SMatrix{4,4}(Matrix{Float64}(I, 4, 4))*m == m
115115
@test m*SMatrix{4,4}(Matrix{Float64}(I, 4, 4)) == m
116-
@test SMatrix{4,4}(Matrix{Float64}(I, 4, 4))/m == diagm([11; 12; 13; 14].\[1; 1; 1; 1])
117-
@test m\SMatrix{4,4}(Matrix{Float64}(I, 4, 4)) == diagm([11; 12; 13; 14].\[1; 1; 1; 1])
116+
@test SMatrix{4,4}(Matrix{Float64}(I, 4, 4))/m == diagm(0 => [11; 12; 13; 14].\[1; 1; 1; 1])
117+
@test m\SMatrix{4,4}(Matrix{Float64}(I, 4, 4)) == diagm(0 => [11; 12; 13; 14].\[1; 1; 1; 1])
118118
end
119119
end

test/eigen.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,35 +30,35 @@
3030
(vals, vecs) = eig(m)
3131
@test vals::SVector vals_a
3232
@test eigvals(m) vals
33-
@test (vecs*diagm(vals)*vecs')::SMatrix m
33+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
3434
ef = eigfact(m)
3535
@test ef[:values]::SVector vals_a
36-
@test (ef[:vectors]*diagm(vals)*ef[:vectors]')::SMatrix m
36+
@test (ef[:vectors]*diagm(Val(0) => vals)*ef[:vectors]')::SMatrix m
3737

3838
(vals, vecs) = eig(Symmetric(m))
3939
@test vals::SVector vals_a
4040
@test eigvals(m) vals
41-
@test (vecs*diagm(vals)*vecs')::SMatrix m
41+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
4242
ef = eigfact(Symmetric(m))
4343
@test ef[:values]::SVector vals_a
44-
@test (ef[:vectors]*diagm(vals)*ef[:vectors]')::SMatrix m
44+
@test (ef[:vectors]*diagm(Val(0) => vals)*ef[:vectors]')::SMatrix m
4545
ef = eigfact(Symmetric(m, :L))
4646
@test ef[:values]::SVector vals_a
47-
@test (ef[:vectors]*diagm(vals)*ef[:vectors]')::SMatrix m
47+
@test (ef[:vectors]*diagm(Val(0) => vals)*ef[:vectors]')::SMatrix m
4848

4949
(vals, vecs) = eig(Hermitian(m))
5050
@test vals::SVector vals_a
5151
@test eigvals(Hermitian(m)) vals
5252
@test eigvals(Hermitian(m, :L)) vals
53-
@test (vecs*diagm(vals)*vecs')::SMatrix m
53+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
5454
ef = eigfact(Hermitian(m))
5555
@test ef[:values]::SVector vals_a
56-
@test (ef[:vectors]*diagm(vals)*ef[:vectors]')::SMatrix m
56+
@test (ef[:vectors]*diagm(Val(0) => vals)*ef[:vectors]')::SMatrix m
5757
ef = eigfact(Hermitian(m, :L))
5858
@test ef[:values]::SVector vals_a
59-
@test (ef[:vectors]*diagm(vals)*ef[:vectors]')::SMatrix m
59+
@test (ef[:vectors]*diagm(Val(0) => vals)*ef[:vectors]')::SMatrix m
6060

61-
m_d = randn(SVector{2}); m = diagm(m_d)
61+
m_d = randn(SVector{2}); m = diagm(Val(0) => m_d)
6262
(vals, vecs) = eig(Hermitian(m))
6363
@test vals::SVector sort(m_d)
6464
(vals, vecs) = eig(Hermitian(m, :L))
@@ -76,19 +76,19 @@
7676
(vals, vecs) = eig(m)
7777
@test vals::SVector vals_a
7878
@test eigvals(m) vals
79-
@test (vecs*diagm(vals)*vecs')::SMatrix m
79+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
8080

8181
(vals, vecs) = eig(Symmetric(m))
8282
@test vals::SVector vals_a
8383
@test eigvals(m) vals
8484
@test eigvals(Hermitian(m)) vals
8585
@test eigvals(Hermitian(m, :L)) vals
86-
@test (vecs*diagm(vals)*vecs')::SMatrix m
86+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
8787

8888
(vals, vecs) = eig(Symmetric(m, :L))
8989
@test vals::SVector vals_a
9090

91-
m_d = randn(SVector{3}); m = diagm(m_d)
91+
m_d = randn(SVector{3}); m = diagm(Val(0) => m_d)
9292
(vals, vecs) = eig(Hermitian(m))
9393
@test vals::SVector sort(m_d)
9494
(vals, vecs) = eig(Hermitian(m, :L))
@@ -146,7 +146,7 @@
146146
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
147147

148148
@test vals [0.0, 1.0, 2.0]
149-
@test vecs*diagm(vals)*vecs' m
149+
@test vecs*diagm(Val(0) => vals)*vecs' m
150150
@test eigvals(m) vals
151151

152152
m = @SMatrix [1.0 0.0 1.0;
@@ -155,7 +155,7 @@
155155
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
156156

157157
@test vals [0.0, 1.0, 2.0]
158-
@test vecs*diagm(vals)*vecs' m
158+
@test vecs*diagm(Val(0) => vals)*vecs' m
159159
@test eigvals(m) vals
160160

161161
m = @SMatrix [1.0 1.0 0.0;
@@ -164,7 +164,7 @@
164164
vals, vecs = eig(m)::Tuple{SVector,SMatrix}
165165

166166
@test vals [0.0, 1.0, 2.0]
167-
@test vecs*diagm(vals)*vecs' m
167+
@test vecs*diagm(Val(0) => vals)*vecs' m
168168
@test eigvals(m) vals
169169
end
170170

@@ -177,18 +177,18 @@
177177
(vals, vecs) = eig(m)
178178
@test vals::SVector vals_a
179179
@test eigvals(m) vals
180-
@test (vecs*diagm(vals)*vecs')::SMatrix m
180+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
181181

182182
(vals, vecs) = eig(Symmetric(m))
183183
@test vals::SVector vals_a
184184
@test eigvals(m) vals
185185
@test eigvals(Hermitian(m)) vals
186186
@test eigvals(Hermitian(m, :L)) vals
187-
@test (vecs*diagm(vals)*vecs')::SMatrix m
187+
@test (vecs*diagm(Val(0) => vals)*vecs')::SMatrix m
188188

189189
(vals, vecs) = eig(Symmetric(m, :L))
190190
@test vals::SVector vals_a
191-
m_d = randn(SVector{4}); m = diagm(m_d)
191+
m_d = randn(SVector{4}); m = diagm(Val(0) => m_d)
192192
(vals, vecs) = eig(Hermitian(m))
193193
@test vals::SVector sort(m_d)
194194
(vals, vecs) = eig(Hermitian(m, :L))
@@ -207,8 +207,8 @@
207207
A = Hermitian(SMatrix{n,n}(a))
208208
D,V = eig(A)
209209
@test V'V eye(n)
210-
@test V*diagm(D)*V' A
211-
@test V'*A*V diagm(D)
210+
@test V*diagm(Val(0) => D)*V' A
211+
@test V'*A*V diagm(Val(0) => D)
212212
end
213213
end
214214
end

test/linalg.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,16 @@ using StaticArrays, Base.Test
4040
end
4141

4242
@testset "diagm()" begin
43-
@test @inferred(diagm(SVector(1,2))) === @SMatrix [1 0; 0 2]
44-
@test @inferred(diagm(SVector(1,2,3), Val{2}))::SMatrix == diagm([1,2,3], 2)
45-
@test @inferred(diagm(SVector(1,2,3), Val{-2}))::SMatrix == diagm([1,2,3], -2)
43+
@test @inferred(diagm(Val(0) => SVector(1,2))) === @SMatrix [1 0; 0 2]
44+
@test @inferred(diagm(Val(2) => SVector(1,2,3)))::SMatrix == diagm(2 => [1,2,3])
45+
@test @inferred(diagm(Val(-2) => SVector(1,2,3)))::SMatrix == diagm(-2 => [1,2,3])
46+
@test @inferred(diagm(Val(-2) => SVector(1,2,3), Val(1) => SVector(4,5)))::SMatrix == diagm(-2 => [1,2,3], 1 => [4,5])
47+
if VERSION < v"0.7.0-DEV.2161"
48+
# old interface, deprecated in Julia 0.7
49+
@test @inferred(diagm(SVector(1,2))) === @SMatrix [1 0; 0 2]
50+
@test @inferred(diagm(SVector(1,2,3), Val{2}))::SMatrix == diagm(2 => [1,2,3])
51+
@test @inferred(diagm(SVector(1,2,3), Val{-2}))::SMatrix == diagm(-2 => [1,2,3])
52+
end
4653
end
4754

4855
@testset "diag()" begin

0 commit comments

Comments
 (0)