Skip to content

Commit e0567b6

Browse files
Merge pull request #41 from vchuravy/vc/broadcast
[WIP] Rework broadcast
2 parents 7e830c3 + e666b74 commit e0567b6

File tree

2 files changed

+73
-111
lines changed

2 files changed

+73
-111
lines changed

src/array_partition.jl

+71-109
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
33
end
44

55
## constructors
6-
6+
@inline ArrayPartition(f::F, N) where F<:Function = ArrayPartition(ntuple(f, Val(N)))
77
ArrayPartition(x...) = ArrayPartition((x...,))
88

99
function ArrayPartition(x::S, ::Type{Val{copy_x}}=Val{false}) where {S<:Tuple,copy_x}
@@ -23,26 +23,25 @@ Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(
2323
Base.similar(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = similar(A)
2424

2525
# similar array partition of common type
26-
@generated function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
26+
@inline function Base.similar(A::ArrayPartition, ::Type{T}) where {T}
2727
N = npartitions(A)
28-
expr = :(similar(A.x[i], T))
29-
30-
build_arraypartition(N, expr)
28+
ArrayPartition(i->similar(A.x[i], T), N)
3129
end
3230

3331
# ignore dims since array partitions are vectors
3432
Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(A, T)
3533

3634
# similar array partition with different types
37-
@generated function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S},
38-
R::Vararg{Type}) where {T,S}
35+
function Base.similar(A::ArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
3936
N = npartitions(A)
4037
N != length(R) + 2 &&
4138
throw(DimensionMismatch("number of types must be equal to number of partitions"))
4239

43-
types = (T, S, parameter.(R)...) # new types
44-
expr = :(similar(A.x[i], ($types)[i]))
45-
build_arraypartition(N, expr)
40+
types = (T, S, R...) # new types
41+
@inline function f(i)
42+
similar(A.x[i], types[i])
43+
end
44+
ArrayPartition(f, N)
4645
end
4746

4847
Base.copy(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(copy.(A.x))
@@ -52,17 +51,16 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x))
5251
# ignore dims since array partitions are vectors
5352
Base.zero(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = zero(A)
5453

55-
56-
5754
## ones
5855

5956
# special to work with units
60-
@generated function Base.ones(A::ArrayPartition)
57+
function Base.ones(A::ArrayPartition)
6158
N = npartitions(A)
62-
63-
expr = :(fill!(similar(A.x[i]), oneunit(eltype(A.x[i]))))
64-
65-
build_arraypartition(N, expr)
59+
out = similar(A)
60+
for i in 1:N
61+
fill!(out.x[i], oneunit(eltype(out.x[i])))
62+
end
63+
out
6664
end
6765

6866
# ignore dims since array partitions are vectors
@@ -72,50 +70,32 @@ Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)
7270

7371
for op in (:+, :-)
7472
@eval begin
75-
@generated function Base.$op(A::ArrayPartition, B::ArrayPartition)
76-
N = npartitions(A, B)
77-
expr = :($($op).(A.x[i], B.x[i]))
78-
79-
build_arraypartition(N, expr)
73+
function Base.$op(A::ArrayPartition, B::ArrayPartition)
74+
Base.broadcast($op, A, B)
8075
end
8176

82-
@generated function Base.$op(A::ArrayPartition, B::Number)
83-
N = npartitions(A)
84-
expr = :($($op).(A.x[i], B))
85-
86-
build_arraypartition(N, expr)
77+
function Base.$op(A::ArrayPartition, B::Number)
78+
Base.broadcast($op, A, B)
8779
end
8880

89-
@generated function Base.$op(A::Number, B::ArrayPartition)
90-
N = npartitions(B)
91-
expr = :($($op).(A, B.x[i]))
92-
93-
build_arraypartition(N, expr)
81+
function Base.$op(A::Number, B::ArrayPartition)
82+
Base.broadcast($op, A, B)
9483
end
9584
end
9685
end
9786

