Skip to content

Commit b93c864

Browse files
committed
Test sparse broadcast! over combinations of broadcast scalars and sparse vectors/matrices.
1 parent 8c43f48 commit b93c864

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

base/sparse/higherorderfns.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -864,22 +864,22 @@ end
864864
# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
865865
# in their orginal order, and such that the result of broadcast(g, passedargstup...) is
866866
# broadcast(f, mixedargs...)
867-
capturescalars(f, mixedargs) =
867+
@inline capturescalars(f, mixedargs) =
868868
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
869869
# Recursion cases for capturescalars
870-
capturescalars(f, passedargstup, scalararg, mixedargs...) =
870+
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
871871
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
872-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
872+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
873873
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
874-
passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
875-
capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
874+
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
875+
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
876876
# Base cases for capturescalars
877-
capturescalars(f, passedargstup, scalararg) =
877+
@inline capturescalars(f, passedargstup, scalararg) =
878878
(capturelastscalar(f, scalararg), passedargstup)
879-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
879+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
880880
(passlastnonscalar(f), (passedargstup..., nonscalararg))
881-
passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
882-
capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
881+
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
882+
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
883883

884884
# NOTE: The following two method definitions work around #19096.
885885
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)

test/sparse/higherorderfns.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ end
193193
end
194194
end
195195

196-
@testset "broadcast over combinations of scalars and sparse vectors/matrices" begin
197-
N, M, p = 10, 12, 0.3
196+
@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
197+
N, M, p = 10, 12, 0.5
198198
elT = Float64
199-
s = elT(2.0)
199+
s = Float32(2.0)
200200
V = sprand(elT, N, p)
201201
A = sprand(elT, N, M, p)
202202
fV, fA = Array(V), Array(A)
@@ -226,8 +226,29 @@ end
226226
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
227227
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
228228
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
229+
# test broadcast entry point
229230
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
230231
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
232+
# test broadcast! entry point
233+
fX = broadcast(*, sparseargs...); X = sparse(fX)
234+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
235+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
236+
X = sparse(fX) # reset / warmup for @allocated test
237+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
238+
# This test (and the analog below) fails for three reasons:
239+
# (1) In all cases, generating the closures that capture the scalar arguments
240+
# results in allocation, not sure why.
241+
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
242+
# consistently provides the correct result eltype when passed the closure
243+
# that incorporates the scalar arguments to broadcast (and, with #19667,
244+
# is inferable, so the overall return type from broadcast is inferred),
245+
# in some cases inference seems unable to determine the return type of
246+
# direct calls to that closure. This issue causes variables in both the
247+
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
248+
# the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have
249+
# inferred type Any, resulting in allocation and lackluster performance.
250+
# (3) The sparseargs... splat in the call above allocates a bit, but of course
251+
# that issue is negligible and perhaps could be accounted for in the test.
231252
end
232253
end
233254
# test combinations at the limit of inference (eight arguments net)
@@ -239,8 +260,16 @@ end
239260
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
240261
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
241262
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
263+
# test broadcast entry point
242264
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
243265
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
266+
# test broadcast! entry point
267+
fX = broadcast(*, sparseargs...); X = sparse(fX)
268+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
269+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
270+
X = sparse(fX) # reset / warmup for @allocated test
271+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
272+
# please see the note a few lines above re. this @test_broken
244273
end
245274
end
246275

0 commit comments

Comments
 (0)