Skip to content

Commit 2afab2d

Browse files
committed
Remove explicit dependence of sparse broadcast on type inference
Instead of determining the output element type beforehand by querying inference, the element type is deduced from the actually computed output values (similar to broadcast over Array, but taking into account the output for the all-inputs-zero case). For the type-unstable case, performance is sub-optimal, but at least it gives the correct result. Closes #19595.
1 parent 2d8f5bf commit 2afab2d

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

base/sparse/sparsematrix.jl

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,7 +1413,7 @@ function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
14131413
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14141414
fpreszeros = fofzeros == zero(fofzeros)
14151415
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
1416-
entrytypeC = _broadcast_type(f, A, Bs...)
1416+
entrytypeC = typeof(fofzeros)
14171417
indextypeC = _promote_indtype(A, Bs...)
14181418
Ccolptr = Vector{indextypeC}(A.n + 1)
14191419
Crowval = Vector{indextypeC}(maxnnzC)
@@ -1438,7 +1438,7 @@ function broadcast{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N
14381438
fofzeros = f(_zeros_eltypes(A, Bs...)...)
14391439
fpreszeros = fofzeros == zero(fofzeros)
14401440
indextypeC = _promote_indtype(A, Bs...)
1441-
entrytypeC = _broadcast_type(f, A, Bs...)
1441+
entrytypeC = typeof(fofzeros)
14421442
Cm, Cn = Base.to_shape(Base.Broadcast.broadcast_indices(A, Bs...))
14431443
maxnnzC = fpreszeros ? _checked_maxnnzbcres(Cm, Cn, A, Bs...) : (Cm * Cn)
14441444
Ccolptr = Vector{indextypeC}(Cn + 1)
@@ -1464,28 +1464,34 @@ _maxnnzfrom(Cm, Cn, A) = nnz(A) * div(Cm, A.m) * div(Cn, A.n)
14641464
@inline _maxnnzfrom_each(Cm, Cn, As) = (_maxnnzfrom(Cm, Cn, first(As)), _maxnnzfrom_each(Cm, Cn, tail(As))...)
14651465
@inline _unchecked_maxnnzbcres(Cm, Cn, As) = min(Cm * Cn, sum(_maxnnzfrom_each(Cm, Cn, As)))
14661466
@inline _checked_maxnnzbcres(Cm, Cn, As...) = Cm != 0 && Cn != 0 ? _unchecked_maxnnzbcres(Cm, Cn, As) : 0
1467-
_broadcast_type(f, As...) = Base._promote_op(f, Base.Broadcast.typestuple(As...))
1467+
@inline _update_nzval!{T}(nzval::Vector{T}, k, x::T) = (nzval[k] = x; nzval)
1468+
@inline function _update_nzval!{T,Tx}(nzval::Vector{T}, k, x::Tx)
1469+
nzval = convert(Vector{typejoin(Tx, T)}, nzval)
1470+
nzval[k] = x
1471+
return nzval
1472+
end
14681473

14691474
# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
14701475
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
14711476
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
14721477
spaceC = min(length(C.rowval), length(C.nzval))
14731478
Ck = 1
1479+
nzval = C.nzval
14741480
@inbounds for j in 1:C.n
14751481
C.colptr[j] = Ck
14761482
for Ak in nzrange(A, j)
14771483
Cx = f(A.nzval[Ak])
14781484
if Cx != zero(eltype(C))
14791485
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + nnz(A) - (Ak - 1)))
14801486
C.rowval[Ck] = A.rowval[Ak]
1481-
C.nzval[Ck] = Cx
1487+
nzval = _update_nzval!(nzval, Ck, Cx)
14821488
Ck += 1
14831489
end
14841490
end
14851491
end
14861492
@inbounds C.colptr[C.n + 1] = Ck
14871493
_trimstorage!(C, Ck - 1)
1488-
return C
1494+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
14891495
end
14901496
"""
14911497
Densifies `C`, storing `fillvalue` in place of each unstored entry in `A` and
@@ -1496,13 +1502,14 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
14961502
_densestructure!(C)
14971503
# Populate values
14981504
fill!(C.nzval, fillvalue)
1505+
nzval = C.nzval
14991506
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1)), Ak in nzrange(A, j)
15001507
Cx = f(A.nzval[Ak])
1501-
Cx != fillvalue && (C.nzval[jo + A.rowval[Ak]] = Cx)
1508+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + A.rowval[Ak], Cx))
15021509
end
15031510
# NOTE: Combining the fill! above into the loop above to avoid multiple sweeps over /
15041511
# nonsequential access of C.nzval does not appear to improve performance.
1505-
return C
1512+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
15061513
end
15071514
# helper functions for these methods and some of those below
15081515
function _expandstorage!(X::SparseMatrixCSC, maxstored)
@@ -1533,6 +1540,7 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::Sp
15331540
spaceC = min(length(C.rowval), length(C.nzval))
15341541
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
15351542
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1543+
nzval = C.nzval
15361544
Ck = 1
15371545
@inbounds for j in 1:C.n
15381546
C.colptr[j] = Ck
@@ -1562,12 +1570,13 @@ function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::Sp
15621570
if Cx != zero(eltype(C))
15631571
Ck > spaceC && (spaceC = _expandstorage!(C, Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))))
15641572
C.rowval[Ck] = Ci
1565-
C.nzval[Ck] = Cx
1573+
nzval = _update_nzval!(nzval, Ck, Cx)
15661574
Ck += 1
15671575
end
15681576
end
15691577
end
15701578
@inbounds C.colptr[C.n + 1] = Ck
1579+
nzval === C.nzval || (C = SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval))
15711580
_trimstorage!(C, Ck - 1)
15721581
return C
15731582
end
@@ -1578,6 +1587,7 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
15781587
fill!(C.nzval, fillvalue)
15791588
# NOTE: Combining this fill! into the loop below to avoid multiple sweeps over /
15801589
# nonsequential access of C.nzval does not appear to improve performance.
1590+
nzval = C.nzval
15811591
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
15821592
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
15831593
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(C.m*C.n - 1))
@@ -1598,10 +1608,10 @@ function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMa
15981608
Cx, Ci = f(zero(eltype(A)), B.nzval[Bk]), Bi
15991609
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
16001610
end
1601-
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1611+
Cx != fillvalue && (nzval = _update_nzval!(nzval, jo + Ci, Cx))
16021612
end
16031613
end
1604-
return C
1614+
return nzval === C.nzval ? C : SparseMatrixCSC(C.m, C.n, C.colptr, C.rowval, nzval)
16051615
end
16061616
# _broadcast_zeropres!/_broadcast_notzeropres! specialized for a pair of (input) sparse matrices
16071617
function _broadcast_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC)

test/sparse/sparse.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,3 +1828,16 @@ let
18281828
@test_throws DimensionMismatch broadcast(+, A, B, speye(N))
18291829
@test_throws DimensionMismatch broadcast!(+, X, A, B, speye(N))
18301830
end
1831+
1832+
# Issue #19595 - broadcasting over sparse matrices with abstract eltype
1833+
let x = sparse(eye(Real,3,3))
1834+
@test eltype(x) === Real
1835+
@test eltype(x + x) <: Real
1836+
@test eltype(x .+ x) <: Real
1837+
@test eltype(map(+, x, x)) <: Real
1838+
@test eltype(broadcast(+, x, x)) <: Real
1839+
@test eltype(x + x + x) <: Real
1840+
@test eltype(x .+ x .+ x) <: Real
1841+
@test eltype(map(+, map(+, x, x), x)) <: Real
1842+
@test eltype(broadcast(+, broadcast(+, x, x), x)) <: Real
1843+
end

0 commit comments

Comments
 (0)