Skip to content

Fix performance trap for sparse view multiplication #476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ const DenseTriangular = UpperOrLowerTriangular{<:Any,<:DenseMatrixUnion}
const DenseInputVector = Union{StridedVector, BitVector}
const DenseVecOrMat = Union{DenseMatrixUnion, DenseInputVector}

matprod_dest(A::SparseMatrixCSCUnion, B::DenseTriangular, TS) =
matprod_dest(A::SparseMatrixCSCUnion2, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::DenseTriangular, TS) =
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))

for op ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric)
Expand All @@ -45,11 +45,11 @@ for op ∈ (:+, :-)
end
end

LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::DenseMatrixUnion, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::AbstractTriangular, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion, B::DenseInputVector, _add::MulAddMul) =
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
Expand Down Expand Up @@ -114,7 +114,7 @@ function _At_or_Ac_mul_B!(tfun::Function, C, A, B, α, β)
C
end

Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
Expand All @@ -125,7 +125,7 @@ Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::Strided
end
return C
end
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -145,7 +145,7 @@ function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixC
end
C
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -164,7 +164,7 @@ function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::A
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::AbstractSparseMatrixCSC, α::Number, β::Number)
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
Expand Down
24 changes: 16 additions & 8 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,23 @@ end
# underlying SparseMatrixCSC
const SparseMatrixCSCView{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange}
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange{<:Integer}}
const SparseMatrixCSCUnion{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCView{Tv,Ti}}
# Define an alias for views of a SparseMatrixCSC which include all rows and a selection of the columns.
# Also define a union of SparseMatrixCSC and this view since many methods can be defined efficiently for
# this union by extracting the fields via the get function: getrowval, and getnzval, BUT NOT getcolptr!
const SparseMatrixCSCColumnSubset{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractVector{<:Integer}}
const SparseMatrixCSCUnion2{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCColumnSubset{Tv,Ti}}

getcolptr(S::SorF) = getfield(S, :colptr)
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(axes(S, 2)):(last(axes(S, 2)) + 1))
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(S.indices[2]):(last(S.indices[2]) + 1))
getcolptr(S::SparseMatrixCSCColumnSubset) = error("getcolptr not well-defined for $(typeof(S))")
getrowval(S::AbstractSparseMatrixCSC) = rowvals(S)
getrowval(S::SparseMatrixCSCView) = rowvals(parent(S))
getrowval(S::SparseMatrixCSCColumnSubset) = rowvals(parent(S))
getnzval( S::AbstractSparseMatrixCSC) = nonzeros(S)
getnzval( S::SparseMatrixCSCView) = nonzeros(parent(S))
getnzval( S::SparseMatrixCSCColumnSubset) = nonzeros(parent(S))
nzvalview(S::AbstractSparseMatrixCSC) = view(nonzeros(S), 1:nnz(S))

"""
Expand All @@ -212,7 +220,7 @@ nnz(S::ReshapedArray{<:Any,1,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::SparseMatrixCSCView) = nnz1(S)
nnz(S::SparseMatrixCSCColumnSubset) = nnz1(S)
nnz1(S) = sum(length.(nzrange.(Ref(S), axes(S, 2))))

function Base._simple_count(pred, S::AbstractSparseMatrixCSC, init::T) where T
Expand Down Expand Up @@ -244,7 +252,7 @@ julia> nonzeros(A)
```
"""
nonzeros(S::SorF) = getfield(S, :nzval)
nonzeros(S::SparseMatrixCSCView) = nonzeros(S.parent)
nonzeros(S::SparseMatrixCSCColumnSubset) = nonzeros(S.parent)
nonzeros(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)
nonzeros(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)

Expand Down Expand Up @@ -272,7 +280,7 @@ julia> rowvals(A)
```
"""
rowvals(S::SorF) = getfield(S, :rowval)
rowvals(S::SparseMatrixCSCView) = rowvals(S.parent)
rowvals(S::SparseMatrixCSCColumnSubset) = rowvals(S.parent)
rowvals(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)
rowvals(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)

Expand All @@ -299,7 +307,7 @@ column. In conjunction with [`nonzeros`](@ref) and
Adding or removing nonzero elements to the matrix may invalidate the `nzrange`, one should not mutate the matrix while iterating.
"""
nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
nzrange(S::SparseMatrixCSCView, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::SparseMatrixCSCColumnSubset, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangeup(S.data, i)
nzrange(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangelo(S.data, i)

Expand Down
14 changes: 8 additions & 6 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ begin
A = sprand(rng, n, n, 0.01)
MA = Matrix(A)
lA = sprand(rng, n, n+10, 0.01)
@test nnz(lA[:, n+1:n+10]) == nnz(view(lA, :, n+1:n+10))
@testset "triangular multiply with $tr($wr)" for tr in (identity, adjoint, transpose),
wr in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW * B ≈ MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(A, :, 1:n)))
vMAW = tr(wr(view(MA, :, 1:n)))
@test vAW * B ≈ vMAW * B
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW * B ≈ AW * B
end
a = sprand(rng, ComplexF64, n, n, 0.01)
ma = Matrix(a)
Expand All @@ -170,9 +170,8 @@ begin
MAW = tr(wr(ma))
@test AW * B ≈ MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(a, :, 1:n)))
vMAW = tr(wr(view(ma, :, 1:n)))
@test vAW * B ≈ vMAW * B
vAW = tr(wr(view([zero(a)+I a], :, (n+1):2n)))
@test vAW * B ≈ AW * B
end
A = A - Diagonal(diag(A)) + 2I # avoid rounding errors by division
MA = Matrix(A)
Expand All @@ -181,6 +180,9 @@ begin
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW \ B ≈ MAW \ B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW \ B ≈ AW \ B
end
@testset "triangular singular exceptions" begin
A = LowerTriangular(sparse([0 2.0;0 1]))
Expand Down