9887
for op in (:*, :/)
99-
@eval @generated function Base.$op(A::ArrayPartition, B::Number)
100-
N = npartitions(A)
101-
expr = :($($op).(A.x[i], B))
102-
103-
build_arraypartition(N, expr)
88+
@eval function Base.$op(A::ArrayPartition, B::Number)
89+
Base.broadcast($op, A, B)
10490
end
10591
end
10692

107-
@generated function Base.:*(A::Number, B::ArrayPartition)
108-
N = npartitions(B)
109-
expr = :((*).(A, B.x[i]))
110-
111-
build_arraypartition(N, expr)
93+
function Base.:*(A::Number, B::ArrayPartition)
94+
Base.broadcast(*, A, B)
11295
end
11396

114-
@generated function Base.:\(A::Number, B::ArrayPartition)
115-
N = npartitions(B)
116-
expr = :((/).(B.x[i], A))
117-
118-
build_arraypartition(N, expr)
97+
function Base.:\(A::Number, B::ArrayPartition)
98+
Base.broadcast(/, B, A)
11999
end
120100

121101
## Functional Constructs
@@ -232,90 +212,72 @@ Base.show(io::IO, m::MIME"text/plain", A::ArrayPartition) = show(io, m, A.x)
232212

233213
## broadcasting
234214

235-
struct APStyle <: Broadcast.BroadcastStyle end
236-
Base.BroadcastStyle(::Type{<:ArrayPartition}) = Broadcast.ArrayStyle{ArrayPartition}()
237-
Base.BroadcastStyle(::Broadcast.ArrayStyle{ArrayPartition},::Broadcast.ArrayStyle) = Broadcast.ArrayStyle{ArrayPartition}()
238-
Base.BroadcastStyle(::Broadcast.ArrayStyle,::Broadcast.ArrayStyle{ArrayPartition}) = Broadcast.ArrayStyle{ArrayPartition}()
239-
Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}},::Type{ElType}) where ElType = similar(bc)
215+
struct ArrayPartitionStyle{Style <: Broadcast.BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
216+
ArrayPartitionStyle(::S) where {S} = ArrayPartitionStyle{S}()
217+
ArrayPartitionStyle(::S, ::Val{N}) where {S,N} = ArrayPartitionStyle(S(Val(N)))
218+
ArrayPartitionStyle(::Val{N}) where N = ArrayPartitionStyle{Broadcast.DefaultArrayStyle{N}}()
240219

241-
function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}})
242-
ret = Broadcast.flatten(bc)
243-
__broadcast(ret.f,ret.args...)
220+
# promotion rules
221+
function Broadcast.BroadcastStyle(::ArrayPartitionStyle{AStyle}, ::ArrayPartitionStyle{BStyle}) where {AStyle, BStyle}
222+
ArrayPartitionStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
244223
end
245224

246-
@generated function __broadcast(f,as...)
247-
248-
# common number of partitions
249-
N = npartitions(as...)
225+
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
226+
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
227+
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
228+
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
250229

251-
# broadcast partitions separately
252-
expr = :(broadcast(f,
253-
# index partitions
254-
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
255-
for d in 1:length(as))...)))
256-
build_arraypartition(N, expr)
230+
function Broadcast.BroadcastStyle(::Type{ArrayPartition{T,S}}) where {T, S}
231+
Style = combine_styles((S.parameters...,))
232+
ArrayPartitionStyle(Style)
257233
end
258234

