Skip to content

Commit 99b9e7d

Browse files
mbaumantimholyvtjnashajkeller34
committed
Customizable lazy fused broadcasting in pure Julia
This patch represents the combined efforts of four individuals, over 60 commits, and an iterated design over (at least) three pull requests that spanned nearly an entire year (closes #22063, #23692, #25377 by superceding them). This introduces a pure Julia data structure that represents a fused broadcast expression. For example, the expression `2 .* (x .+ 1)` lowers to: ```julia julia> Meta.@lower 2 .* (x .+ 1) :($(Expr(:thunk, CodeInfo(:(begin Core.SSAValue(0) = (Base.getproperty)(Base.Broadcast, :materialize) Core.SSAValue(1) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(2) = (Base.getproperty)(Base.Broadcast, :make) Core.SSAValue(3) = (Core.SSAValue(2))(+, x, 1) Core.SSAValue(4) = (Core.SSAValue(1))(*, 2, Core.SSAValue(3)) Core.SSAValue(5) = (Core.SSAValue(0))(Core.SSAValue(4)) return Core.SSAValue(5) end))))) ``` Or, slightly more readably as: ```julia using .Broadcast: materialize, make materialize(make(*, 2, make(+, x, 1))) ``` The `Broadcast.make` function serves two purposes. Its primary purpose is to construct the `Broadcast.Broadcasted` objects that hold onto the function, the tuple of arguments (potentially including nested `Broadcasted` arguments), and sometimes a set of `axes` to include knowledge of the outer shape. The secondary purpose, however, is to allow an "out" for objects that _don't_ want to participate in fusion. For example, if `x` is a range in the above `2 .* (x .+ 1)` expression, it needn't allocate an array and operate elementwise — it can just compute and return a new range. Thus custom structures are able to specialize `Broadcast.make(f, args...)` just as they'd specialize on `f` normally to return an immediate result. `Broadcast.materialize` is identity for everything _except_ `Broadcasted` objects for which it allocates an appropriate result and computes the broadcast. It does two things: it `initialize`s the outermost `Broadcasted` object to compute its axes and then `copy`s it. Similarly, an in-place fused broadcast like `y .= 2 .* (x .+ 1)` uses the exact same expression tree to compute the right-hand side of the expression as above, and then uses `materialize!(y, make(*, 2, make(+, x, 1)))` to `instantiate` the `Broadcasted` expression tree and then `copyto!` it into the given destination. All-together, this forms a complete API for custom types to extend and customize the behavior of broadcast (fixes #22060). It uses the existing `BroadcastStyle`s throughout to simplify dispatch on many arguments: * Custom types can opt-out of broadcast fusion by specializing `Broadcast.make(f, args...)` or `Broadcast.make(::BroadcastStyle, f, args...)`. * The `Broadcasted` object computes and stores the type of the combined `BroadcastStyle` of its arguments as its first type parameter, allowing for easy dispatch and specialization. * Custom Broadcast storage is still allocated via `broadcast_similar`, however instead of passing just a function as a first argument, the entire `Broadcasted` object is passed as a final argument. This potentially allows for much more runtime specialization dependent upon the exact expression given. * Custom broadcast implmentations for a `CustomStyle` are defined by specializing `copy(bc::Broadcasted{CustomStyle})` or `copyto!(dest::AbstractArray, bc::Broadcasted{CustomStyle})`. * Fallback broadcast specializations for a given output object of type `Dest` (for the `DefaultArrayStyle` or another such style that hasn't implemented assignments into such an object) are defined by specializing `copyto(dest::Dest, bc::Broadcasted{Nothing})`. As it fully supports range broadcasting, this now deprecates `(1:5) + 2` to `.+`, just as had been done for all `AbstractArray`s in general. As a first-mover proof of concept, LinearAlgebra uses this new system to improve broadcasting over structured arrays. Before, broadcasting over a structured matrix would result in a sparse array. Now, broadcasting over a structured matrix will _either_ return an appropriately structured matrix _or_ a dense array. This does incur a type instability (in the form of a discriminated union) in some situations, but thanks to type-based introspection of the `Broadcasted` wrapper commonly used functions can be special cased to be type stable. For example: ```julia julia> f(d) = round.(Int, d) f (generic function with 1 method) julia> @inferred f(Diagonal(rand(3))) 3×3 Diagonal{Int64,Array{Int64,1}}: 0 ⋅ ⋅ ⋅ 0 ⋅ ⋅ ⋅ 1 julia> @inferred Diagonal(rand(3)) .* 3 ERROR: return type Diagonal{Float64,Array{Float64,1}} does not match inferred return type Union{Array{Float64,2}, Diagonal{Float64,Array{Float64,1}}} Stacktrace: [1] error(::String) at ./error.jl:33 [2] top-level scope julia> @inferred Diagonal(1:4) .+ Bidiagonal(rand(4), rand(3), 'U') .* Tridiagonal(1:3, 1:4, 1:3) 4×4 Tridiagonal{Float64,Array{Float64,1}}: 1.30771 0.838589 ⋅ ⋅ 0.0 3.89109 0.0459757 ⋅ ⋅ 0.0 4.48033 2.51508 ⋅ ⋅ 0.0 6.23739 ``` In addition to the issues referenced above, it fixes: * Fixes #19313, #22053, #23445, and #24586: Literals are no longer treated specially in a fused broadcast; they're just arguments in a `Broadcasted` object like everything else. * Fixes #21094: Since broadcasting is now represented by a pure Julia datastructure it can be created within `@generated` functions and serialized. * Fixes #26097: The fallback destination-array specialization method of `copyto!` is specifically implemented as `Broadcasted{Nothing}` and will not be confused by `nothing` arguments. * Fixes the broadcast-specific element of #25499: The default base broadcast implementation no longer depends upon `Base._return_type` to allocate its array (except in the empty or concretely-type cases). Note that the sparse implementation (#19595) is still dependent upon inference and is _not_ fixed. * Fixes #25340: Functions are treated like normal values just like arguments and only evaluated once. * Fixes #22255, and is performant with 12+ fused broadcasts. Okay, that one was fixed on master already, but this fixes it now, too. * Fixes #25521. * The performance of this patch has been thoroughly tested through its iterative development process in #25377. There remain [two classes of performance regressions](#25377) that Nanosoldier flagged. * #25691: Propagation of constant literals sill lose their constant-ness upon going through the broadcast machinery. I believe quite a large number of functions would need to be marked as `@pure` to support this -- including functions that are intended to be specialized. (For bookkeeping, this is the squashed version of the [teh-jn/lazydotfuse](JuliaLang/julia#25377) branch as of a1d4e7ec9756ada74fb48f2c514615b9d981cf5c. Squashed and separated out to make it easier to review and commit) Co-authored-by: Tim Holy <[email protected]> Co-authored-by: Jameson Nash <[email protected]> Co-authored-by: Andrew Keller <[email protected]>
1 parent 3aea74e commit 99b9e7d

File tree

2 files changed

+172
-166
lines changed

2 files changed

+172
-166
lines changed

src/higherorderfns.jl

+152-106
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,16 @@ module HigherOrderFns
44

55
# This module provides higher order functions specialized for sparse arrays,
66
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
7-
import Base: map, map!, broadcast, broadcast!
7+
import Base: map, map!, broadcast, copy, copyto!
88

99
using Base: front, tail, to_shape
1010
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
1111
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
12-
using Base.Broadcast: BroadcastStyle
12+
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
1313
using LinearAlgebra
1414

1515
# This module is organized as follows:
16+
# (0) Define BroadcastStyle rules and convenience types for dispatch
1617
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
1718
# map[!]/broadcast[!]'s purposes. The methods below are written against this interface.
1819
# (2) Define entry points for map[!] (short children of _map_[not]zeropres!).
@@ -29,11 +30,79 @@ using LinearAlgebra
2930
# (12) Define map[!] methods handling combinations of sparse and structured matrices.
3031

3132

33+
# (0) BroadcastStyle rules and convenience types for dispatch
34+
35+
SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
36+
37+
# broadcast container type promotion for combinations of sparse arrays and other types
38+
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
39+
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
40+
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
41+
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
42+
const SPVM = Union{SparseVecStyle,SparseMatStyle}
43+
44+
# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
45+
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
46+
# Fall back to DefaultArrayStyle for higher dimensionality.
47+
SparseVecStyle(::Val{0}) = SparseVecStyle()
48+
SparseVecStyle(::Val{1}) = SparseVecStyle()
49+
SparseVecStyle(::Val{2}) = SparseMatStyle()
50+
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
51+
SparseMatStyle(::Val{0}) = SparseMatStyle()
52+
SparseMatStyle(::Val{1}) = SparseMatStyle()
53+
SparseMatStyle(::Val{2}) = SparseMatStyle()
54+
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
55+
56+
Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()
57+
58+
# Tuples promote to dense
59+
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
60+
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
61+
62+
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
63+
PromoteToSparse(::Val{0}) = PromoteToSparse()
64+
PromoteToSparse(::Val{1}) = PromoteToSparse()
65+
PromoteToSparse(::Val{2}) = PromoteToSparse()
66+
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
67+
68+
const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
69+
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
70+
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
71+
72+
Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
73+
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
74+
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
75+
76+
Broadcast.BroadcastStyle(::SPVM, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse()
77+
Broadcast.BroadcastStyle(::PromoteToSparse, ::LinearAlgebra.StructuredMatrixStyle{<:StructuredMatrix}) = PromoteToSparse()
78+
79+
Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
80+
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
81+
82+
# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
83+
# could report itself as a DefaultArrayStyle().
84+
# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
85+
is_supported_sparse_broadcast() = true
86+
is_supported_sparse_broadcast(::AbstractArray, rest...) = false
87+
is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...)
88+
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
89+
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
90+
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
91+
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
92+
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)
93+
94+
# Dispatch on broadcast operations by number of arguments
95+
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
96+
Broadcasted{Style,Axes,F,Tuple{}}
97+
const SpBroadcasted1{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat}} =
98+
Broadcasted{Style,Axes,F,Args}
99+
const SpBroadcasted2{Style<:SPVM,Axes,F,Args<:Tuple{SparseVecOrMat,SparseVecOrMat}} =
100+
Broadcasted{Style,Axes,F,Args}
101+
32102
# (1) The definitions below provide a common interface to sparse vectors and matrices
33103
# sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors
34104
# as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views
35105
# sparse vectors in practice.
36-
SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
37106
@inline numrows(A::SparseVector) = A.n
38107
@inline numrows(A::SparseMatrixCSC) = A.m
39108
@inline numcols(A::SparseVector) = 1
@@ -85,18 +154,18 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
85154
fofzeros = f(_zeros_eltypes(A, Bs...)...)
86155
fpreszeros = _iszero(fofzeros)
87156
maxnnzC = fpreszeros ? min(length(A), _sumnnzs(A, Bs...)) : length(A)
88-
entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...)
157+
entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...))
89158
indextypeC = _promote_indtype(A, Bs...)
90159
C = _allocres(size(A), indextypeC, entrytypeC, maxnnzC)
91160
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
92161
_map_notzeropres!(f, fofzeros, C, A, Bs...)
93162
end
94163
# (3) broadcast[!] entry points
95-
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
96-
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
164+
copy(bc::SpBroadcasted1) = _noshapecheck_map(bc.f, bc.args[1])
97165

