Skip to content

Commit

Permalink
Sparse versions of repeat for matrices and vectors (#532)
Browse files Browse the repository at this point in the history
* Move col_length offset to helper function

The col_length function argument was pretty confusing, as it expected
not the number of nonzero elements in the column, but rather one less
than that. This moves this offset by one to the actual for-loop that it
is for, which clarifies the meaning and also simplifies callsites.

* Simplify arguments of helper functions

Instead of taking the whole input matrix, only take the actuallly
required parts (row indices and nonzero values). That allows using it in
more contexts.

* Add efficient repeat for sparse matrices

* Add efficient repeat for sparse vectors

* Avoid using temporary vars for sparse data arrays

Just use the accessor functions rowvals, nonzeros, getcolptr whenever
needed instead. Especially in the sparse vector case avoid using findnz,
which creates a copy of the data. Use nonzeroinds and nonzeros instead.

* Return next insert position from stuffcol helper

Simplifies updating the insert position at the call site.
  • Loading branch information
mjacobse authored May 21, 2024
1 parent a09f90b commit 9d4397f
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 12 deletions.
55 changes: 43 additions & 12 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3908,31 +3908,26 @@ function vcat(X::AbstractSparseMatrixCSC...)
ptr_res = colptr[c]
for i = 1 : num
colptrXi = getcolptr(X[i])
col_length = (colptrXi[c + 1] - 1) - colptrXi[c]
col_length = colptrXi[c + 1] - colptrXi[c]
ptr_Xi = colptrXi[c]

stuffcol!(X[i], colptr, rowval, nzval,
ptr_res, ptr_Xi, col_length, mX_sofar)

ptr_res += col_length + 1
ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(X[i]), nonzeros(X[i]), ptr_Xi,
col_length, mX_sofar)
mX_sofar += mX[i]
end
colptr[c + 1] = ptr_res
end
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

@inline function stuffcol!(Xi::AbstractSparseMatrixCSC, colptr, rowval, nzval,
ptr_res, ptr_Xi, col_length, mX_sofar)
colptrXi = getcolptr(Xi)
rowvalXi = rowvals(Xi)
nzvalXi = nonzeros(Xi)

for k=ptr_res:(ptr_res + col_length)
@inline function stuffcol!(rowval, nzval, ptr_res, rowvalXi, nzvalXi, ptr_Xi,
col_length, mX_sofar)
for k=ptr_res:(ptr_res + col_length - 1)
@inbounds rowval[k] = rowvalXi[ptr_Xi] + mX_sofar
@inbounds nzval[k] = nzvalXi[ptr_Xi]
ptr_Xi += 1
end
return ptr_res + col_length
end

function hcat(X::AbstractSparseMatrixCSC...)
Expand Down Expand Up @@ -3973,6 +3968,42 @@ function hcat(X::AbstractSparseMatrixCSC...)
SparseMatrixCSC(m, n, colptr, rowval, nzval)
end


# Efficient repetition of sparse matrices

function Base.repeat(A::AbstractSparseMatrixCSC, m)
nnz_new = nnz(A) * m
colptr = similar(getcolptr(A), length(getcolptr(A)))
rowval = similar(rowvals(A), nnz_new)
nzval = similar(nonzeros(A), nnz_new)

colptr[1] = 1
for c = 1 : size(A, 2)
ptr_res = colptr[c]
ptr_source = getcolptr(A)[c]
col_length = getcolptr(A)[c + 1] - ptr_source
for index_repetition = 0 : (m - 1)
row_offset = index_repetition * size(A, 1)
ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(A), nonzeros(A), ptr_source,
col_length, row_offset)
end
colptr[c + 1] = ptr_res
end
@assert colptr[end] == nnz_new + 1

SparseMatrixCSC(size(A, 1) * m, size(A, 2), colptr, rowval, nzval)
end

function Base.repeat(A::AbstractSparseMatrixCSC, m, n)
B = repeat(A, m)
nnz_per_column = diff(getcolptr(B))
colptr = cumsum(vcat(1, repeat(nnz_per_column, n)))
rowval = repeat(rowvals(B), n)
nzval = repeat(nonzeros(B), n)
SparseMatrixCSC(size(B, 1), size(B, 2) * n, colptr, rowval, nzval)
end


"""
blockdiag(A...)
Expand Down
28 changes: 28 additions & 0 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,34 @@ hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat,
hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)


### Efficient repetition of sparse vectors

function Base.repeat(v::AbstractSparseVector, m)
nnz_source = nnz(v)
nnz_new = nnz_source * m

nzind = similar(nonzeroinds(v), nnz_new)
nzval = similar(nonzeros(v), nnz_new)

ptr_res = 1
for index_repetition = 0:(m-1)
row_offset = index_repetition * length(v)
ptr_res = stuffcol!(nzind, nzval, ptr_res, nonzeroinds(v), nonzeros(v), 1, nnz_source, row_offset)
end
@assert ptr_res == nnz_new + 1

SparseVector(length(v) * m, nzind, nzval)
end

function Base.repeat(v::AbstractSparseVector, m, n)
w = repeat(v, m)
colptr = Vector{eltype(nonzeroinds(w))}(1 .+ nnz(w) * (0:n))
rowval = repeat(nonzeroinds(w), n)
nzval = repeat(nonzeros(w), n)
SparseMatrixCSC(length(w), n, colptr, rowval, nzval)
end


# make sure UniformScaling objects are converted to sparse matrices for concatenation
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)
Expand Down
13 changes: 13 additions & 0 deletions test/sparsematrix_constructors_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,19 @@ end
end
end

@testset "repeat tests" begin
A = sprand(6, 4, 0.5)
A_full = Matrix(A)
for m = 0:3
@test issparse(repeat(A, m))
@test repeat(A, m) == repeat(A_full, m)
for n = 0:3
@test issparse(repeat(A, m, n))
@test repeat(A, m, n) == repeat(A_full, m, n)
end
end
end

@testset "copyto!" begin
A = sprand(5, 5, 0.2)
B = sprand(5, 5, 0.2)
Expand Down
12 changes: 12 additions & 0 deletions test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,18 @@ end
end
end
end

@testset "repeat" begin
for m = 0:3
@test issparse(repeat(spv_x1, m))
@test repeat(spv_x1, m) == repeat(x1_full, m)
for n = 0:3
@test issparse(repeat(spv_x1, m, n))
@test repeat(spv_x1, m, n) == repeat(x1_full, m, n)
end
end
end

@testset "sparsemat: combinations with sparse matrix" begin
let S = sprand(4, 8, 0.5)
Sf = Array(S)
Expand Down

0 comments on commit 9d4397f

Please sign in to comment.