Skip to content

Commit ce545a6

Browse files
committed
Test sparse broadcast! over combinations of broadcast scalars and sparse vectors/matrices.
1 parent 746dbb0 commit ce545a6

File tree

2 files changed

+46
-16
lines changed

2 files changed

+46
-16
lines changed

base/sparse/higherorderfns.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -863,26 +863,26 @@ end
863863
return broadcast!(parevalf, dest, passedsrcargstup...)
864864
end
865865
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
866-
# broadcast scalar arguments (mixedargs), and returns a function (parevalf) and a reduced
867-
# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
868-
# in their orginal order, and such that the result of broadcast(g, passedargstup...) is
869-
# broadcast(f, mixedargs...)
870-
capturescalars(f, mixedargs) =
866+
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
867+
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
868+
# vectors/matrices in mixedargs in their orginal order, and such that the result of
869+
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
870+
@inline capturescalars(f, mixedargs) =
871871
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
872872
# Recursion cases for capturescalars
873-
capturescalars(f, passedargstup, scalararg, mixedargs...) =
873+
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
874874
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
875-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
875+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
876876
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
877-
passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
878-
capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
877+
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
878+
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
879879
# Base cases for capturescalars
880-
capturescalars(f, passedargstup, scalararg) =
880+
@inline capturescalars(f, passedargstup, scalararg) =
881881
(capturelastscalar(f, scalararg), passedargstup)
882-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
882+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
883883
(passlastnonscalar(f), (passedargstup..., nonscalararg))
884-
passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
885-
capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
884+
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
885+
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
886886

887887
# NOTE: The following two method definitions work around #19096.
888888
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)

test/sparse/higherorderfns.jl

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ end
193193
end
194194
end
195195

196+
196197
@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
197198
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
198199
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
@@ -202,10 +203,10 @@ end
202203
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
203204
end
204205

205-
@testset "broadcast over combinations of scalars and sparse vectors/matrices" begin
206-
N, M, p = 10, 12, 0.3
206+
@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
207+
N, M, p = 10, 12, 0.5
207208
elT = Float64
208-
s = elT(2.0)
209+
s = Float32(2.0)
209210
V = sprand(elT, N, p)
210211
A = sprand(elT, N, M, p)
211212
fV, fA = Array(V), Array(A)
@@ -235,8 +236,29 @@ end
235236
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
236237
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
237238
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
239+
# test broadcast entry point
238240
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
239241
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
242+
# test broadcast! entry point
243+
fX = broadcast(*, sparseargs...); X = sparse(fX)
244+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
245+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
246+
X = sparse(fX) # reset / warmup for @allocated test
247+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
248+
# This test (and the analog below) fails for three reasons:
249+
# (1) In all cases, generating the closures that capture the scalar arguments
250+
# results in allocation, not sure why.
251+
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
252+
# consistently provides the correct result eltype when passed the closure
253+
# that incorporates the scalar arguments to broadcast (and, with #19667,
254+
# is inferable, so the overall return type from broadcast is inferred),
255+
# in some cases inference seems unable to determine the return type of
256+
# direct calls to that closure. This issue causes variables in both the
257+
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
258+
# the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have
259+
# inferred type Any, resulting in allocation and lackluster performance.
260+
# (3) The sparseargs... splat in the call above allocates a bit, but of course
261+
# that issue is negligible and perhaps could be accounted for in the test.
240262
end
241263
end
242264
# test combinations at the limit of inference (eight arguments net)
@@ -248,8 +270,16 @@ end
248270
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
249271
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
250272
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
273+
# test broadcast entry point
251274
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
252275
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
276+
# test broadcast! entry point
277+
fX = broadcast(*, sparseargs...); X = sparse(fX)
278+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
279+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
280+
X = sparse(fX) # reset / warmup for @allocated test
281+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
282+
# please see the note a few lines above re. this @test_broken
253283
end
254284
end
255285

0 commit comments

Comments
 (0)