@@ -5,9 +5,10 @@ module HigherOrderFns
5
5
# This module provides higher order functions specialized for sparse arrays,
6
6
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7
7
import Base: map, map!, broadcast, broadcast!
8
+ import Base. Broadcast: containertype, promote_containertype, broadcast_indices, broadcast_c
8
9
9
- using Base: tail, to_shape
10
- using .. SparseArrays: SparseVector, SparseMatrixCSC, indtype
10
+ using Base: front, tail, to_shape
11
+ using .. SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseArray, indtype
11
12
12
13
# This module is organized as follows:
13
14
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
@@ -835,16 +836,45 @@ end
835
836
836
837
837
838
# (9) broadcast[!] over combinations of broadcast scalars and sparse vectors/matrices
838
- #
839
- # TODO : The minimal snippet below is not satisfying: A better solution would achieve
840
- # the same for (1) all broadcast scalar types (Base.Broadcast.containertype(x) == Any?) and
841
- # (2) any combination (number, order, type mixture) of broadcast scalars.
842
- #
843
- broadcast {Tf} (f:: Tf , x:: Union{Number,Bool} , A:: SparseMatrixCSC ) = broadcast (y -> f (x, y), A)
844
- broadcast {Tf} (f:: Tf , A:: SparseMatrixCSC , y:: Union{Number,Bool} ) = broadcast (x -> f (x, y), A)
845
- # NOTE: The following two method definitions work around #19096. These definitions should
846
- # be folded into the two preceding definitions on resolution of #19096.
847
- broadcast {Tf,T} (f:: Tf , :: Type{T} , A:: SparseMatrixCSC ) = broadcast (y -> f (T, y), A)
848
- broadcast {Tf,T} (f:: Tf , A:: SparseMatrixCSC , :: Type{T} ) = broadcast (x -> f (x, T), A)
839
+
840
+ # broadcast shape promotion for combinations of sparse arrays and other types
841
+ broadcast_indices (:: Type{AbstractSparseArray} , A) = indices (A)
842
+ # broadcast container type promotion for combinations of sparse arrays and other types
843
+ containertype {T<:SparseVecOrMat} (:: Type{T} ) = AbstractSparseArray
844
+ # combinations of sparse arrays with broadcast scalars should yield sparse arrays
845
+ promote_containertype (:: Type{Any} , :: Type{AbstractSparseArray} ) = AbstractSparseArray
846
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{Any} ) = AbstractSparseArray
847
+ # combinations of sparse arrays with anything else should fall back to generic dense broadcast
848
+ promote_containertype (:: Type{Array} , :: Type{AbstractSparseArray} ) = Array
849
+ promote_containertype (:: Type{Tuple} , :: Type{AbstractSparseArray} ) = Array
850
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{Array} ) = Array
851
+ promote_containertype (:: Type{AbstractSparseArray} , :: Type{Tuple} ) = Array
852
+
853
+ # broadcast entry point for combinations of sparse arrays and other types
854
+ function broadcast_c (f, :: Type{AbstractSparseArray} , mixedargs... )
855
+ parevalf, passedargstup = capturescalars (f, mixedargs)
856
+ return broadcast (parevalf, passedargstup... )
857
+ end
858
+ # capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
859
+ # broadcast scalar arguments (mixedargs), and returns a function (parevalf) and a reduced
860
+ # argument tuple (passedargstup) containing only the sparse vectors/matrices in mixedargs
861
+ # in their orginal order, and such that the result of broadcast(g, passedargstup...) is
862
+ # broadcast(f, mixedargs...)
863
+ capturescalars (f, mixedargs) =
864
+ capturescalars ((passed, tofill) -> f (tofill... ), (), mixedargs... )
865
+ # Recursion cases for capturescalars
866
+ capturescalars (f, passedargstup, scalararg, mixedargs... ) =
867
+ capturescalars (capturescalar (f, scalararg), passedargstup, mixedargs... )
868
+ capturescalars (f, passedargstup, nonscalararg:: SparseVecOrMat , mixedargs... ) =
869
+ capturescalars (passnonscalar (f), (passedargstup... , nonscalararg), mixedargs... )
870
+ passnonscalar (f) = (passed, tofill) -> f (Base. front (passed), (last (passed), tofill... ))
871
+ capturescalar (f, scalararg) = (passed, tofill) -> f (passed, (scalararg, tofill... ))
872
+ # Base cases for capturescalars
873
+ capturescalars (f, passedargstup, scalararg) =
874
+ (capturelastscalar (f, scalararg), passedargstup)
875
+ capturescalars (f, passedargstup, nonscalararg:: SparseVecOrMat ) =
876
+ (passlastnonscalar (f), (passedargstup... , nonscalararg))
877
+ passlastnonscalar (f) = (passed... ) -> f (Base. front (passed), (last (passed),))
878
+ capturelastscalar (f, scalararg) = (passed... ) -> f (passed, (scalararg,))
849
879
850
880
end
0 commit comments