diff --git a/src/sparsematrix.jl b/src/sparsematrix.jl index f94318b2..4fa2adf9 100644 --- a/src/sparsematrix.jl +++ b/src/sparsematrix.jl @@ -3908,13 +3908,11 @@ function vcat(X::AbstractSparseMatrixCSC...) ptr_res = colptr[c] for i = 1 : num colptrXi = getcolptr(X[i]) - col_length = (colptrXi[c + 1] - 1) - colptrXi[c] + col_length = colptrXi[c + 1] - colptrXi[c] ptr_Xi = colptrXi[c] - stuffcol!(X[i], colptr, rowval, nzval, - ptr_res, ptr_Xi, col_length, mX_sofar) - - ptr_res += col_length + 1 + ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(X[i]), nonzeros(X[i]), ptr_Xi, + col_length, mX_sofar) mX_sofar += mX[i] end colptr[c + 1] = ptr_res @@ -3922,17 +3920,14 @@ function vcat(X::AbstractSparseMatrixCSC...) SparseMatrixCSC(m, n, colptr, rowval, nzval) end -@inline function stuffcol!(Xi::AbstractSparseMatrixCSC, colptr, rowval, nzval, - ptr_res, ptr_Xi, col_length, mX_sofar) - colptrXi = getcolptr(Xi) - rowvalXi = rowvals(Xi) - nzvalXi = nonzeros(Xi) - - for k=ptr_res:(ptr_res + col_length) +@inline function stuffcol!(rowval, nzval, ptr_res, rowvalXi, nzvalXi, ptr_Xi, + col_length, mX_sofar) + for k=ptr_res:(ptr_res + col_length - 1) @inbounds rowval[k] = rowvalXi[ptr_Xi] + mX_sofar @inbounds nzval[k] = nzvalXi[ptr_Xi] ptr_Xi += 1 end + return ptr_res + col_length end function hcat(X::AbstractSparseMatrixCSC...) @@ -3973,6 +3968,42 @@ function hcat(X::AbstractSparseMatrixCSC...) SparseMatrixCSC(m, n, colptr, rowval, nzval) end + +# Efficient repetition of sparse matrices + +function Base.repeat(A::AbstractSparseMatrixCSC, m) + nnz_new = nnz(A) * m + colptr = similar(getcolptr(A), length(getcolptr(A))) + rowval = similar(rowvals(A), nnz_new) + nzval = similar(nonzeros(A), nnz_new) + + colptr[1] = 1 + for c = 1 : size(A, 2) + ptr_res = colptr[c] + ptr_source = getcolptr(A)[c] + col_length = getcolptr(A)[c + 1] - ptr_source + for index_repetition = 0 : (m - 1) + row_offset = index_repetition * size(A, 1) + ptr_res = stuffcol!(rowval, nzval, ptr_res, rowvals(A), nonzeros(A), ptr_source, + col_length, row_offset) + end + colptr[c + 1] = ptr_res + end + @assert colptr[end] == nnz_new + 1 + + SparseMatrixCSC(size(A, 1) * m, size(A, 2), colptr, rowval, nzval) +end + +function Base.repeat(A::AbstractSparseMatrixCSC, m, n) + B = repeat(A, m) + nnz_per_column = diff(getcolptr(B)) + colptr = cumsum(vcat(1, repeat(nnz_per_column, n))) + rowval = repeat(rowvals(B), n) + nzval = repeat(nonzeros(B), n) + SparseMatrixCSC(size(B, 1), size(B, 2) * n, colptr, rowval, nzval) +end + + """ blockdiag(A...) diff --git a/src/sparsevector.jl b/src/sparsevector.jl index 97cf979e..a2102eff 100644 --- a/src/sparsevector.jl +++ b/src/sparsevector.jl @@ -1308,6 +1308,34 @@ hvcat(rows::Tuple{Vararg{Int}}, n1::Number, ns::Vararg{Number}) = invoke(hvcat, hvcat(rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}) where {N<:Number} = invoke(hvcat, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...) +### Efficient repetition of sparse vectors + +function Base.repeat(v::AbstractSparseVector, m) + nnz_source = nnz(v) + nnz_new = nnz_source * m + + nzind = similar(nonzeroinds(v), nnz_new) + nzval = similar(nonzeros(v), nnz_new) + + ptr_res = 1 + for index_repetition = 0:(m-1) + row_offset = index_repetition * length(v) + ptr_res = stuffcol!(nzind, nzval, ptr_res, nonzeroinds(v), nonzeros(v), 1, nnz_source, row_offset) + end + @assert ptr_res == nnz_new + 1 + + SparseVector(length(v) * m, nzind, nzval) +end + +function Base.repeat(v::AbstractSparseVector, m, n) + w = repeat(v, m) + colptr = Vector{eltype(nonzeroinds(w))}(1 .+ nnz(w) * (0:n)) + rowval = repeat(nonzeroinds(w), n) + nzval = repeat(nonzeros(w), n) + SparseMatrixCSC(length(w), n, colptr, rowval, nzval) +end + + # make sure UniformScaling objects are converted to sparse matrices for concatenation promote_to_array_type(A::Tuple{Vararg{Union{_SparseConcatGroup,UniformScaling}}}) = anysparse(A...) ? SparseMatrixCSC : Matrix promote_to_arrays_(n::Int, ::Type{SparseMatrixCSC}, J::UniformScaling) = sparse(J, n, n) diff --git a/test/sparsematrix_constructors_indexing.jl b/test/sparsematrix_constructors_indexing.jl index b1b08dbf..a0396f40 100644 --- a/test/sparsematrix_constructors_indexing.jl +++ b/test/sparsematrix_constructors_indexing.jl @@ -302,6 +302,19 @@ end end end +@testset "repeat tests" begin + A = sprand(6, 4, 0.5) + A_full = Matrix(A) + for m = 0:3 + @test issparse(repeat(A, m)) + @test repeat(A, m) == repeat(A_full, m) + for n = 0:3 + @test issparse(repeat(A, m, n)) + @test repeat(A, m, n) == repeat(A_full, m, n) + end + end +end + @testset "copyto!" begin A = sprand(5, 5, 0.2) B = sprand(5, 5, 0.2) diff --git a/test/sparsevector.jl b/test/sparsevector.jl index 15499e12..c91f0d69 100644 --- a/test/sparsevector.jl +++ b/test/sparsevector.jl @@ -635,6 +635,18 @@ end end end end + +@testset "repeat" begin + for m = 0:3 + @test issparse(repeat(spv_x1, m)) + @test repeat(spv_x1, m) == repeat(x1_full, m) + for n = 0:3 + @test issparse(repeat(spv_x1, m, n)) + @test repeat(spv_x1, m, n) == repeat(x1_full, m, n) + end + end +end + @testset "sparsemat: combinations with sparse matrix" begin let S = sprand(4, 8, 0.5) Sf = Array(S)