Skip to content

Commit 9d4397f

Browse files
authored
Sparse versions of repeat for matrices and vectors (#532)
* Move col_length offset to helper function The col_length function argument was pretty confusing, as it expected not the number of nonzero elements in the column, but rather one less than that. This moves this offset by one to the actual for-loop that it is for, which clarifies the meaning and also simplifies callsites. * Simplify arguments of helper functions Instead of taking the whole input matrix, only take the actuallly required parts (row indices and nonzero values). That allows using it in more contexts. * Add efficient repeat for sparse matrices * Add efficient repeat for sparse vectors * Avoid using temporary vars for sparse data arrays Just use the accessor functions rowvals, nonzeros, getcolptr whenever needed instead. Especially in the sparse vector case avoid using findnz, which creates a copy of the data. Use nonzeroinds and nonzeros instead. * Return next insert position from stuffcol helper Simplifies updating the insert position at the call site.
1 parent a09f90b commit 9d4397f

File tree

4 files changed

+96
-12
lines changed

4 files changed

+96
-12
lines changed

src/sparsematrix.jl

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3908,31 +3908,26 @@ function vcat(X::AbstractSparseMatrixCSC...)
39083908
ptr_res = colptr[c]
39093909
for i = 1 : num
39103910
colptrXi = getcolptr(X[i])
3911-
col_length = (colptrXi[c + 1] - 1) - colptrXi[c]
3911+
col_length = colptrXi[c + 1] - colptrXi[c]
39123912
ptr_Xi = colptrXi[c]
39133913

3914-
stuffcol!(X[i], colptr, rowval, nzval,
3915-
ptr_res, ptr_Xi, col_length, mX_sofar)
3916-
3917-
ptr_res += col_length + 1
3914+
ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(X[i]), nonzeros(X[i]), ptr_Xi,
3915+
col_length, mX_sofar)
39183916
mX_sofar += mX[i]
39193917
end
39203918
colptr[c + 1] = ptr_res
39213919
end
39223920
SparseMatrixCSC(m, n, colptr, rowval, nzval)
39233921
end
39243922

3925-
@inline function stuffcol!(Xi::AbstractSparseMatrixCSC, colptr, rowval, nzval,
3926-
ptr_res, ptr_Xi, col_length, mX_sofar)
3927-
colptrXi = getcolptr(Xi)
3928-
rowvalXi = rowvals(Xi)
3929-
nzvalXi = nonzeros(Xi)
3930-
3931-
for k=ptr_res:(ptr_res + col_length)
3923+
@inline function stuffcol!(rowval, nzval, ptr_res, rowvalXi, nzvalXi, ptr_Xi,
3924+
col_length, mX_sofar)
3925+
for k=ptr_res:(ptr_res + col_length - 1)
39323926
@inbounds rowval[k] = rowvalXi[ptr_Xi] + mX_sofar
39333927
@inbounds nzval[k] = nzvalXi[ptr_Xi]
39343928
ptr_Xi += 1
39353929
end
3930+
return ptr_res + col_length
39363931
end
39373932

39383933
function hcat(X::AbstractSparseMatrixCSC...)
@@ -3973,6 +3968,42 @@ function hcat(X::AbstractSparseMatrixCSC...)
39733968
SparseMatrixCSC(m, n, colptr, rowval, nzval)
39743969
end
39753970