98-
@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
166+
@inline function copyto!(C::SparseVecOrMat, bc::Broadcasted0{Nothing})
99167
isempty(C) && return _finishempty!(C)
168+
f = bc.f
100169
fofnoargs = f()
101170
if _iszero(fofnoargs) # f() is zero, so empty C
102171
trimstorage!(C, 0)
@@ -109,19 +178,12 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
109178
return C
110179
end
111180

112-
# the following three similar defs are necessary for type stability in the mixed vector/matrix case
113-
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
114-
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
115-
broadcast(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf,N} =
116-
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
117-
broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} =
118-
_diffshape_broadcast(f, A, Bs...)
119181
function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
120182
fofzeros = f(_zeros_eltypes(A, Bs...)...)
121183
fpreszeros = _iszero(fofzeros)
122184
indextypeC = _promote_indtype(A, Bs...)
123-
entrytypeC = Base.Broadcast.combine_eltypes(f, A, Bs...)
124-
shapeC = to_shape(Base.Broadcast.combine_indices(A, Bs...))
185+
entrytypeC = Base.Broadcast.combine_eltypes(f, (A, Bs...))
186+
shapeC = to_shape(Base.Broadcast.combine_axes(A, Bs...))
125187
maxnnzC = fpreszeros ? _checked_maxnnzbcres(shapeC, A, Bs...) : _densennz(shapeC)
126188
C = _allocres(shapeC, indextypeC, entrytypeC, maxnnzC)
127189
return fpreszeros ? _broadcast_zeropres!(f, C, A, Bs...) :
@@ -141,6 +203,10 @@ end
141203
@inline _aresameshape(A, B) = size(A) == size(B)
142204
@inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false
143205
@inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match"))
206+
@inline _all_args_isa(t::Tuple{Any}, ::Type{T}) where T = isa(t[1], T)
207+
@inline _all_args_isa(t::Tuple{Any,Vararg{Any}}, ::Type{T}) where T = isa(t[1], T) & _all_args_isa(tail(t), T)
208+
@inline _all_args_isa(t::Tuple{Broadcasted}, ::Type{T}) where T = _all_args_isa(t[1].args, T)
209+
@inline _all_args_isa(t::Tuple{Broadcasted,Vararg{Any}}, ::Type{T}) where T = _all_args_isa(t[1].args, T) & _all_args_isa(tail(t), T)
144210
@inline _densennz(shape::NTuple{1}) = shape[1]
145211
@inline _densennz(shape::NTuple{2}) = shape[1] * shape[2]
146212
_maxnnzfrom(shape::NTuple{1}, A) = nnz(A) * div(shape[1], A.n)
@@ -887,37 +953,56 @@ end
887953

