Skip to content

Commit 9bbd3f0

Browse files
committed
Test sparse broadcast! over combinations of broadcast scalars and sparse vectors/matrices.
1 parent d31145b commit 9bbd3f0

File tree

2 files changed

+41
-12
lines changed

2 files changed

+41
-12
lines changed

base/sparse/higherorderfns.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -865,21 +865,21 @@ end
865865
# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
866866
# in their orginal order, and such that the result of broadcast(g, passedargstup...) is
867867
# broadcast(f, mixedargs...)
868-
capturescalars(f, mixedargs) =
868+
@inline capturescalars(f, mixedargs) =
869869
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
870870
# Recursion cases for capturescalars
871-
capturescalars(f, passedargstup, scalararg, mixedargs...) =
871+
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
872872
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
873-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
873+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
874874
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
875-
passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
876-
capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
875+
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
876+
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
877877
# Base cases for capturescalars
878-
capturescalars(f, passedargstup, scalararg) =
878+
@inline capturescalars(f, passedargstup, scalararg) =
879879
(capturelastscalar(f, scalararg), passedargstup)
880-
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
880+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
881881
(passlastnonscalar(f), (passedargstup..., nonscalararg))
882-
passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
883-
capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
882+
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
883+
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
884884

885885
end

test/sparse/higherorderfns.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,10 @@ end
179179
end
180180
end
181181

182-
@testset "broadcast over combinations of scalars and sparse vectors/matrices" begin
183-
N, M, p = 10, 12, 0.3
182+
@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
183+
N, M, p = 10, 12, 0.5
184184
elT = Float64
185-
s = elT(2.0)
185+
s = Float32(2.0)
186186
V = sprand(elT, N, p)
187187
A = sprand(elT, N, M, p)
188188
fV, fA = Array(V), Array(A)
@@ -212,8 +212,29 @@ end
212212
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
213213
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
214214
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
215+
# test broadcast entry point
215216
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
216217
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
218+
# test broadcast! entry point
219+
fX = broadcast(*, sparseargs...); X = sparse(fX)
220+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
221+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
222+
X = sparse(fX) # reset / warmup for @allocated test
223+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
224+
# This test (and the analog below) fails for three reasons:
225+
# (1) In all cases, generating the closures that capture the scalar arguments
226+
# results in allocation, not sure why.
227+
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
228+
# consistently provides the correct result eltype when passed the closure
229+
# that incorporates the scalar arguments to broadcast (and, with #19667,
230+
# is inferable, so the overall return type from broadcast is inferred),
231+
# in some cases inference seems unable to determine the return type of
232+
# direct calls to that closure. This issue causes variables in in both the
233+
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
234+
# the driver routines (Cx in _map_zeropres! and _broadcast_zeroprs!) to have
235+
# inferred type Any, resulting in allocation and lackluster performance.
236+
# (3) The sparseargs... splat in the call above allocates a bit, but of course
237+
# that issue is negligible and perhaps could be accounted for in the test.
217238
end
218239
end
219240
# test combinations at the limit of inference (eight arguments net)
@@ -225,8 +246,16 @@ end
225246
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
226247
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
227248
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
249+
# test broadcast entry point
228250
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
229251
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
252+
# test broadcast! entry point
253+
fX = broadcast(*, sparseargs...); X = sparse(fX)
254+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
255+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
256+
X = sparse(fX) # reset / warmup for @allocated test
257+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
258+
# please see the note a few lines above re. this @test_broken
230259
end
231260
end
232261

0 commit comments

Comments
 (0)