Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix performance trap for sparse view multiplication #476

Merged
merged 9 commits into from
Jan 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@
const DenseInputVector = Union{StridedVector, BitVector}
const DenseVecOrMat = Union{DenseMatrixUnion, DenseInputVector}

matprod_dest(A::SparseMatrixCSCUnion, B::DenseTriangular, TS) =
matprod_dest(A::SparseMatrixCSCUnion2, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, B::DenseTriangular, TS) =
matprod_dest(A::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, B::DenseTriangular, TS) =
similar(B, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion, TS) =
matprod_dest(A::DenseTriangular, B::SparseMatrixCSCUnion2, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::StridedMaybeAdjOrTransMat, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::Union{BitMatrix,AdjOrTrans{<:Any,BitMatrix}}, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}, TS) =
matprod_dest(A::DenseTriangular, B::AdjOrTrans{<:Any,<:SparseMatrixCSCUnion2}, TS) =
similar(A, TS, (size(A, 1), size(B, 2)))

for op ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric)
Expand All @@ -45,11 +45,11 @@
end
end

LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::DenseMatrixUnion, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::DenseMatrixUnion, _add::MulAddMul) =
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion, B::AbstractTriangular, _add::MulAddMul) =
LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::SparseMatrixCSCUnion2, B::AbstractTriangular, _add::MulAddMul) =

Check warning on line 50 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L50

Added line #L50 was not covered by tests
spdensemul!(C, tA, tB, A, B, _add)
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion, B::DenseInputVector, _add::MulAddMul) =
LinearAlgebra.generic_matvecmul!(C::StridedVecOrMat, tA, A::SparseMatrixCSCUnion2, B::DenseInputVector, _add::MulAddMul) =
spdensemul!(C, tA, 'N', A, B, _add)

Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, _add)
Expand Down Expand Up @@ -114,7 +114,7 @@
C
end

Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::AbstractSparseMatrixCSC, _add::MulAddMul)
Base.@constprop :aggressive function LinearAlgebra.generic_matmatmul!(C::StridedMatrix, tA, tB, A::DenseMatrixUnion, B::SparseMatrixCSCUnion2, _add::MulAddMul)
transA = tA == 'N' ? identity : tA == 'T' ? transpose : adjoint
if tB == 'N'
_spmul!(C, transA(A), B, _add.alpha, _add.beta)
Expand All @@ -125,7 +125,7 @@
end
return C
end
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::DenseMatrixUnion, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -145,7 +145,7 @@
end
C
end
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::AbstractSparseMatrixCSC, α::Number, β::Number)
function _spmul!(C::StridedMatrix, X::AdjOrTrans{<:Any,<:DenseMatrixUnion}, A::SparseMatrixCSCUnion2, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) ||
throw(DimensionMismatch("second dimension of X, $nX, does not match the first dimension of A, $(size(A,1))"))
Expand All @@ -164,7 +164,7 @@
C
end

function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::AbstractSparseMatrixCSC, α::Number, β::Number)
function _A_mul_Bt_or_Bc!(tfun::Function, C::StridedMatrix, A::AbstractMatrix, B::SparseMatrixCSCUnion2, α::Number, β::Number)
mA, nA = size(A)
nA == size(B, 2) ||
throw(DimensionMismatch("second dimension of A, $nA, does not match the second dimension of B, $(size(B,2))"))
Expand Down
24 changes: 16 additions & 8 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,23 @@
# underlying SparseMatrixCSC
const SparseMatrixCSCView{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange}
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractUnitRange{<:Integer}}
const SparseMatrixCSCUnion{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCView{Tv,Ti}}
# Define an alias for views of a SparseMatrixCSC which include all rows and a selection of the columns.
# Also define a union of SparseMatrixCSC and this view since many methods can be defined efficiently for
# this union by extracting the fields via the get function: getrowval, and getnzval, BUT NOT getcolptr!
const SparseMatrixCSCColumnSubset{Tv,Ti} =
SubArray{Tv,2,<:AbstractSparseMatrixCSC{Tv,Ti},
Tuple{Base.Slice{Base.OneTo{Int}},I}} where {I<:AbstractVector{<:Integer}}
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
const SparseMatrixCSCUnion2{Tv,Ti} = Union{AbstractSparseMatrixCSC{Tv,Ti}, SparseMatrixCSCColumnSubset{Tv,Ti}}

getcolptr(S::SorF) = getfield(S, :colptr)
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(axes(S, 2)):(last(axes(S, 2)) + 1))
getcolptr(S::SparseMatrixCSCView) = view(getcolptr(parent(S)), first(S.indices[2]):(last(S.indices[2]) + 1))
getcolptr(S::SparseMatrixCSCColumnSubset) = error("getcolptr not well-defined for $(typeof(S))")