888954
# (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices
889955

890-
# broadcast container type promotion for combinations of sparse arrays and other types
891-
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
892-
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
893-
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
894-
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
895-
const SPVM = Union{SparseVecStyle,SparseMatStyle}
956+
# broadcast entry points for combinations of sparse arrays and other (scalar) types
957+
@inline function copy(bc::Broadcasted{<:SPVM})
958+
bcf = flatten(bc)
959+
return _copy(bcf.f, bcf.args...)
960+
end
896961

897-
# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
898-
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
899-
# Fall back to DefaultArrayStyle for higher dimensionality.
900-
SparseVecStyle(::Val{0}) = SparseVecStyle()
901-
SparseVecStyle(::Val{1}) = SparseVecStyle()
902-
SparseVecStyle(::Val{2}) = SparseMatStyle()
903-
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
904-
SparseMatStyle(::Val{0}) = SparseMatStyle()
905-
SparseMatStyle(::Val{1}) = SparseMatStyle()
906-
SparseMatStyle(::Val{2}) = SparseMatStyle()
907-
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
962+
_copy(f, args::SparseVector...) = _shapecheckbc(f, args...)
963+
_copy(f, args::SparseMatrixCSC...) = _shapecheckbc(f, args...)
964+
_copy(f, args::SparseVecOrMat...) = _diffshape_broadcast(f, args...)
965+
# Otherwise, we incorporate scalars into the function and re-dispatch
966+
function _copy(f, args...)
967+
parevalf, passedargstup = capturescalars(f, args)
968+
return _copy(parevalf, passedargstup...)
969+
end
908970

