Skip to content

Commit ed52263

Browse files
committed
Centralize broadcast support for structured matrices
1 parent c892056 commit ed52263

File tree

7 files changed

+71
-36
lines changed

7 files changed

+71
-36
lines changed

base/broadcast.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ BroadcastStyle(::Type{<:Ref}) = DefaultArrayStyle{0}()
119119
# 3 or more arguments still return an `ArrayConflict`.
120120
struct ArrayConflict <: AbstractArrayStyle{Any} end
121121

122+
# This will be used for Diagonal, Bidiagonal, Tridiagonal, and SymTridiagonal
123+
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
124+
122125
### Binary BroadcastStyle rules
123126
"""
124127
BroadcastStyle(::Style1, ::Style2) = Style3()

base/linalg/bidiag.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,23 @@ Bidiagonal{T}(A::Bidiagonal) where {T} =
172172
# When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T}
173173
AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A)
174174

175-
broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo)
175+
function copyto!(dest::Bidiagonal, bc::Broadcasted{PromoteToSparse})
176+
axs = axes(dest)
177+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
178+
for i in axs[1]
179+
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
180+
end
181+
if dest.uplo == 'U'
182+
for i = 1:size(dest, 1)-1
183+
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
184+
end
185+
else
186+
for i = 1:size(dest, 1)-1
187+
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i))
188+
end
189+
end
190+
dest
191+
end
176192

177193
# For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix.
178194
# On the other hand, similar(B, [neweltype,] shape...) should yield a sparse matrix.
@@ -234,18 +250,9 @@ function size(M::Bidiagonal, d::Integer)
234250
end
235251

236252
#Elementary operations
237-
broadcast(::typeof(abs), M::Bidiagonal) = Bidiagonal(abs.(M.dv), abs.(M.ev), M.uplo)
238-
broadcast(::typeof(round), M::Bidiagonal) = Bidiagonal(round.(M.dv), round.(M.ev), M.uplo)
239-
broadcast(::typeof(trunc), M::Bidiagonal) = Bidiagonal(trunc.(M.dv), trunc.(M.ev), M.uplo)
240-
broadcast(::typeof(floor), M::Bidiagonal) = Bidiagonal(floor.(M.dv), floor.(M.ev), M.uplo)
241-
broadcast(::typeof(ceil), M::Bidiagonal) = Bidiagonal(ceil.(M.dv), ceil.(M.ev), M.uplo)
242253
for func in (:conj, :copy, :real, :imag)
243254
@eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo)
244255
end
245-
broadcast(::typeof(round), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(round.(T, M.dv), round.(T, M.ev), M.uplo)
246-
broadcast(::typeof(trunc), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(trunc.(T, M.dv), trunc.(T, M.ev), M.uplo)
247-
broadcast(::typeof(floor), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(floor.(T, M.dv), floor.(T, M.ev), M.uplo)
248-
broadcast(::typeof(ceil), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(ceil.(T, M.dv), ceil.(T, M.ev), M.uplo)
249256

250257
transpose(M::Bidiagonal) = Bidiagonal(M.dv, M.ev, M.uplo == 'U' ? :L : :U)
251258
adjoint(M::Bidiagonal) = Bidiagonal(conj(M.dv), conj(M.ev), M.uplo == 'U' ? :L : :U)

base/linalg/diagonal.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,18 @@ isposdef(D::Diagonal) = all(x -> x > 0, D.diag)
111111

112112
factorize(D::Diagonal) = D
113113

114-
broadcast(::typeof(abs), D::Diagonal) = Diagonal(abs.(D.diag))
115114
real(D::Diagonal) = Diagonal(real(D.diag))
116115
imag(D::Diagonal) = Diagonal(imag(D.diag))
117116

117+
function copyto!(dest::Diagonal, bc::Broadcasted{PromoteToSparse})
118+
axs = axes(dest)
119+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
120+
for i in axs[1]
121+
dest.diag[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
122+
end
123+
dest
124+
end
125+
118126
istriu(D::Diagonal) = true
119127
istril(D::Diagonal) = true
120128
function triu!(D::Diagonal,k::Integer=0)

base/linalg/linalg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1717
StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
1818
using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof,
1919
@propagate_inbounds, @pure, reduce, typed_vcat
20+
using Base.Broadcast: Broadcasted, PromoteToSparse
21+
2022
# We use `_length` because of non-1 indices; releases after julia 0.5
2123
# can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`.
2224

base/linalg/tridiag.jl

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,22 @@ end
113113
similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T))
114114
similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...)
115115

116+
function copyto!(dest::SymTridiagonal, bc::Broadcasted{PromoteToSparse})
117+
axs = axes(dest)
118+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
119+
for i in axs[1]
120+
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
121+
end
122+
for i = 1:size(dest, 1)-1
123+
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
124+
end
125+
dest
126+
end
127+
116128
#Elementary operations
117-
broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev))
118-
broadcast(::typeof(round), M::SymTridiagonal) = SymTridiagonal(round.(M.dv), round.(M.ev))
119-
broadcast(::typeof(trunc), M::SymTridiagonal) = SymTridiagonal(trunc.(M.dv), trunc.(M.ev))
120-
broadcast(::typeof(floor), M::SymTridiagonal) = SymTridiagonal(floor.(M.dv), floor.(M.ev))
121-
broadcast(::typeof(ceil), M::SymTridiagonal) = SymTridiagonal(ceil.(M.dv), ceil.(M.ev))
122129
for func in (:conj, :copy, :real, :imag)
123130
@eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev))
124131
end
125-
broadcast(::typeof(round), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(round.(T, M.dv), round.(T, M.ev))
126-
broadcast(::typeof(trunc), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(trunc.(T, M.dv), trunc.(T, M.ev))
127-
broadcast(::typeof(floor), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(floor.(T, M.dv), floor.(T, M.ev))
128-
broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(ceil.(T, M.dv), ceil.(T, M.ev))
129132

130133
transpose(M::SymTridiagonal) = M #Identity operation
131134
adjoint(M::SymTridiagonal) = conj(M)
@@ -500,24 +503,11 @@ similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spz
500503
copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)
501504

502505
#Elementary operations
503-
broadcast(::typeof(abs), M::Tridiagonal) = Tridiagonal(abs.(M.dl), abs.(M.d), abs.(M.du))
504-
broadcast(::typeof(round), M::Tridiagonal) = Tridiagonal(round.(M.dl), round.(M.d), round.(M.du))
505-
broadcast(::typeof(trunc), M::Tridiagonal) = Tridiagonal(trunc.(M.dl), trunc.(M.d), trunc.(M.du))
506-
broadcast(::typeof(floor), M::Tridiagonal) = Tridiagonal(floor.(M.dl), floor.(M.d), floor.(M.du))
507-
broadcast(::typeof(ceil), M::Tridiagonal) = Tridiagonal(ceil.(M.dl), ceil.(M.d), ceil.(M.du))
508506
for func in (:conj, :copy, :real, :imag)
509507
@eval function ($func)(M::Tridiagonal)
510508
Tridiagonal(($func)(M.dl), ($func)(M.d), ($func)(M.du))
511509
end
512510
end
513-
broadcast(::typeof(round), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
514-
Tridiagonal(round.(T, M.dl), round.(T, M.d), round.(T, M.du))
515-
broadcast(::typeof(trunc), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
516-
Tridiagonal(trunc.(T, M.dl), trunc.(T, M.d), trunc.(T, M.du))
517-
broadcast(::typeof(floor), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
518-
Tridiagonal(floor.(T, M.dl), floor.(T, M.d), floor.(T, M.du))
519-
broadcast(::typeof(ceil), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
520-
Tridiagonal(ceil.(T, M.dl), ceil.(T, M.d), ceil.(T, M.du))
521511

522512
transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
523513
adjoint(M::Tridiagonal) = conj(transpose(M))
@@ -576,6 +566,19 @@ function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::Ab
576566
i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s)
577567
end
578568

569+
function copyto!(dest::Tridiagonal, bc::Broadcasted{PromoteToSparse})
570+
axs = axes(dest)
571+
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
572+
for i in axs[1]
573+
dest.d[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
574+
end
575+
for i = 1:size(dest, 1)-1
576+
dest.du[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
577+
dest.dl[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i))
578+
end
579+
dest
580+
end
581+
579582
#tril and triu
580583

581584
istriu(M::Tridiagonal) = iszero(M.dl)

base/sparse/higherorderfns.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Base: map, map!, broadcast, copy, copyto!
99
using Base: TupleLL, front, tail, to_shape
1010
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
1111
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
12-
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
12+
using Base.Broadcast: BroadcastStyle, Broadcasted, PromoteToSparse, Args1, Args2, flatten
1313

1414
# This module is organized as follows:
1515
# (0) Define BroadcastStyle rules and convenience types for dispatch
@@ -54,7 +54,6 @@ SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
5454

5555
Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()
5656

57-
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
5857
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
5958
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()
6059

@@ -969,6 +968,7 @@ function _copy(::Any, bc::Broadcasted{<:SPVM})
969968
parevalf, passedargstup = capturescalars(bcf.f, args)
970969
return broadcast(parevalf, passedargstup...)
971970
end
971+
972972
function _shapecheckbc(bc::Broadcasted)
973973
args = Tuple(bc.args)
974974
_aresameshape(bc.args) ? _noshapecheck_map(bc.f, args...) : _diffshape_broadcast(bc.f, args...)
@@ -1044,10 +1044,22 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
10441044
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.
10451045

10461046
function copy(bc::Broadcasted{PromoteToSparse})
1047+
if bc.args isa Args1{<:StructuredMatrix} || bc.args isa Args2{<:Type,<:StructuredMatrix}
1048+
if _iszero(fzero(bc.f, bc.args))
1049+
T = Broadcast.combine_eltypes(bc.f, bc.args)
1050+
M = get_matrix(bc.args)
1051+
dest = similar(M, T)
1052+
return copyto!(dest, bc)
1053+
end
1054+
end
10471055
bcf = flatten(bc)
10481056
As = Tuple(bcf.args)
10491057
broadcast(bcf.f, map(_sparsifystructured, As)...)
10501058
end
1059+
get_matrix(args::Args1{<:StructuredMatrix}) = args.head
1060+
get_matrix(args::Args2{<:Type,<:StructuredMatrix}) = args.rest.head
1061+
fzero(f::Tf, args::Args1{<:StructuredMatrix}) where Tf = f(zero(eltype(get_matrix(args))))
1062+
fzero(f::Tf, args::Args2{<:Type, <:StructuredMatrix}) where Tf = f(args.head, zero(eltype(get_matrix(args))))
10511063

10521064
function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse})
10531065
bcf = flatten(bc)

test/sparse/higherorderfns.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ end
382382
structuredarrays = (D, B, T, S)
383383
fstructuredarrays = map(Array, structuredarrays)
384384
for (X, fX) in zip(structuredarrays, fstructuredarrays)
385-
@test (Q = broadcast(sin, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(sin, fX)))
385+
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == sparse(broadcast(sin, fX)))
386386
@test broadcast!(sin, Z, X) == sparse(broadcast(sin, fX))
387387
@test (Q = broadcast(cos, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(cos, fX)))
388388
@test broadcast!(cos, Z, X) == sparse(broadcast(cos, fX))

0 commit comments

Comments
 (0)