Skip to content

5-term mul! with Diagonal #603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 28, 2025
193 changes: 175 additions & 18 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1877,47 +1877,204 @@
## scale methods

# Copy colptr and rowval from one sparse matrix to another
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
if getcolptr(C) !== getcolptr(A)
function copyinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC; copy_rows=true, copy_cols=true)
if copy_cols && getcolptr(C) !== getcolptr(A)
resize!(getcolptr(C), length(getcolptr(A)))
copyto!(getcolptr(C), getcolptr(A))
end
if rowvals(C) !== rowvals(A)
if copy_rows && rowvals(C) !== rowvals(A)
resize!(rowvals(C), length(rowvals(A)))
copyto!(rowvals(C), rowvals(A))
end
end

"""
rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)

Check if A[row, col] is a stored value, and return the index of the row in `rowvals(A)`.
Returns `(row_exists, row_ind)`, where `row_exists::Bool` signifies
whether the corresponding index is populated, and `row_ind` is the index.
If `row_exists` is `false`, the `row_ind` is the index where the value should be inserted into
`rowvals(A)` such that the subarray `@view rowvals(A)[nzrange(A, col)]` remains sorted.
"""
@inline function rowcheck_index(A::AbstractSparseMatrixCSC, row::Integer, col::Integer)
nzinds = nzrange(A, col)
rows_col = @view rowvals(A)[nzinds]
# faster implementation of row ∈ rows_col and obtaining the index,
# assuming that rows_col is sorted
row_ind_col = searchsortedfirst(rows_col, row)
row_exists = row_ind_col ∈ axes(rows_col,1) && rows_col[row_ind_col] == row
row_ind = row_ind_col + first(nzinds) - firstindex(nzinds)
row_exists, row_ind
end

"""
mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)

Update `C` to contain stored values corresponding to the stored indices of `A`.
Stored indices common to `C` and `A` are not touched. Indices of `A` at which
`C` did not have a stored value are populated with zeros after the call.

# Examples
```jldoctest
julia> A = spzeros(3,3);

julia> A[4:4:8] .= 1;

julia> A
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
⋅ 1.0 ⋅
⋅ ⋅ 1.0
⋅ ⋅ ⋅

julia> C = spzeros(3,3);

julia> C[2:4:6] .= 2;

julia> C
3×3 SparseMatrixCSC{Float64, Int64} with 2 stored entries:
⋅ ⋅ ⋅
2.0 ⋅ ⋅
⋅ 2.0 ⋅

julia> SparseArrays.mergeinds!(C, A)
3×3 SparseMatrixCSC{Float64, Int64} with 4 stored entries:
⋅ 0.0 ⋅
2.0 ⋅ 0.0
⋅ 2.0 ⋅
```
"""
function mergeinds!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC)
C_colptr = getcolptr(C)
for col in axes(A,2)
n_extra = 0
for ind in nzrange(A, col)
row = rowvals(A)[ind]
row_exists, ind = rowcheck_index(C, row, col)
if !row_exists
n_extra += 1
insert!(rowvals(C), ind, row)
insert!(nonzeros(C), ind, zero(eltype(C)))
C_colptr[col+1] += 1
end
end
if !iszero(n_extra)
@views C_colptr[col+2:end] .+= n_extra
end
end
C
end

# multiply by diagonal matrix as vector
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal)
function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagonal, alpha::Number, beta::Number)
m, n = size(A)
b = D.diag
b = D.diag
lb = length(b)
n == lb || throw(DimensionMismatch("A has size ($m, $n) but D has size ($lb, $lb)"))
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
copyinds!(C, A)
n == lb || throw(DimensionMismatch(lazy"A has size ($m, $n) but D has size ($lb, $lb)"))
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
beta_is_zero = iszero(beta)
rows_match = rowvals(C) == rowvals(A)
cols_match = getcolptr(C) == getcolptr(A)
identical_nzinds = rows_match && cols_match
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
resize!(Cnzval, length(Anzval))
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col]
if beta_is_zero || identical_nzinds
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
resize!(Cnzval, length(Anzval))
if beta_is_zero
if isone(alpha)
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col]
end
else
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha
end

Check warning on line 1992 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L1990-L1992

Added lines #L1990 - L1992 were not covered by tests
end
else
if isone(alpha)
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col] + Cnzval[p] * beta
end
else
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col] * alpha + Cnzval[p] * beta
end
end
end
else
mergeinds!(C, A)
for col in axes(C,2), p in nzrange(C, col)
row = rowvals(C)[p]
# check if the index (row, col) is stored in A
row_exists, row_ind_A = rowcheck_index(A, row, col)
if row_exists
if isone(alpha)
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] + Cnzval[p] * beta
else
@inbounds Cnzval[p] = Anzval[row_ind_A] * b[col] * alpha + Cnzval[p] * beta
end
else # A[row,col] == 0
@inbounds Cnzval[p] = Cnzval[p] * beta
end
end
end
C
end