909-
Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()
971+
function _shapecheckbc(f, args...)
972+
_aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...)
973+
end
910974

911-
# Tuples promote to dense
912-
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
913-
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
914975

915-
# broadcast entry points for combinations of sparse arrays and other (scalar) types
916-
function broadcast(f, ::SPVM, ::Nothing, ::Nothing, mixedargs::Vararg{Any,N}) where N
917-
parevalf, passedargstup = capturescalars(f, mixedargs)
918-
return broadcast(parevalf, passedargstup...)
976+
@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{<:SPVM})
977+
if bc.f === identity && bc isa SpBroadcasted1 && Base.axes(dest) == (A = bc.args[1]; Base.axes(A))
978+
return copyto!(dest, A)
979+
end
980+
bcf = flatten(bc)
981+
As = map(arg->Base.unalias(dest, arg), bcf.args)
982+
return _copyto!(bcf.f, dest, As...)
983+
end
984+
985+
@inline function _copyto!(f, dest, As::SparseVecOrMat...)
986+
_aresameshape(dest, As...) && return _noshapecheck_map!(f, dest, As...)
987+
Base.Broadcast.check_broadcast_axes(axes(dest), As...)
988+
fofzeros = f(_zeros_eltypes(As...)...)
989+
if _iszero(fofzeros)
990+
return _broadcast_zeropres!(f, dest, As...)
991+
else
992+
return _broadcast_notzeropres!(f, fofzeros, dest, As...)
993+
end
994+
end
995+
996+
@inline function _copyto!(f, dest, args...)
997+
# args contains nothing but SparseVecOrMat and scalars
998+
# See below for capturescalars
999+
parevalf, passedsrcargstup = capturescalars(f, args)
1000+
_copyto!(parevalf, dest, passedsrcargstup...)
1001+
end
1002+
1003+
struct CapturedScalars{F, Args, Order}
1004+
args::Args
9191005
end
920-
# for broadcast! see (11)
9211006

