Skip to content

Commit 4240f1b

Browse files
committed
Resize buffers in sparse! to satisfy buffer checks in constructor
With this patch the output buffers to `sparse!` are resized in order to satisfy the buffer length checks in the `SparseMatrixCSC` constructor that were introduced in JuliaLang/julia#40523. Previously `csccolptr` was never resized, and `cscrowval` and `cscnzval` were only resized if the buffers were too short (i.e. never truncated). The requirement `length(csccolptr) >= n + 1` could be kept, but seems unnecessary since all buffers need to be resized anyway (to pass the constructor checks). In particular this fixes calling `sparse!` with `I`, `J`, `V` as both input and output buffers: `sparse!(I, J, V, m, n, ..., I, J, V)`. Fixes #313.
1 parent 57cbb74 commit 4240f1b

File tree

2 files changed

+87
-9
lines changed

2 files changed

+87
-9
lines changed

src/sparsematrix.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,9 +1068,8 @@ intermediate CSR forms and require `length(csrrowptr) >= m + 1`,
10681068
`length(csrcolval) >= length(I)`, and `length(csrnzval >= length(I))`. Input
10691069
array `klasttouch`, workspace for the second stage, requires `length(klasttouch) >= n`.
10701070
Optional input arrays `csccolptr`, `cscrowval`, and `cscnzval` constitute storage for the
1071-
returned CSC form `S`. `csccolptr` requires `length(csccolptr) >= n + 1`. If necessary,
1072-
`cscrowval` and `cscnzval` are automatically resized to satisfy
1073-
`length(cscrowval) >= nnz(S)` and `length(cscnzval) >= nnz(S)`; hence, if `nnz(S)` is
1071+
returned CSC form `S`. If necessary, these are resized automatically to satisfy
1072+
`length(csccolptr) = n + 1`, `length(cscrowval) = nnz(S)` and `length(cscnzval) = nnz(S)`; hence, if `nnz(S)` is
10741073
unknown at the outset, passing in empty vectors of the appropriate type (`Vector{Ti}()`
10751074
and `Vector{Tv}()` respectively) suffices, or calling the `sparse!` method
10761075
neglecting `cscrowval` and `cscnzval`.
@@ -1081,6 +1080,7 @@ representation of the result's transpose.
10811080
You may reuse the input arrays' storage (`I`, `J`, `V`) for the output arrays
10821081
(`csccolptr`, `cscrowval`, `cscnzval`). For example, you may call
10831082
`sparse!(I, J, V, csrrowptr, csrcolval, csrnzval, I, J, V)`.
1083+
Note that they will be resized to satisfy the conditions above.
10841084
10851085
For the sake of efficiency, this method performs no argument checking beyond
10861086
`1 <= I[k] <= m` and `1 <= J[k] <= n`. Use with care. Testing with `--check-bounds=yes`
@@ -1140,6 +1140,9 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti}, V::Union{Tv,Abstr
11401140
end
11411141
# This completes the unsorted-row, has-repeats CSR form's construction
11421142

1143+
# The output array csccolptr can now be resized safely even if aliased with I
1144+
resize!(csccolptr, n + 1)
1145+
11431146
# Sweep through the CSR form, simultaneously (1) calculating the CSC form's column
11441147
# counts and storing them shifted forward by one in csccolptr; (2) detecting repeated
11451148
# entries; and (3) repacking the CSR form with the repeated entries combined.
@@ -1188,10 +1191,13 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti}, V::Union{Tv,Abstr
11881191
Base.hastypemax(Ti) && (countsum <= typemax(Ti) || throw(ArgumentError("more than typemax(Ti)-1 == $(typemax(Ti)-1) entries")))
11891192
end
11901193

1191-
# Now knowing the CSC form's entry count, resize cscrowval and cscnzval if necessary
1194+
# Now knowing the CSC form's entry count, resize cscrowval and cscnzval
1195+
# Note: This is done unconditionally to appease the buffer checks in the SparseMatrixCSC
1196+
# constructor. If these checks are lifted this resizing is only needed if the
1197+
# buffers are too short. csccolptr is resized above.
11921198
cscnnz = countsum - Tj(1)
1193-
length(cscrowval) < cscnnz && resize!(cscrowval, cscnnz)
1194-
length(cscnzval) < cscnnz && resize!(cscnzval, cscnnz)
1199+
resize!(cscrowval, cscnnz)
1200+
resize!(cscnzval, cscnnz)
11951201

11961202
# Finally counting-sort the row and nonzero values from the CSR form into cscrowval and
11971203
# cscnzval. Tracking write positions in csccolptr corrects the column pointers.

test/sparsematrix_constructors_indexing.jl

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ using Dates
1212
include("forbidproperties.jl")
1313
include("simplesmatrix.jl")
1414

15+
function same_structure(A, B)
16+
return all(getfield(A, f) == getfield(B, f) for f in (:m, :n, :colptr, :rowval))
17+
end
18+
1519
@testset "uniform scaling should not change type #103" begin
1620
A = spzeros(Float32, Int8, 5, 5)
1721
B = I - A
@@ -59,9 +63,6 @@ end
5963
end
6064

6165
@testset "spzeros for pattern creation (structural zeros)" begin
62-
function same_structure(A, B)
63-
return all(getfield(A, f) == getfield(B, f) for f in (:m, :n, :colptr, :rowval))
64-
end
6566
I = [1, 2, 3]
6667
J = [1, 3, 4]
6768
V = zeros(length(I))
@@ -1625,4 +1626,75 @@ end
16251626
@test_throws ArgumentError SparseArrays.expandptr([2; 3])
16261627
end
16271628

1629+
@testset "sparse!" begin
1630+
using SparseArrays: sparse!, getcolptr, getrowval, nonzeros
1631+
1632+
function allocate_arrays(m, n)
1633+
N = round(Int, 0.5 * m * n)
1634+
Tv, Ti = Float64, Int
1635+
I = Ti[rand(1:m) for _ in 1:N]; I = Ti[I; I]
1636+
J = Ti[rand(1:n) for _ in 1:N]; J = Ti[J; J]
1637+
V = Tv.(I)
1638+
csrrowptr = Vector{Ti}(undef, m + 1)
1639+
csrcolval = Vector{Ti}(undef, length(I))
1640+
csrnzval = Vector{Tv}(undef, length(I))
1641+
klasttouch = Vector{Ti}(undef, n)
1642+
csccolptr = Vector{Ti}(undef, n + 1)
1643+
cscrowval = Vector{Ti}()
1644+
cscnzval = Vector{Tv}()
1645+
return I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval
1646+
end
1647+
1648+
for (m, n) in ((10, 5), (5, 10), (10, 10))
1649+
# Passing csr vectors
1650+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n)
1651+
S = sparse(I, J, V, m, n)
1652+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval)
1653+
@test S == S!
1654+
@test same_structure(S, S!)
1655+
1656+
# Passing csr vectors + csccolptr
1657+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr = allocate_arrays(m, n)
1658+
S = sparse(I, J, V, m, n)
1659+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr)
1660+
@test S == S!
1661+
@test same_structure(S, S!)
1662+
@test getcolptr(S!) === csccolptr
1663+
1664+
# Passing csr vectors, and csc vectors
1665+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval =
1666+
allocate_arrays(m, n)
1667+
S = sparse(I, J, V, m, n)
1668+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval,
1669+
csccolptr, cscrowval, cscnzval)
1670+
@test S == S!
1671+
@test same_structure(S, S!)
1672+
@test getcolptr(S!) === csccolptr
1673+
@test getrowval(S!) === cscrowval
1674+
@test nonzeros(S!) === cscnzval
1675+
1676+
# Passing csr vectors, and csc vectors of insufficient lengths
1677+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval, csccolptr, cscrowval, cscnzval =
1678+
allocate_arrays(m, n)
1679+
S = sparse(I, J, V, m, n)
1680+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval,
1681+
resize!(csccolptr, 0), resize!(cscrowval, 0), resize!(cscnzval, 0))
1682+
@test S == S!
1683+
@test same_structure(S, S!)
1684+
@test getcolptr(S!) === csccolptr
1685+
@test getrowval(S!) === cscrowval
1686+
@test nonzeros(S!) === cscnzval
1687+
1688+
# Passing csr vectors, and csc vectors aliased with I, J, V
1689+
I, J, V, klasttouch, csrrowptr, csrcolval, csrnzval = allocate_arrays(m, n)
1690+
S = sparse(I, J, V, m, n)
1691+
S! = sparse!(I, J, V, m, n, +, klasttouch, csrrowptr, csrcolval, csrnzval, I, J, V)
1692+
@test S == S!
1693+
@test same_structure(S, S!)
1694+
@test getcolptr(S!) === I
1695+
@test getrowval(S!) === J
1696+
@test nonzeros(S!) === V
1697+
end
1698+
end
1699+
16281700
end

0 commit comments

Comments
 (0)