function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC)
function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCSC, alpha::Number, beta::Number)
m, n = size(A)
b = D.diag
lb = length(b)
m == lb || throw(DimensionMismatch("D has size ($lb, $lb) but A has size ($m, $n)"))
size(A)==size(C) || throw(DimensionMismatch("A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
copyinds!(C, A)
m == lb || throw(DimensionMismatch(lazy"D has size ($lb, $lb) but A has size ($m, $n)"))
size(A)==size(C) || throw(DimensionMismatch(lazy"A has size ($m, $n), D has size ($lb, $lb), C has size $(size(C))"))
beta_is_zero = iszero(beta)
rows_match = rowvals(C) == rowvals(A)
cols_match = getcolptr(C) == getcolptr(A)
identical_nzinds = rows_match && cols_match
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
Arowval = rowvals(A)
resize!(Cnzval, length(Anzval))
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
if beta_is_zero || identical_nzinds
identical_nzinds || copyinds!(C, A, copy_rows = !rows_match, copy_cols = !cols_match)
resize!(Cnzval, length(Anzval))
if beta_is_zero
if isone(alpha)
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
end
else
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha
end

Check warning on line 2049 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L2047-L2049

Added lines #L2047 - L2049 were not covered by tests
end
else
if isone(alpha)
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] + Cnzval[p] * beta
end
else
for col in axes(A,2), p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p] * alpha + Cnzval[p] * beta
end
end
end
else
mergeinds!(C, A)
for col in axes(C,2), p in nzrange(C, col)
row = rowvals(C)[p]
# check if the index (row, col) is stored in A
row_exists, row_ind_A = rowcheck_index(A, row, col)
if row_exists
if isone(alpha)
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] + Cnzval[p] * beta
else
@inbounds Cnzval[p] = b[row] * Anzval[row_ind_A] * alpha + Cnzval[p] * beta
end
else # A[row,col] == 0
@inbounds Cnzval[p] = Cnzval[p] * beta
end
end
end
C
end
Expand Down
4 changes: 2 additions & 2 deletions test/issues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ end
A = sprand(5,5,0.5)
D = Diagonal(rand(5))
C = copy(A)
m1 = @which mul!(C,A,D)
m2 = @which mul!(C,D,A)
m1 = @which mul!(C,A,D,true,false)
m2 = @which mul!(C,D,A,true,false)
@test m1.module == SparseArrays
@test m2.module == SparseArrays
end
Expand Down
35 changes: 35 additions & 0 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,41 @@ end
@test lmul!(D, copy(sA)) ≈ D * dA
@test mul!(sC, D, copy(sA)) ≈ D * dA
end

@testset "5-arg mul!" begin
@testset "merge indices" begin
# for zero arrays, merge and copy are identical
A = spzeros(size(sA))
SparseArrays.mergeinds!(A, sA)
B = spzeros(size(sA))
SparseArrays.copyinds!(B, sA)
@test all(col -> nzrange(A, col) == nzrange(B, col), axes(A,2))
# for arrays with different indices populated, merge should combine these
A = spzeros(5,5)
A[diagind(A,1)] .= 5
B = spzeros(5,5)
B[diagind(A,-1)] .= 10
SparseArrays.mergeinds!(B, A)
@test rowvals(B) == [2, 1,3, 2,4, 3,5, 4]
@test [nzrange(B,col) for col in axes(B,2)] == [1:1, 2:3, 4:5, 6:7, 8:8]
@test nonzeros(B) == [10, 0,10, 0,10, 0,10, 0]
# for arrays with overlapping indices, merge should only add the extra ones
A[diagind(A,2)] .= 5
SparseArrays.mergeinds!(B, A)
@test rowvals(B) == [2, 1,3, 1,2,4, 2,3,5, 3,4]
@test [nzrange(B,col) for col in axes(B,2)] == [1:1, 2:3, 4:6, 7:9, 10:11]
@test nonzeros(B) == [10, 0,10, 0,0,10, 0,0,10, 0,0]
end
for sA2 in (similar(sA), sprand(size(sA)..., 0.1))
nonzeros(sA2) .= 1
@testset for (alpha, beta) in [(true, false), (true, true), (2,3)]
D = Diagonal(rand(size(sA,2)))
@test mul!(copy(sA2), sA, D, alpha, beta) ≈ dA * D * alpha + sA2 * beta
D = Diagonal(rand(size(sA,1)))
@test mul!(copy(sA2), D, sA, alpha, beta) ≈ D * dA * alpha + sA2 * beta
end
end
end
end

@testset "conj" begin
Expand Down
Loading