9221007
# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
9231008
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
@@ -930,6 +1015,13 @@ end
9301015
return (parevalf, passedsrcargstup)
9311016
end
9321017
end
1018+
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
1019+
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
1020+
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
1021+
@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
1022+
capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
1023+
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
1024+
capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))
9331025

9341026
nonscalararg(::SparseVecOrMat) = true
9351027
nonscalararg(::Any) = false
@@ -942,11 +1034,17 @@ end
9421034
@inline function _capturescalars(arg, mixedargs...)
9431035
let (rest, f) = _capturescalars(mixedargs...)
9441036
if nonscalararg(arg)
945-
return (arg, rest...), (head, tail...) -> (head, f(tail...)...) # pass-through to broadcast
1037+
return (arg, rest...), @inline function(head, tail...)
1038+
(head, f(tail...)...)
1039+
end # pass-through to broadcast
9461040
elseif scalarwrappedarg(arg)
947-
return rest, (tail...) -> (arg[], f(tail...)...) # unwrap and add back scalararg after (in makeargs)
1041+
return rest, @inline function(tail...)
1042+
(arg[], f(tail...)...) # TODO: This can put a Type{T} in a tuple
1043+
end # unwrap and add back scalararg after (in makeargs)
9481044
else
949-
return rest, (tail...) -> (arg, f(tail...)...) # add back scalararg after (in makeargs)
1045+
return rest, @inline function(tail...)
1046+
(arg, f(tail...)...)
1047+
end # add back scalararg after (in makeargs)
9501048
end
9511049
end
9521050
end
@@ -972,69 +1070,18 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
9721070
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
9731071
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
9741072