3971+
3972+
# Efficient repetition of sparse matrices
3973+
3974+
function Base.repeat(A::AbstractSparseMatrixCSC, m)
3975+
nnz_new = nnz(A) * m
3976+
colptr = similar(getcolptr(A), length(getcolptr(A)))
3977+
rowval = similar(rowvals(A), nnz_new)
3978+
nzval = similar(nonzeros(A), nnz_new)
3979+
3980+
colptr[1] = 1
3981+
for c = 1 : size(A, 2)
3982+
ptr_res = colptr[c]
3983+
ptr_source = getcolptr(A)[c]
3984+
col_length = getcolptr(A)[c + 1] - ptr_source
3985+
for index_repetition = 0 : (m - 1)
3986+
row_offset = index_repetition * size(A, 1)
3987+
ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(A), nonzeros(A), ptr_source,
3988+
col_length, row_offset)
3989+
end
3990+
colptr[c + 1] = ptr_res
3991+
end
3992+
@assert colptr[end] == nnz_new + 1
3993+
3994+
SparseMatrixCSC(size(A, 1) * m, size(A, 2), colptr, rowval, nzval)
3995+
end
3996+
3997+
function Base.repeat(A::AbstractSparseMatrixCSC, m, n)
3998+
B = repeat(A, m)
3999+
nnz_per_column = diff(getcolptr(B))
4000+
colptr = cumsum(vcat(1, repeat(nnz_per_column, n)))
4001+
rowval = repeat(rowvals(B), n)
4002+
nzval = repeat(nonzeros(B), n)
4003+
SparseMatrixCSC(size(B, 1), size(B, 2) * n, colptr, rowval, nzval)
4004+
end
4005+
4006+
39764007
"""
39774008
blockdiag(A...)
39784009

src/sparsevector.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,34 @@ hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat,
13081308
hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)
13091309

13101310

1311+
### Efficient repetition of sparse vectors
1312+
1313+
function Base.repeat(v::AbstractSparseVector, m)
1314+
nnz_source = nnz(v)
1315+
nnz_new = nnz_source * m
1316+
1317+
nzind = similar(nonzeroinds(v), nnz_new)
1318+
nzval = similar(nonzeros(v), nnz_new)
1319+
1320+
ptr_res = 1
1321+
for index_repetition = 0:(m-1)
1322+
row_offset = index_repetition * length(v)
1323+
ptr_res = stuffcol!(nzind, nzval, ptr_res, nonzeroinds(v), nonzeros(v), 1, nnz_source, row_offset)
1324+
end
1325+
@assert ptr_res == nnz_new + 1
1326+
1327+
SparseVector(length(v) * m, nzind, nzval)
1328+
end
1329+
1330+
function Base.repeat(v::AbstractSparseVector, m, n)
1331+
w = repeat(v, m)
1332+
colptr = Vector{eltype(nonzeroinds(w))}(1 .+ nnz(w) * (0:n))
1333+
rowval = repeat(nonzeroinds(w), n)
1334+
nzval = repeat(nonzeros(w), n)
1335+
SparseMatrixCSC(length(w), n, colptr, rowval, nzval)
1336+
end
1337+
1338+
13111339
# make sure UniformScaling objects are converted to sparse matrices for concatenation
13121340
promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix
13131341
promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n)

test/sparsematrix_constructors_indexing.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,19 @@ end
302302
end
303303
end
304304

305+
@testset "repeat tests" begin
306+
A = sprand(6, 4, 0.5)
307+
A_full = Matrix(A)
308+
for m = 0:3
309+
@test issparse(repeat(A, m))
310+
@test repeat(A, m) == repeat(A_full, m)
311+
for n = 0:3
312+
@test issparse(repeat(A, m, n))
313+
@test repeat(A, m, n) == repeat(A_full, m, n)
314+
end
315+
end
316+
end
317+
305318
@testset "copyto!" begin
306319
A = sprand(5, 5, 0.2)
307320
B = sprand(5, 5, 0.2)

test/sparsevector.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,18 @@ end
635635
end
636636
end
637637
end
638+
639+
@testset "repeat" begin
640+
for m = 0:3
641+
@test issparse(repeat(spv_x1, m))
642+
@test repeat(spv_x1, m) == repeat(x1_full, m)
643+
for n = 0:3
644+
@test issparse(repeat(spv_x1, m, n))
645+
@test repeat(spv_x1, m, n) == repeat(x1_full, m, n)
646+
end
647+
end
648+
end
649+
638650
@testset "sparsemat: combinations with sparse matrix" begin
639651
let S = sprand(4, 8, 0.5)
640652
Sf = Array(S)

0 commit comments

Comments
 (0)