Skip to content

Commit

Permalink
Fix performance trap for sparse view multiplication (#476)
Browse files Browse the repository at this point in the history
  • Loading branch information
dkarrasch authored Jan 7, 2024
1 parent 63459e5 commit 75081bc
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 29 deletions.
30 changes: 15 additions & 15 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion}
const DenseInputVector = Union{StridedVector, BitVector}
const DenseVecOrMat = Union{DenseMatrixUnion, DenseInputVector}

matprod_dest(A::SparseMatrixCSCUnion, B::DenseTriangular, TS) =
matprod_dest(A::SparseMatrixCSCUnion2, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::DenseTriangular, TS) =
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))

for op (:+, :-), Wrapper (:Hermitian, :Symmetric)
Expand All @@ -45,11 +45,11 @@ for op ∈ (:+, :-)
end
end

LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::DenseMatrixUnion, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::AbstractTriangular, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion, B::DenseInputVector, _add::MulAddMul) =
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
Expand Down Expand Up @@ -114,7 +114,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
C
end

Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
Expand All @@ -125,7 +125,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::Strided
end
return C
end
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -145,7 +145,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixC
end
C
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -164,7 +164,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::A
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::AbstractSparseMatrixCSC, α::Number, β::Number)
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
Expand Down
24 changes: 16 additions & 8 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,23 @@ end
# underlying SparseMatrixCSC
const SparseMatrixCSCView{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange}
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange{<:Integer}}
const SparseMatrixCSCUnion{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCView{Tv,Ti}}
# Define an alias for views of a SparseMatrixCSC which include all rows and a selection of the columns.
# Also define a union of SparseMatrixCSC and this view since many methods can be defined efficiently for
# this union by extracting the fields via the get function: getrowval, and getnzval, BUT NOT getcolptr!
const SparseMatrixCSCColumnSubset{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractVector{<:Integer}}
const SparseMatrixCSCUnion2{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCColumnSubset{Tv,Ti}}

getcolptr(S::SorF) = getfield(S, :colptr)
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(axes(S, 2)):(last(axes(S, 2)) + 1))
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(S.indices[2]):(last(S.indices[2]) + 1))
getcolptr(S::SparseMatrixCSCColumnSubset) = error("getcolptr not well-defined for $(typeof(S))")
getrowval(S::AbstractSparseMatrixCSC) = rowvals(S)
getrowval(S::SparseMatrixCSCView) = rowvals(parent(S))
getrowval(S::SparseMatrixCSCColumnSubset) = rowvals(parent(S))
getnzval( S::AbstractSparseMatrixCSC) = nonzeros(S)
getnzval( S::SparseMatrixCSCView) = nonzeros(parent(S))
getnzval( S::SparseMatrixCSCColumnSubset) = nonzeros(parent(S))
nzvalview(S::AbstractSparseMatrixCSC) = view(nonzeros(S), 1:nnz(S))

"""
Expand All @@ -212,7 +220,7 @@ nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::SparseMatrixCSCView) = nnz1(S)
nnz(S::SparseMatrixCSCColumnSubset) = nnz1(S)
nnz1(S) = sum(length.(nzrange.(Ref(S), axes(S, 2))))

function Base._simple_count(pred, S::AbstractSparseMatrixCSC, init::T) where T
Expand Down Expand Up @@ -244,7 +252,7 @@ julia> nonzeros(A)
```
"""
nonzeros(S::SorF) = getfield(S, :nzval)
nonzeros(S::SparseMatrixCSCView) = nonzeros(S.parent)
nonzeros(S::SparseMatrixCSCColumnSubset) = nonzeros(S.parent)
nonzeros(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)
nonzeros(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)

Expand Down Expand Up @@ -272,7 +280,7 @@ julia> rowvals(A)
```
"""
rowvals(S::SorF) = getfield(S, :rowval)
rowvals(S::SparseMatrixCSCView) = rowvals(S.parent)
rowvals(S::SparseMatrixCSCColumnSubset) = rowvals(S.parent)
rowvals(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)
rowvals(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)

Expand All @@ -299,7 +307,7 @@ column. In conjunction with [`nonzeros`](@ref) and
Adding or removing nonzero elements to the matrix may invalidate the `nzrange`, one should not mutate the matrix while iterating.
"""
nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
nzrange(S::SparseMatrixCSCView, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::SparseMatrixCSCColumnSubset, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangeup(S.data, i)
nzrange(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangelo(S.data, i)

Expand Down
14 changes: 8 additions & 6 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ begin
A = sprand(rng, n, n, 0.01)
MA = Matrix(A)
lA = sprand(rng, n, n+10, 0.01)
@test nnz(lA[:, n+1:n+10]) == nnz(view(lA, :, n+1:n+10))
@testset "triangular multiply with $tr($wr)" for tr in (identity, adjoint, transpose),
wr in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW * B MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(A, :, 1:n)))
vMAW = tr(wr(view(MA, :, 1:n)))
@test vAW * B vMAW * B
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW * B AW * B
end
a = sprand(rng, ComplexF64, n, n, 0.01)
ma = Matrix(a)
Expand All @@ -172,9 +172,8 @@ begin
MAW = tr(wr(ma))
@test AW * B MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(a, :, 1:n)))
vMAW = tr(wr(view(ma, :, 1:n)))
@test vAW * B vMAW * B
vAW = tr(wr(view([zero(a)+I a], :, (n+1):2n)))
@test vAW * B AW * B
end
A = A - Diagonal(diag(A)) + 2I # avoid rounding errors by division
MA = Matrix(A)
Expand All @@ -183,6 +182,9 @@ begin
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW \ B MAW \ B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW \ B AW \ B
end
@testset "triangular singular exceptions" begin
A = LowerTriangular(sparse([0 2.0;0 1]))
Expand Down

0 comments on commit 75081bc

Please sign in to comment.