Skip to content

Commit 071fa65

Browse files
committed
add option for read-only off-diagonal elements
1 parent 9d03a80 commit 071fa65

File tree

2 files changed

+35
-16
lines changed

2 files changed

+35
-16
lines changed

src/PackedArrays.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ import LinearAlgebra: mul!, BLAS, BlasFloat, generic_matvecmul!, MulAddMul
77

88
export SymmetricPacked
99

10-
struct SymmetricPacked{T,S<:AbstractMatrix{<:T}} <: AbstractMatrix{T}
10+
struct SymmetricPacked{T,S<:AbstractMatrix{<:T},V} <: AbstractMatrix{T}
1111
tri::Vector{T}
1212
n::Int
1313
uplo::Char
1414

15-
function SymmetricPacked{T,S}(tri, n, uplo) where {T,S<:AbstractMatrix{<:T}}
15+
function SymmetricPacked{T,S,V}(tri, n, uplo) where {T,S<:AbstractMatrix{<:T},V}
1616
require_one_based_indexing(tri)
1717
uplo=='U' || uplo=='L' || throw(ArgumentError("uplo must be either 'U' (upper) or 'L' (lower)"))
18-
new{T,S}(tri, n, uplo)
18+
new{T,S,V}(tri, n, uplo)
1919
end
2020
end
2121

@@ -33,10 +33,11 @@ function pack(A::AbstractMatrix{T}, uplo::Symbol) where {T}
3333
end
3434

3535
"""
36-
SymmetricPacked(A, uplo=:U)
36+
SymmetricPacked(A, uplo=:U, offdiag=Val(:RO))
3737
3838
Construct a `Symmetric` matrix in packed form of the upper (if `uplo = :U`)
39-
or lower (if `uplo = :L`) triangle of the matrix `A`.
39+
or lower (if `uplo = :L`) triangle of the matrix `A`. `offdiag` specifies
40+
whether elements not on the diagaonal can be set (if `:RW`) or not (if `:RO`).
4041
4142
# Examples
4243
```jldoctest
@@ -63,20 +64,20 @@ julia> Base.summarysize(AP)
6364
184
6465
```
6566
"""
66-
function SymmetricPacked(A::AbstractMatrix{T}, uplo::Symbol=:U) where {T}
67+
function SymmetricPacked(A::AbstractMatrix{T}, uplo::Symbol=:U, offdiag=Val(:RO)) where {T}
6768
n = checksquare(A)
68-
SymmetricPacked{T,typeof(A)}(pack(A, uplo), n, char_uplo(uplo))
69+
SymmetricPacked{T,typeof(A),offdiag}(pack(A, uplo), n, char_uplo(uplo))
6970
end
7071

71-
function SymmetricPacked(x::SymmetricPacked{T,S}) where{T,S}
72-
SymmetricPacked{T,S}(T.(x.tri), x.n, x.uplo)
72+
function SymmetricPacked(x::SymmetricPacked{T,S,V}) where{T,S,V}
73+
SymmetricPacked{T,S,V}(T.(x.tri), x.n, x.uplo)
7374
end
7475

7576
checksquare(x::SymmetricPacked) = x.n
7677

77-
convert(::Type{SymmetricPacked{T,S}}, x::SymmetricPacked) where {T,S} = SymmetricPacked{T,S}(T.(x.tri), x.n, x.uplo)
78+
convert(::Type{SymmetricPacked{T,S,V}}, x::SymmetricPacked) where {T,S,V} = SymmetricPacked{T,S}(T.(x.tri), x.n, x.uplo)
7879

79-
unsafe_convert(::Type{Ptr{T}}, A::SymmetricPacked{T,S}) where {T,S} = Base.unsafe_convert(Ptr{T}, A.tri)
80+
unsafe_convert(::Type{Ptr{T}}, A::SymmetricPacked{T,S,V}) where {T,S,V} = Base.unsafe_convert(Ptr{T}, A.tri)
8081

8182
size(A::SymmetricPacked) = (A.n,A.n)
8283

@@ -97,7 +98,7 @@ end
9798
return r
9899
end
99100

100-
function setindex!(A::SymmetricPacked, v, i::Int, j::Int)
101+
function _setindex!(A::SymmetricPacked, v, i::Int, j::Int)
101102
@boundscheck checkbounds(A, i, j)
102103
if A.uplo=='U'
103104
i,j = minmax(i,j)
@@ -109,15 +110,22 @@ function setindex!(A::SymmetricPacked, v, i::Int, j::Int)
109110
return v
110111
end
111112

113+
function setindex!(A::SymmetricPacked{T,S,Val(:RO)}, v, i::Int, j::Int) where {T,S}
114+
i!=j && throw(ArgumentError("Cannot set a non-diagonal index in a symmetric matrix"))
115+
_setindex!(A, v, i, j)
116+
end
117+
118+
setindex!(A::SymmetricPacked{T,S,Val(:RW)}, v, i::Int, j::Int) where {T,S} = _setindex!(A, v, i, j)
119+
112120
function copy(A::SymmetricPacked{T,S}) where {T,S}
113121
B = copy(A.tri)
114122
SymmetricPacked{T,S}(B, A.n, A.uplo)
115123
end
116124

117125
@inline function mul!(y::StridedVector{T},
118-
AP::SymmetricPacked{T,<:StridedMatrix},
126+
AP::SymmetricPacked{T,<:StridedMatrix,V},
119127
x::StridedVector{T},
120-
α::Number, β::Number) where {T<:BlasFloat}
128+
α::Number, β::Number) where {T<:BlasFloat,V}
121129
alpha, beta = promote(α, β, zero(T))
122130
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
123131
BLAS.spmv!(AP.uplo, alpha, AP.tri, x, beta, y)

test/runtests.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using PackedArrays, Test, LinearAlgebra
33
A = collect(reshape(1:9.0,3,3))
44

55
@testset "upper triangle" begin
6-
APU = SymmetricPacked(A, :U)
6+
APU = SymmetricPacked(A, :U, Val(:RW))
77

88
@test APU[1,1] == A[1,1]
99
@test APU[1,2] == A[1,2]
@@ -20,7 +20,7 @@ A = collect(reshape(1:9.0,3,3))
2020
end
2121

2222
@testset "lower triangle" begin
23-
APL = SymmetricPacked(A, :L)
23+
APL = SymmetricPacked(A, :L, Val(:RW))
2424

2525
@test APL[1,1] == A[1,1]
2626
@test APL[1,2] == A[2,1]
@@ -36,6 +36,17 @@ end
3636
@test APL[2,1] == 0
3737
end
3838

39+
@testset "read-only" begin
40+
APL = SymmetricPacked(A, :L)
41+
APL[1,1]=3
42+
@test APL[1,1] == 3
43+
@test_throws ArgumentError APL[1,3]=3
44+
APL = SymmetricPacked(A, :L, Val(:RO))
45+
APL[1,1]=3
46+
@test APL[1,1] == 3
47+
@test_throws ArgumentError APL[1,3]=3
48+
end
49+
3950
@testset "mul!" begin
4051
for uplo in [:U, :L]
4152
y = Float64[1,2,3]

0 commit comments

Comments
 (0)