Skip to content

Commit 75081bc

Browse files
authored
Fix performance trap for sparse view multiplication (#476)
1 parent 63459e5 commit 75081bc

File tree

3 files changed

+39
-29
lines changed

3 files changed

+39
-29
lines changed

src/linalg.jl

+15-15
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@ const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion}
1010
const DenseInputVector = Union{StridedVector, BitVector}
1111
const DenseVecOrMat = Union{DenseMatrixUnion, DenseInputVector}
1212

13-
matprod_dest(A::SparseMatrixCSCUnion, B::DenseTriangular, TS) =
13+
matprod_dest(A::SparseMatrixCSCUnion2, B::DenseTriangular, TS) =
1414
similar(B, TS, (size(A, 1), size(B, 2)))
15-
matprod_dest(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::DenseTriangular, TS) =
15+
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, B::DenseTriangular, TS) =
1616
similar(B, TS, (size(A, 1), size(B, 2)))
17-
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion, TS) =
17+
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion2, TS) =
1818
similar(A, TS, (size(A, 1), size(B, 2)))
19-
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion, TS) =
19+
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion2, TS) =
2020
similar(A, TS, (size(A, 1), size(B, 2)))
21-
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion, TS) =
21+
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion2, TS) =
2222
similar(A, TS, (size(A, 1), size(B, 2)))
23-
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
23+
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
2424
similar(A, TS, (size(A, 1), size(B, 2)))
25-
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
25+
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
2626
similar(A, TS, (size(A, 1), size(B, 2)))
27-
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
27+
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
2828
similar(A, TS, (size(A, 1), size(B, 2)))
2929

3030
for op (:+, :-), Wrapper (:Hermitian, :Symmetric)
@@ -45,11 +45,11 @@ for op ∈ (:+, :-)
4545
end
4646
end
4747

48-
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::DenseMatrixUnion, _add::MulAddMul) =
48+
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
4949
spdensemul!(C, tA, tB, A, B, _add)
50-
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::AbstractTriangular, _add::MulAddMul) =
50+
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
5151
spdensemul!(C, tA, tB, A, B, _add)
52-
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion, B::DenseInputVector, _add::MulAddMul) =
52+
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
5353
spdensemul!(C, tA, 'N', A, B, _add)
5454

5555
Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
@@ -114,7 +114,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
114114
C
115115
end
116116

117-
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
117+
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
118118
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
119119
if tB == 'N'
120120
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
@@ -125,7 +125,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::Strided
125125
end
126126
return C
127127
end
128-
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixCSC, α::Number, β::Number)
128+
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
129129
mX, nX = size(X)
130130
nX == size(A, 1) ||
131131
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
@@ -145,7 +145,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixC
145145
end
146146
C
147147
end
148-
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::AbstractSparseMatrixCSC, α::Number, β::Number)
148+
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
149149
mX, nX = size(X)
150150
nX == size(A, 1) ||
151151
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
@@ -164,7 +164,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::A
164164
C
165165
end
166166

167-
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::AbstractSparseMatrixCSC, α::Number, β::Number)
167+
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
168168
mA, nA = size(A)
169169
nA == size(B, 2) ||
170170
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))

src/sparsematrix.jl

+16-8
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,23 @@ end
179179
# underlying SparseMatrixCSC
180180
const SparseMatrixCSCView{Tv,Ti} =
181181
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
182-
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange}
182+
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange{<:Integer}}
183183
const SparseMatrixCSCUnion{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCView{Tv,Ti}}
184+
# Define an alias for views of a SparseMatrixCSC which include all rows and a selection of the columns.
185+
# Also define a union of SparseMatrixCSC and this view since many methods can be defined efficiently for
186+
# this union by extracting the fields via the get function: getrowval, and getnzval, BUT NOT getcolptr!
187+
const SparseMatrixCSCColumnSubset{Tv,Ti} =
188+
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
189+
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractVector{<:Integer}}
190+
const SparseMatrixCSCUnion2{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCColumnSubset{Tv,Ti}}
184191

