Skip to content

Commit 422410f

Browse files
committed
Extend sparse broadcast (non-!) to combinations of broadcast scalars and sparse vectors/matrices.
1 parent c1cde2f commit 422410f

File tree

1 file changed

+43
-13
lines changed

1 file changed

+43
-13
lines changed

base/sparse/higherorderfns.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ module HigherOrderFns
55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
77
import Base: map, map!, broadcast, broadcast!
8+
import Base.Broadcast: containertype, promote_containertype, broadcast_indices, broadcast_c
89

9-
using Base: tail, to_shape
10-
using ..SparseArrays: SparseVector, SparseMatrixCSC, indtype
10+
using Base: front, tail, to_shape
11+
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
1112

1213
# This module is organized as follows:
1314
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
@@ -835,16 +836,45 @@ end
835836

836837

837838
# (9) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
838-
#
839-
# TODO: The minimal snippet below is not satisfying: A better solution would achieve
840-
# the same for (1) all broadcast scalar types (Base.Broadcast.containertype(x) == Any?) and
841-
# (2) any combination (number, order, type mixture) of broadcast scalars.
842-
#
843-
broadcast{Tf}(f::Tf, x::Union{Number,Bool}, A::SparseMatrixCSC) = broadcast(y -> f(x, y), A)
844-
broadcast{Tf}(f::Tf, A::SparseMatrixCSC, y::Union{Number,Bool}) = broadcast(x -> f(x, y), A)
845-
# NOTE: The following two method definitions work around #19096. These definitions should
846-
# be folded into the two preceding definitions on resolution of #19096.
847-
broadcast{Tf,T}(f::Tf, ::Type{T}, A::SparseMatrixCSC) = broadcast(y -> f(T, y), A)
848-
broadcast{Tf,T}(f::Tf, A::SparseMatrixCSC, ::Type{T}) = broadcast(x -> f(x, T), A)
839+
840+
# broadcast shape promotion for combinations of sparse arrays and other types
841+
broadcast_indices(::Type{AbstractSparseArray}, A) = indices(A)
842+
# broadcast container type promotion for combinations of sparse arrays and other types
843+
containertype{T<:SparseVecOrMat}(::Type{T}) = AbstractSparseArray
844+
# combinations of sparse arrays with broadcast scalars should yield sparse arrays
845+
promote_containertype(::Type{Any}, ::Type{AbstractSparseArray}) = AbstractSparseArray
846+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Any}) = AbstractSparseArray
847+
# combinations of sparse arrays with anything else should fall back to generic dense broadcast
848+
promote_containertype(::Type{Array}, ::Type{AbstractSparseArray}) = Array
849+
promote_containertype(::Type{Tuple}, ::Type{AbstractSparseArray}) = Array
850+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Array}) = Array
851+
promote_containertype(::Type{AbstractSparseArray}, ::Type{Tuple}) = Array
852+
853+
# broadcast entry point for combinations of sparse arrays and other types
854+
function broadcast_c(f, ::Type{AbstractSparseArray}, mixedargs...)
855+
parevalf, passedargstup = capturescalars(f, mixedargs)
856+
return broadcast(parevalf, passedargstup...)
857+
end
858+
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
859+
# broadcast scalar arguments (mixedargs), and returns a function (parevalf) and a reduced
860+
# argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
861+
# in their orginal order, and such that the result of broadcast(g, passedargstup...) is
862+
# broadcast(f, mixedargs...)
863+
capturescalars(f, mixedargs) =
864+
capturescalars((passed, tofill) -> f(tofill...), (), mixedargs...)
865+
# Recursion cases for capturescalars
866+
capturescalars(f, passedargstup, scalararg, mixedargs...) =
867+
capturescalars(capturescalar(f, scalararg), passedargstup, mixedargs...)
868+
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat, mixedargs...) =
869+
capturescalars(passnonscalar(f), (passedargstup..., nonscalararg), mixedargs...)
870+
passnonscalar(f) = (passed, tofill) -> f(Base.front(passed), (last(passed), tofill...))
871+
capturescalar(f, scalararg) = (passed, tofill) -> f(passed, (scalararg, tofill...))
872+
# Base cases for capturescalars
873+
capturescalars(f, passedargstup, scalararg) =
874+
(capturelastscalar(f, scalararg), passedargstup)
875+
capturescalars(f, passedargstup, nonscalararg::SparseVecOrMat) =
876+
(passlastnonscalar(f), (passedargstup..., nonscalararg))
877+
passlastnonscalar(f) = (passed...) -> f(Base.front(passed), (last(passed),))
878+
capturelastscalar(f, scalararg) = (passed...) -> f(passed, (scalararg,))
849879

850880
end

0 commit comments

Comments
 (0)