Check warning on line 194 in src/sparsematrix.jl

View check run for this annotation

Codecov / codecov/patch

src/sparsematrix.jl#L194

Added line #L194 was not covered by tests
getrowval(S::AbstractSparseMatrixCSC) = rowvals(S)
getrowval(S::SparseMatrixCSCView) = rowvals(parent(S))
getrowval(S::SparseMatrixCSCColumnSubset) = rowvals(parent(S))
getnzval( S::AbstractSparseMatrixCSC) = nonzeros(S)
getnzval( S::SparseMatrixCSCView) = nonzeros(parent(S))
getnzval( S::SparseMatrixCSCColumnSubset) = nonzeros(parent(S))
nzvalview(S::AbstractSparseMatrixCSC) = view(nonzeros(S), 1:nnz(S))

"""
Expand All @@ -212,7 +220,7 @@
nnz(S::AdjOrTrans{<:Any,<:AbstractSparseMatrixCSC}) = nnz(parent(S))
nnz(S::UpperTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::LowerTriangular{<:Any,<:AbstractSparseMatrixCSC}) = nnz1(S)
nnz(S::SparseMatrixCSCView) = nnz1(S)
nnz(S::SparseMatrixCSCColumnSubset) = nnz1(S)
nnz1(S) = sum(length.(nzrange.(Ref(S), axes(S, 2))))

function Base._simple_count(pred, S::AbstractSparseMatrixCSC, init::T) where T
Expand Down Expand Up @@ -244,7 +252,7 @@
```
"""
nonzeros(S::SorF) = getfield(S, :nzval)
nonzeros(S::SparseMatrixCSCView) = nonzeros(S.parent)
nonzeros(S::SparseMatrixCSCColumnSubset) = nonzeros(S.parent)
nonzeros(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)
nonzeros(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = nonzeros(S.data)

Expand Down Expand Up @@ -272,7 +280,7 @@
```
"""
rowvals(S::SorF) = getfield(S, :rowval)
rowvals(S::SparseMatrixCSCView) = rowvals(S.parent)
rowvals(S::SparseMatrixCSCColumnSubset) = rowvals(S.parent)
rowvals(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)
rowvals(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}) = rowvals(S.data)

Expand All @@ -299,7 +307,7 @@
Adding or removing nonzero elements to the matrix may invalidate the `nzrange`, one should not mutate the matrix while iterating.
"""
nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
nzrange(S::SparseMatrixCSCView, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::SparseMatrixCSCColumnSubset, col::Integer) = nzrange(S.parent, S.indices[2][col])
nzrange(S::UpperTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangeup(S.data, i)
nzrange(S::LowerTriangular{<:Any,<:SparseMatrixCSCUnion}, i::Integer) = nzrangelo(S.data, i)

Expand Down
14 changes: 8 additions & 6 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,15 @@ begin
A = sprand(rng, n, n, 0.01)
MA = Matrix(A)
lA = sprand(rng, n, n+10, 0.01)
@test nnz(lA[:, n+1:n+10]) == nnz(view(lA, :, n+1:n+10))
@testset "triangular multiply with $tr($wr)" for tr in (identity, adjoint, transpose),
wr in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular)
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW * B ≈ MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(A, :, 1:n)))
vMAW = tr(wr(view(MA, :, 1:n)))
@test vAW * B ≈ vMAW * B
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW * B ≈ AW * B
end
a = sprand(rng, ComplexF64, n, n, 0.01)
ma = Matrix(a)
Expand All @@ -170,9 +170,8 @@ begin
MAW = tr(wr(ma))
@test AW * B ≈ MAW * B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view(a, :, 1:n)))
vMAW = tr(wr(view(ma, :, 1:n)))
@test vAW * B ≈ vMAW * B
vAW = tr(wr(view([zero(a)+I a], :, (n+1):2n)))
@test vAW * B ≈ AW * B
end
A = A - Diagonal(diag(A)) + 2I # avoid rounding errors by division
MA = Matrix(A)
Expand All @@ -181,6 +180,9 @@ begin
AW = tr(wr(A))
MAW = tr(wr(MA))
@test AW \ B ≈ MAW \ B
# and for SparseMatrixCSCView - a view of all rows and unit range of cols
vAW = tr(wr(view([zero(A)+I A], :, (n+1):2n)))
@test vAW \ B ≈ AW \ B
end
@testset "triangular singular exceptions" begin
A = LowerTriangular(sparse([0 2.0;0 1]))
Expand Down
Loading