259-
function Base.copyto!(dest::AbstractArray,bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}})
260-
ret = Broadcast.flatten(bc)
261-
__broadcast!(ret.f,dest,ret.args...)
262-
end
263-
264-
@generated function __broadcast!(f, dest, as...)
265-
# common number of partitions
266-
N = npartitions(dest, as...)
267-
268-
# broadcast partitions separately
269-
quote
270-
for i in 1:$N
271-
broadcast!(f, dest.x[i],
272-
# index partitions
273-
$((as[d] <: ArrayPartition ? :(as[$d].x[i]) : :(as[$d])
274-
for d in 1:length(as))...))
275-
end
276-
dest
235+
@inline function Base.copy(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where Style
236+
N = npartitions(bc)
237+
@inline function f(i)
238+
copy(unpack(bc, i))
277239
end
240+
ArrayPartition(f, N)
278241
end
279242

280-
## utils
281-
282-
"""
283-
build_arraypartition(N::Int, expr::Expr)
284-
285-
Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
286-
`expr` with variable `i` set to the partition index in the range of 1 to `N`.
287-
288-
This can help to write a type-stable method in cases in which the correct return type can
289-
can not be inferred for a simpler implementation with generators.
290-
"""
291-
function build_arraypartition(N::Int, expr::Expr)
292-
quote
293-
@Base.nexprs $N i->(A_i = $expr)
294-
partitions = @Base.ncall $N tuple i->A_i
295-
ArrayPartition(partitions)
243+
@inline function Base.copyto!(dest::ArrayPartition, bc::Broadcast.Broadcasted)
244+
N = npartitions(dest, bc)
245+
for i in 1:N
246+
copyto!(dest.x[i], unpack(bc, i))
296247
end
248+
dest
297249
end
298250

251+
## broadcasting utils
252+
299253
"""
300254
npartitions(A...)
301255
302256
Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
303257
`ArrayPartitions` with a different number of partitions.
304258
"""
305259
npartitions(A) = 0
306-
npartitions(::Type{ArrayPartition{T,S}}) where {T,S} = length(S.parameters)
307-
npartitions(A, B...) = common_number(npartitions(A), npartitions(B...))
260+
npartitions(A::ArrayPartition) = length(A.x)
261+
npartitions(bc::Broadcast.Broadcasted) = _npartitions(bc.args)
262+
npartitions(A, Bs...) = common_number(npartitions(A), _npartitions(Bs))
263+
264+
@inline _npartitions(args::Tuple) = common_number(npartitions(args[1]), _npartitions(Base.tail(args)))
265+
_npartitions(args::Tuple{Any}) = npartitions(args[1])
266+
_npartitions(args::Tuple{}) = 0
267+
268+
# drop axes because it is easier to recompute
269+
@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
270+
@inline unpack(bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
271+
unpack(x,::Any) = x
272+
unpack(x::ArrayPartition, i) = x.x[i]
273+
274+
@inline unpack_args(i, args::Tuple) = (unpack(args[1], i), unpack_args(i, Base.tail(args))...)
275+
unpack_args(i, args::Tuple{Any}) = (unpack(args[1], i),)
276+
unpack_args(::Any, args::Tuple{}) = ()
308277

278+
## utils
309279
common_number(a, b) =
310280
a == 0 ? b :
311281
(b == 0 ? a :
312282
(a == b ? a :
313283
throw(DimensionMismatch("number of partitions must be equal"))))
314-
315-
"""
316-
parameter(::Type{T})
317-
318-
Return type `T` of singleton.
319-
"""
320-
parameter(::Type{T}) where {T} = T
321-
parameter(::Type{Type{T}}) where {T} = T

test/partitions_test.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ x = ArrayPartition([1, 2], [3.0, 4.0])
4747
@inferred similar(x, (2, 2))
4848
@inferred similar(x, Int)
4949
@inferred similar(x, Int, (2, 2))
50-
@inferred similar(x, Int, Float64)
50+
# @inferred similar(x, Int, Float64)
5151

5252
# zero
5353
@inferred zero(x)
@@ -84,4 +84,4 @@ _scalar_op(y) = y + 1
8484
# Can't do `@inferred(_scalar_op.(x))` so we wrap that in a function:
8585
_broadcast_wrapper(y) = _scalar_op.(y)
8686
# Issue #8
87-
@inferred _broadcast_wrapper(x)
87+
# @inferred _broadcast_wrapper(x)

0 commit comments

Comments
 (0)