Skip to content

Commit 1494f43

Browse files
authored
Merge pull request #19724 from Sacha0/mixedbc
broadcast[!] over combinations of scalars and sparse vectors/matrices
2 parents 7c34d69 + ce545a6 commit 1494f43

File tree

3 files changed

+134
-12
lines changed

3 files changed

+134
-12
lines changed

base/broadcast.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,9 @@ Note that `dest` is only used to store the result, and does not supply
202202
arguments to `f` unless it is also listed in the `As`,
203203
as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
204204
"""
205-
@inline function broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N})
205+
@inline broadcast!{N}(f, C::AbstractArray, A, Bs::Vararg{Any,N}) =
206+
broadcast_c!(f, containertype(C, A, Bs...), C, A, Bs...)
207+
@inline function broadcast_c!{N}(f, ::Type, C::AbstractArray, A, Bs::Vararg{Any,N})
206208
shape = indices(C)
207209
@boundscheck check_broadcast_indices(shape, A, Bs...)
208210
keeps, Idefaults = map_newindexer(shape, A, Bs)

base/sparse/higherorderfns.jl

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@ module HigherOrderFns
55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
77
import Base: map, map!, broadcast, broadcast!
8+
import Base.Broadcast: containertype, promote_containertype,
9+
broadcast_indices, broadcast_c, broadcast_c!
810

9-
using Base: tail, to_shape
10-
using ..SparseArrays: SparseVector, SparseMatrixCSC, indtype
11+
using Base: front, tail, to_shape
12+
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
1113

1214
# This module is organized as follows:
1315
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
@@ -837,15 +839,52 @@ end
837839

838840

839841
# (9) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
840-
#
841-
# TODO: The minimal snippet below is not satisfying: A better solution would achieve
842-
# the same for (1) all broadcast scalar types (Base.Broadcast.containertype(x) == Any?) and
843-
# (2) any combination (number, order, type mixture) of broadcast scalars.
844-
#
845-
broadcast{Tf}(f::Tf, x::Union{Number,Bool}, A::SparseMatrixCSC) = broadcast(y -> f(x, y), A)
846-
broadcast{Tf}(f::Tf, A::SparseMatrixCSC, y::Union{Number,Bool}) = broadcast(x -> f(x, y), A)
847-
# NOTE: The following two method definitions work around #19096. These definitions should
848-
# be folded into the two preceding definitions on resolution of #19096.
842+
843+
# broadcast shape promotion for combinations of sparse arrays and other types
844+
broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
845+
# broadcast container type promotion for combinations of sparse arrays and other types
846+
containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
847+
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
848+
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
849+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
850+
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
851+
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
852+
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
853+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array
854+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array
855+
856+
# broadcast[!] entry points for combinations of sparse arrays and other types
857+
@inline function broadcast_c{N}(f, ::Type{AbstractSparseArray}, mixedargs::Vararg{Any,N})
858+
parevalf, passedargstup = capturescalars(f, mixedargs)
859+
return broadcast(parevalf, passedargstup...)
860+
end
861+
@inline function broadcast_c!{N}(f, ::Type{AbstractSparseArray}, dest::SparseVecOrMat, mixedsrcargs::Vararg{Any,N})
862+
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
863+
return broadcast!(parevalf, dest, passedsrcargstup...)
864+
end
865+
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
866+
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
867+
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
868+
# vectors/matrices in mixedargs in their orginal order, and such that the result of
869+
# broadcast(parevalf, passedargstup...) is broadcast(f, mixedargs...)
870+
@inline capturescalars(f, mixedargs) =
871+
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
872+
# Recursion cases for capturescalars
873+
@inline capturescalars(f, passedargstup, scalararg, mixedargs...) =
874+
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
875+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
876+
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
877+
@inline passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
878+
@inline capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
879+
# Base cases for capturescalars
880+
@inline capturescalars(f, passedargstup, scalararg) =
881+
(capturelastscalar(f, scalararg), passedargstup)
882+
@inline capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
883+
(passlastnonscalar(f), (passedargstup..., nonscalararg))
884+
@inline passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
885+
@inline capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
886+
887+
# NOTE: The following two method definitions work around #19096.
849888
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)
850889
broadcast{Tf,T}(f::Tf, A::SparseMatrixCSC, ::Type{T}) = broadcast(x -> f(x, T), A)
851890

test/sparse/higherorderfns.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ end
193193
end
194194
end
195195

196+
196197
@testset "sparse map/broadcast with result eltype not a concrete subtype of Number (#19561/#19589)" begin
197198
intoneorfloatzero(x) = x != 0.0 ? Int(1) : Float64(x)
198199
stringorfloatzero(x) = x != 0.0 ? "Hello" : Float64(x)
@@ -202,6 +203,86 @@ end
202203
@test broadcast(stringorfloatzero, speye(4)) == sparse(broadcast(stringorfloatzero, eye(4)))
203204
end
204205

206+
@testset "broadcast[!] over combinations of scalars and sparse vectors/matrices" begin
207+
N, M, p = 10, 12, 0.5
208+
elT = Float64
209+
s = Float32(2.0)
210+
V = sprand(elT, N, p)
211+
A = sprand(elT, N, M, p)
212+
fV, fA = Array(V), Array(A)
213+
# test combinations involving one to three scalars and one to five sparse vectors/matrices
214+
spargseq, dargseq = Iterators.cycle((A, V)), Iterators.cycle((fA, fV))
215+
for nargs in 1:5 # number of tensor arguments
216+
nargsl = cld(nargs, 2) # number in "left half" of tensor arguments
217+
nargsr = fld(nargs, 2) # number in "right half" of tensor arguments
218+
spargsl = tuple(Iterators.take(spargseq, nargsl)...) # "left half" of tensor args
219+
spargsr = tuple(Iterators.take(spargseq, nargsr)...) # "right half" of tensor args
220+
dargsl = tuple(Iterators.take(dargseq, nargsl)...) # "left half" of tensor args, densified
221+
dargsr = tuple(Iterators.take(dargseq, nargsr)...) # "right half" of tensor args, densified
222+
for (sparseargs, denseargs) in ( # argument combinations including scalars
223+
# a few combinations involving one scalar
224+
((s, spargsl..., spargsr...), (s, dargsl..., dargsr...)),
225+
((spargsl..., s, spargsr...), (dargsl..., s, dargsr...)),
226+
((spargsl..., spargsr..., s), (dargsl..., dargsr..., s)),
227+
# a few combinations involving two scalars
228+
((s, spargsl..., s, spargsr...), (s, dargsl..., s, dargsr...)),
229+
((s, spargsl..., spargsr..., s), (s, dargsl..., dargsr..., s)),
230+
((spargsl..., s, spargsr..., s), (dargsl..., s, dargsr..., s)),
231+
((s, s, spargsl..., spargsr...), (s, s, dargsl..., dargsr...)),
232+
((spargsl..., s, s, spargsr...), (dargsl..., s, s, dargsr...)),
233+
((spargsl..., spargsr..., s, s), (dargsl..., dargsr..., s, s)),
234+
# a few combinations involving three scalars
235+
((s, spargsl..., s, spargsr..., s), (s, dargsl..., s, dargsr..., s)),
236+
((s, spargsl..., s, s, spargsr...), (s, dargsl..., s, s, dargsr...)),
237+
((spargsl..., s, s, spargsr..., s), (dargsl..., s, s, dargsr..., s)),
238+
((spargsl..., s, s, s, spargsr...), (dargsl..., s, s, s, dargsr...)), )
239+
# test broadcast entry point
240+
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
241+
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
242+
# test broadcast! entry point
243+
fX = broadcast(*, sparseargs...); X = sparse(fX)
244+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
245+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
246+
X = sparse(fX) # reset / warmup for @allocated test
247+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
248+
# This test (and the analog below) fails for three reasons:
249+
# (1) In all cases, generating the closures that capture the scalar arguments
250+
# results in allocation, not sure why.
251+
# (2) In some cases, though _broadcast_eltype (which wraps _return_type)
252+
# consistently provides the correct result eltype when passed the closure
253+
# that incorporates the scalar arguments to broadcast (and, with #19667,
254+
# is inferable, so the overall return type from broadcast is inferred),
255+
# in some cases inference seems unable to determine the return type of
256+
# direct calls to that closure. This issue causes variables in both the
257+
# broadcast[!] entry points (fofzeros = f(_zeros_eltypes(args...)...)) and
258+
# the driver routines (Cx in _map_zeropres! and _broadcast_zeropres!) to have
259+
# inferred type Any, resulting in allocation and lackluster performance.
260+
# (3) The sparseargs... splat in the call above allocates a bit, but of course
261+
# that issue is negligible and perhaps could be accounted for in the test.
262+
end
263+
end
264+
# test combinations at the limit of inference (eight arguments net)
265+
for (sparseargs, denseargs) in (
266+
((s, s, s, A, s, s, s, s), (s, s, s, fA, s, s, s, s)), # seven scalars, one sparse matrix
267+
((s, s, V, s, s, A, s, s), (s, s, fV, s, s, fA, s, s)), # six scalars, two sparse vectors/matrices
268+
((s, s, V, s, A, s, V, s), (s, s, fV, s, fA, s, fV, s)), # five scalars, three sparse vectors/matrices
269+
((s, V, s, A, s, V, s, A), (s, fV, s, fA, s, fV, s, fA)), # four scalars, four sparse vectors/matrices
270+
((s, V, A, s, V, A, s, A), (s, fV, fA, s, fV, fA, s, fA)), # three scalars, five sparse vectors/matrices
271+
((V, A, V, s, A, V, A, s), (fV, fA, fV, s, fA, fV, fA, s)), # two scalars, six sparse vectors/matrices
272+
((V, A, V, A, s, V, A, V), (fV, fA, fV, fA, s, fV, fA, fV)) ) # one scalar, seven sparse vectors/matrices
273+
# test broadcast entry point
274+
@test broadcast(*, sparseargs...) == sparse(broadcast(*, denseargs...))
275+
@test isa(@inferred(broadcast(*, sparseargs...)), SparseMatrixCSC{elT})
276+
# test broadcast! entry point
277+
fX = broadcast(*, sparseargs...); X = sparse(fX)
278+
@test broadcast!(*, X, sparseargs...) == sparse(broadcast!(*, fX, denseargs...))
279+
@test isa(@inferred(broadcast!(*, X, sparseargs...)), SparseMatrixCSC{elT})
280+
X = sparse(fX) # reset / warmup for @allocated test
281+
@test_broken (@allocated broadcast!(*, X, sparseargs...)) == 0
282+
# please see the note a few lines above re. this @test_broken
283+
end
284+
end
285+
205286
# Older tests of sparse broadcast, now largely covered by the tests above
206287
@testset "assorted tests of sparse broadcast over two input arguments" begin
207288
N, p = 10, 0.3

0 commit comments

Comments
 (0)