Skip to content

Commit 747c67f

Browse files
committed
Resize buffers in sparse! to satisfy buffer checks in constructor (#314)
With this patch the output buffers to `sparse!` are resized in order to satisfy the buffer length checks in the `SparseMatrixCSC` constructor that were introduced in JuliaLang/julia#40523. Previously `csccolptr` was never resized, and `cscrowval` and `cscnzval` were only resized if the buffers were too short (i.e. never truncated). The requirement `length(csccolptr) >= n + 1` could be kept, but seems unnecessary since all buffers need to be resized anyway (to pass the constructor checks). In particular this fixes calling `sparse!` with `I`, `J`, `V` as both input and output buffers: `sparse!(I, J, V, m, n, ..., I, J, V)`. Fixes #313. (cherry picked from commit 85a381b)
1 parent 6811df2 commit 747c67f

File tree

2 files changed

+87
-6
lines changed

2 files changed

+87
-6
lines changed

src/sparsematrix.jl

+12-6
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,8 @@ intermediate CSR forms and require `length(csrrowptr) >= m + 1`,
10641064
`length(csrcolval) >= length(I)`, and `length(csrnzval >= length(I))`. Input
10651065
array `klasttouch`, workspace for the second stage, requires `length(klasttouch) >= n`.
10661066
Optional input arrays `csccolptr`, `cscrowval`, and `cscnzval` constitute storage for the
1067-
returned CSC form `S`. `csccolptr` requires `length(csccolptr) >= n + 1`. If necessary,
1068-
`cscrowval` and `cscnzval` are automatically resized to satisfy
1069-
`length(cscrowval) >= nnz(S)` and `length(cscnzval) >= nnz(S)`; hence, if `nnz(S)` is
1067+
returned CSC form `S`. If necessary, these are resized automatically to satisfy
1068+
`length(csccolptr) = n + 1`, `length(cscrowval) = nnz(S)` and `length(cscnzval) = nnz(S)`; hence, if `nnz(S)` is
10701069
unknown at the outset, passing in empty vectors of the appropriate type (`Vector{Ti}()`
10711070
and `Vector{Tv}()` respectively) suffices, or calling the `sparse!` method
10721071
neglecting `cscrowval` and `cscnzval`.
@@ -1077,6 +1076,7 @@ representation of the result's transpose.
10771076
You may reuse the input arrays' storage (`I`, `J`, `V`) for the output arrays
10781077
(`csccolptr`, `cscrowval`, `cscnzval`). For example, you may call
10791078
`sparse!(I, J, V, csrrowptr, csrcolval, csrnzval, I, J, V)`.
1079+
Note that they will be resized to satisfy the conditions above.
10801080
10811081
For the sake of efficiency, this method performs no argument checking beyond
10821082
`1 <= I[k] <= m` and `1 <= J[k] <= n`. Use with care. Testing with `--check-bounds=yes`
@@ -1131,6 +1131,9 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
11311131
end
11321132
# This completes the unsorted-row, has-repeats CSR form's construction
11331133

1134+
# The output array csccolptr can now be resized safely even if aliased with I
1135+
resize!(csccolptr, n + 1)
1136+
11341137
# Sweep through the CSR form, simultaneously (1) calculating the CSC form's column
11351138
# counts and storing them shifted forward by one in csccolptr; (2) detecting repeated
11361139
# entries; and (3) repacking the CSR form with the repeated entries combined.
@@ -1175,10 +1178,13 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
11751178
Base.hastypemax(Ti) && (countsum <= typemax(Ti) || throw(ArgumentError("more than typemax(Ti)-1 == $(typemax(Ti)-1) entries")))
11761179
end
11771180

1178-
# Now knowing the CSC form's entry count, resize cscrowval and cscnzval if necessary
1181+
# Now knowing the CSC form's entry count, resize cscrowval and cscnzval
1182+
# Note: This is done unconditionally to appease the buffer checks in the SparseMatrixCSC
1183+
# constructor. If these checks are lifted this resizing is only needed if the
1184+
# buffers are too short. csccolptr is resized above.
11791185
cscnnz = countsum - Tj(1)
1180-
length(cscrowval) < cscnnz && resize!(cscrowval, cscnnz)
1181-
length(cscnzval) < cscnnz && resize!(cscnzval, cscnnz)
1186+
resize!(cscrowval, cscnnz)
1187+
resize!(cscnzval, cscnnz)
11821188

11831189
# Finally counting-sort the row and nonzero values from the CSR form into cscrowval and
11841190
# cscnzval. Tracking write positions in csccolptr corrects the column pointers.

test/sparsematrix_constructors_indexing.jl

+75
Original file line numberDiff line numberDiff line change
@@ -1598,4 +1598,79 @@ end
15981598
@test_throws ArgumentError SparseArrays.expandptr([2; 3])
15991599
end
16001600

1601+
@testset "sparse!" begin
1602+
using SparseArrays: sparse!, getcolptr, getrowval, nonzeros
1603+
1604+
function same_structure(A, B)
1605+
return all(getfield(A, f) == getfield(B, f) for f in (:m, :n, :colptr, :rowval))
1606+
end
1607+
1608+
function allocate_arrays(m, n)
1609+
N = round(Int, 0.5 * m * n)
1610+
Tv, Ti = Float64, Int
1611+
I = Ti[rand(1:m) for _ in 1:N]; I = Ti[I; I]
1612+
J = Ti[rand(1:n) for _ in 1:N]; J = Ti[J; J]
1613+
V = Tv.(I)
1614+
csrrowptr = Vector{Ti}(undef, m + 1)
1615+
csrcolval = Vector{Ti}(undef, length(I))
1616+
csrnzval = Vector{Tv}(undef, length(I))
1617+
klasttouch = Vector{Ti}(undef, n)
1618+
csccolptr = Vector{Ti}(undef, n + 1)
1619+
cscrowval = Vector{Ti}()
1620+
cscnzval = Vector{Tv}()
1621+
return I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval
1622+
end
1623+
1624+
for (m, n) in ((10, 5), (5, 10), (10, 10))
1625+
# Passing csr vectors
1626+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n)
1627+
S = sparse(I, J, V, m, n)
1628+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval)
1629+
@test S == S!
1630+
@test same_structure(S, S!)
1631+
1632+
# Passing csr vectors + csccolptr
1633+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr = allocate_arrays(m, n)
1634+
S = sparse(I, J, V, m, n)
1635+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr)
1636+
@test S == S!
1637+
@test same_structure(S, S!)
1638+
@test getcolptr(S!) === csccolptr
1639+
1640+
# Passing csr vectors, and csc vectors
1641+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval =
1642+
allocate_arrays(m, n)
1643+
S = sparse(I, J, V, m, n)
1644+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval,
1645+
csccolptr, cscrowval, cscnzval)
1646+
@test S == S!
1647+
@test same_structure(S, S!)
1648+
@test getcolptr(S!) === csccolptr
1649+
@test getrowval(S!) === cscrowval
1650+
@test nonzeros(S!) === cscnzval
1651+
1652+
# Passing csr vectors, and csc vectors of insufficient lengths
1653+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval =
1654+
allocate_arrays(m, n)
1655+
S = sparse(I, J, V, m, n)
1656+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval,
1657+
resize!(csccolptr, 0), resize!(cscrowval, 0), resize!(cscnzval, 0))
1658+
@test S == S!
1659+
@test same_structure(S, S!)
1660+
@test getcolptr(S!) === csccolptr
1661+
@test getrowval(S!) === cscrowval
1662+
@test nonzeros(S!) === cscnzval
1663+
1664+
# Passing csr vectors, and csc vectors aliased with I, J, V
1665+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n)
1666+
S = sparse(I, J, V, m, n)
1667+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, I, J, V)
1668+
@test S == S!
1669+
@test same_structure(S, S!)
1670+
@test getcolptr(S!) === I
1671+
@test getrowval(S!) === J
1672+
@test nonzeros(S!) === V
1673+
end
1674+
end
1675+
16011676
end

0 commit comments

Comments
 (0)