975-
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
976-
PromoteToSparse(::Val{0}) = PromoteToSparse()
977-
PromoteToSparse(::Val{1}) = PromoteToSparse()
978-
PromoteToSparse(::Val{2}) = PromoteToSparse()
979-
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
980-
981-
const StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
982-
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
983-
Broadcast.BroadcastStyle(::Type{<:Adjoint{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
984-
Broadcast.BroadcastStyle(::Type{<:Transpose{T,<:Union{SparseVector,SparseMatrixCSC}} where T}) = PromoteToSparse()
985-
986-
Broadcast.BroadcastStyle(s::SPVM, ::Broadcast.AbstractArrayStyle{0}) = s
987-
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
988-
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
989-
990-
Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
991-
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
992-
993-
# FIXME: currently sparse broadcasts are only well-tested on known array types, while any AbstractArray
994-
# could report itself as a DefaultArrayStyle().
995-
# See https://github.com/JuliaLang/julia/pull/23939#pullrequestreview-72075382 for more details
996-
is_supported_sparse_broadcast() = true
997-
is_supported_sparse_broadcast(::AbstractArray, rest...) = false
998-
is_supported_sparse_broadcast(::AbstractSparseArray, rest...) = is_supported_sparse_broadcast(rest...)
999-
is_supported_sparse_broadcast(::StructuredMatrix, rest...) = is_supported_sparse_broadcast(rest...)
1000-
is_supported_sparse_broadcast(::Array, rest...) = is_supported_sparse_broadcast(rest...)
1001-
is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_supported_sparse_broadcast(t.parent, rest...)
1002-
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
1003-
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)
1004-
function broadcast(f, s::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N}
1005-
if is_supported_sparse_broadcast(As...)
1006-
return broadcast(f, map(_sparsifystructured, As)...)
1073+
function copy(bc::Broadcasted{PromoteToSparse})
1074+
bcf = flatten(bc)
1075+
if is_supported_sparse_broadcast(bcf.args...)
1076+
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
10071077
else
1008-
return broadcast(f, Broadcast.ArrayConflict(), nothing, nothing, As...)
1078+
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{2}}, bc))
10091079
end
10101080
end
10111081

1012-
# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
1013-
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
1014-
# we can handle it here, otherwise see below for the promotion machinery.
1015-
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
1016-
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
1017-
return copyto!(dest, A)
1018-
end
1019-
A′ = Base.unalias(dest, A)
1020-
Bs′ = map(B->Base.unalias(dest, B), Bs)
1021-
_aresameshape(dest, A′, Bs′...) && return _noshapecheck_map!(f, dest, A′, Bs′...)
1022-
Base.Broadcast.check_broadcast_indices(axes(dest), A′, Bs′...)
1023-
fofzeros = f(_zeros_eltypes(A′, Bs′...)...)
1024-
fpreszeros = _iszero(fofzeros)
1025-
fpreszeros ? _broadcast_zeropres!(f, dest, A′, Bs′...) :
1026-
_broadcast_notzeropres!(f, fofzeros, dest, A′, Bs′...)
1027-
return dest
1028-
end
1029-
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
1030-
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
1031-
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
1032-
broadcast!(parevalf, dest, passedsrcargstup...)
1033-
return dest
1034-
end
1035-
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
1036-
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
1037-
return dest
1082+
@inline function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse})
1083+
bcf = flatten(bc)
1084+
broadcast!(bcf.f, dest, map(_sparsifystructured, bcf.args)...)
10381085
end
10391086

10401087
_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
@@ -1047,8 +1094,7 @@ _sparsifystructured(x) = x
10471094

10481095

10491096
# (12) map[!] over combinations of sparse and structured matrices
1050-
SparseOrStructuredMatrix = Union{SparseMatrixCSC,StructuredMatrix}
1051-
map(f::Tf, A::StructuredMatrix) where {Tf} = _noshapecheck_map(f, _sparsifystructured(A))
1097+
SparseOrStructuredMatrix = Union{SparseMatrixCSC,LinearAlgebra.StructuredMatrix}
10521098
map(f::Tf, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =
10531099
(_checksameshape(A, Bs...); _noshapecheck_map(f, _sparsifystructured(A), map(_sparsifystructured, Bs)...))
10541100
map!(f::Tf, C::SparseMatrixCSC, A::SparseOrStructuredMatrix, Bs::Vararg{SparseOrStructuredMatrix,N}) where {Tf,N} =

0 commit comments

Comments
 (0)