185192
getcolptr(S::SorF) = getfield(S, :colptr)
186-
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(axes(S, 2)):(last(axes(S, 2)) + 1))
193+
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(S.indices[2]):(last(S.indices[2]) + 1))
194+
getcolptr(S::SparseMatrixCSCColumnSubset) = error("getcolptr not well-defined for $(typeof(S))")
187195
getrowval(S::AbstractSparseMatrixCSC) = rowvals(S)
188-
getrowval(S::SparseMatrixCSCView) = rowvals(parent(S))
196+
getrowval(S::SparseMatrixCSCColumnSubset) = rowvals(parent(S))
189197
getnzval( S::AbstractSparseMatrixCSC) = nonzeros(S)
190-
getnzval( S::SparseMatrixCSCView) = nonzeros(parent(S))
198+
getnzval( S::SparseMatrixCSCColumnSubset) = nonzeros(parent(S))
191199
nzvalview(S::AbstractSparseMatrixCSC) = view(nonzeros(S), 1:nnz(S))
192200

193201
"""
@@ -212,7 +220,7 @@ nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
212220
nnz(S::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
213221
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
214222
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
215-
nnz(S::SparseMatrixCSCView) = nnz1(S)
223+
nnz(S::SparseMatrixCSCColumnSubset) = nnz1(S)
216224
nnz1(S) = sum(length.(nzrange.(Ref(S), axes(S, 2))))
217225

218226
function Base._simple_count(pred, S::AbstractSparseMatrixCSC, init::T) where T
@@ -244,7 +252,7 @@ julia> nonzeros(A)
244252
```
245253
"""
246254
nonzeros(S::SorF) = getfield(S, :nzval)
247-
nonzeros(S::SparseMatrixCSCView) = nonzeros(S.parent)
255+
nonzeros(S::SparseMatrixCSCColumnSubset) = nonzeros(S.parent)
248256
nonzeros(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)
249257
nonzeros(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)
250258

@@ -272,7 +280,7 @@ julia> rowvals(A)
272280
```
273281
"""
274282
rowvals(S::SorF) = getfield(S, :rowval)
275-
rowvals(S::SparseMatrixCSCView) = rowvals(S.parent)
283+
rowvals(S::SparseMatrixCSCColumnSubset) = rowvals(S.parent)
276284
rowvals(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)
277285
rowvals(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)
278286

@@ -299,7 +307,7 @@ column. In conjunction with [`nonzeros`](@ref) and
299307
Adding or removing nonzero elements to the matrix may invalidate the `nzrange`, one should not mutate the matrix while iterating.
300308
"""
301309
nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
302-
nzrange(S::SparseMatrixCSCView, col::Integer) = nzrange(S.parent, S.indices[2][col])
310+
nzrange(S::SparseMatrixCSCColumnSubset, col::Integer) = nzrange(S.parent, S.indices[2][col])
303311
nzrange(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangeup(S.data, i)
304312
nzrange(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangelo(S.data, i)
305313

test/linalg.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -154,15 +154,15 @@ begin
154154
A = sprand(rng, n, n, 0.01)
155155
MA = Matrix(A)
156156
lA = sprand(rng, n, n+10, 0.01)
157+
@test nnz(lA[:, n+1:n+10]) == nnz(view(lA, :, n+1:n+10))
157158
@testset "triangular multiply with $tr($wr)" for tr in (identity, adjoint, transpose),
158159
wr in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
159160
AW = tr(wr(A))
160161
MAW = tr(wr(MA))
161162
@test AW * B MAW * B
162163
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
163-
vAW = tr(wr(view(A, :, 1:n)))
164-
vMAW = tr(wr(view(MA, :, 1:n)))
165-
@test vAW * B vMAW * B
164+
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
165+
@test vAW * B AW * B
166166
end
167167
a = sprand(rng, ComplexF64, n, n, 0.01)
168168
ma = Matrix(a)
@@ -172,9 +172,8 @@ begin
172172
MAW = tr(wr(ma))
173173
@test AW * B MAW * B
174174
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
175-
vAW = tr(wr(view(a, :, 1:n)))
176-
vMAW = tr(wr(view(ma, :, 1:n)))
177-
@test vAW * B vMAW * B
175+
vAW = tr(wr(view([zero(a)+I a], :, (n+1):2n)))
176+
@test vAW * B AW * B
178177
end
179178
A = A - Diagonal(diag(A)) + 2I # avoid rounding errors by division
180179
MA = Matrix(A)
@@ -183,6 +182,9 @@ begin
183182
AW = tr(wr(A))
184183
MAW = tr(wr(MA))
185184
@test AW \ B MAW \ B
185+
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
186+
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
187+
@test vAW \ B AW \ B
186188
end
187189
@testset "triangular singular exceptions" begin
188190
A = LowerTriangular(sparse([0 2.0;0 1]))

0 commit comments

Comments
 (0)