@@ -3,7 +3,7 @@ struct ArrayPartition{T,S<:Tuple} <: AbstractVector{T}
3
3
end
4
4
5
5
# # constructors
6
-
6
+ @inline ArrayPartition (f :: F , N) where F <: Function = ArrayPartition ( ntuple (f, Val (N)))
7
7
ArrayPartition (x... ) = ArrayPartition ((x... ,))
8
8
9
9
function ArrayPartition (x:: S , :: Type{Val{copy_x}} = Val{false }) where {S<: Tuple ,copy_x}
@@ -23,26 +23,25 @@ Base.similar(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(similar.(
23
23
Base. similar (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = similar (A)
24
24
25
25
# similar array partition of common type
26
- @generated function Base. similar (A:: ArrayPartition , :: Type{T} ) where {T}
26
+ @inline function Base. similar (A:: ArrayPartition , :: Type{T} ) where {T}
27
27
N = npartitions (A)
28
- expr = :(similar (A. x[i], T))
29
-
30
- build_arraypartition (N, expr)
28
+ ArrayPartition (i-> similar (A. x[i], T), N)
31
29
end
32
30
33
31
# ignore dims since array partitions are vectors
34
32
Base. similar (A:: ArrayPartition , :: Type{T} , dims:: NTuple{N,Int} ) where {T,N} = similar (A, T)
35
33
36
34
# similar array partition with different types
37
- @generated function Base. similar (A:: ArrayPartition , :: Type{T} , :: Type{S} ,
38
- R:: Vararg{Type} ) where {T,S}
35
+ function Base. similar (A:: ArrayPartition , :: Type{T} , :: Type{S} , R:: DataType... ) where {T, S}
39
36
N = npartitions (A)
40
37
N != length (R) + 2 &&
41
38
throw (DimensionMismatch (" number of types must be equal to number of partitions" ))
42
39
43
- types = (T, S, parameter .(R)... ) # new types
44
- expr = :(similar (A. x[i], ($ types)[i]))
45
- build_arraypartition (N, expr)
40
+ types = (T, S, R... ) # new types
41
+ @inline function f (i)
42
+ similar (A. x[i], types[i])
43
+ end
44
+ ArrayPartition (f, N)
46
45
end
47
46
48
47
Base. copy (A:: ArrayPartition{T,S} ) where {T,S} = ArrayPartition {T,S} (copy .(A. x))
@@ -52,17 +51,16 @@ Base.zero(A::ArrayPartition{T,S}) where {T,S} = ArrayPartition{T,S}(zero.(A.x))
52
51
# ignore dims since array partitions are vectors
53
52
Base. zero (A:: ArrayPartition , dims:: NTuple{N,Int} ) where {N} = zero (A)
54
53
55
-
56
-
57
54
# # ones
58
55
59
56
# special to work with units
60
- @generated function Base. ones (A:: ArrayPartition )
57
+ function Base. ones (A:: ArrayPartition )
61
58
N = npartitions (A)
62
-
63
- expr = :(fill! (similar (A. x[i]), oneunit (eltype (A. x[i]))))
64
-
65
- build_arraypartition (N, expr)
59
+ out = similar (A)
60
+ for i in 1 : N
61
+ fill! (out. x[i], oneunit (eltype (out. x[i])))
62
+ end
63
+ out
66
64
end
67
65
68
66
# ignore dims since array partitions are vectors
@@ -72,50 +70,32 @@ Base.ones(A::ArrayPartition, dims::NTuple{N,Int}) where {N} = ones(A)
72
70
73
71
for op in (:+ , :- )
74
72
@eval begin
75
- @generated function Base. $op (A:: ArrayPartition , B:: ArrayPartition )
76
- N = npartitions (A, B)
77
- expr = :($ ($ op). (A. x[i], B. x[i]))
78
-
79
- build_arraypartition (N, expr)
73
+ function Base. $op (A:: ArrayPartition , B:: ArrayPartition )
74
+ Base. broadcast ($ op, A, B)
80
75
end
81
76
82
- @generated function Base. $op (A:: ArrayPartition , B:: Number )
83
- N = npartitions (A)
84
- expr = :($ ($ op). (A. x[i], B))
85
-
86
- build_arraypartition (N, expr)
77
+ function Base. $op (A:: ArrayPartition , B:: Number )
78
+ Base. broadcast ($ op, A, B)
87
79
end
88
80
89
- @generated function Base. $op (A:: Number , B:: ArrayPartition )
90
- N = npartitions (B)
91
- expr = :($ ($ op). (A, B. x[i]))
92
-
93
- build_arraypartition (N, expr)
81
+ function Base. $op (A:: Number , B:: ArrayPartition )
82
+ Base. broadcast ($ op, A, B)
94
83
end
95
84
end
96
85
end
97
86
98
87
for op in (:* , :/ )
99
- @eval @generated function Base. $op (A:: ArrayPartition , B:: Number )
100
- N = npartitions (A)
101
- expr = :($ ($ op). (A. x[i], B))
102
-
103
- build_arraypartition (N, expr)
88
+ @eval function Base. $op (A:: ArrayPartition , B:: Number )
89
+ Base. broadcast ($ op, A, B)
104
90
end
105
91
end
106
92
107
- @generated function Base.:* (A:: Number , B:: ArrayPartition )
108
- N = npartitions (B)
109
- expr = :((* ). (A, B. x[i]))
110
-
111
- build_arraypartition (N, expr)
93
+ function Base.:* (A:: Number , B:: ArrayPartition )
94
+ Base. broadcast (* , A, B)
112
95
end
113
96
114
- @generated function Base.:\ (A:: Number , B:: ArrayPartition )
115
- N = npartitions (B)
116
- expr = :((/ ). (B. x[i], A))
117
-
118
- build_arraypartition (N, expr)
97
+ function Base.:\ (A:: Number , B:: ArrayPartition )
98
+ Base. broadcast (/ , B, A)
119
99
end
120
100
121
101
# # Functional Constructs
@@ -232,90 +212,72 @@ Base.show(io::IO, m::MIME"text/plain", A::ArrayPartition) = show(io, m, A.x)
232
212
233
213
# # broadcasting
234
214
235
- struct APStyle <: Broadcast.BroadcastStyle end
236
- Base. BroadcastStyle (:: Type{<:ArrayPartition} ) = Broadcast. ArrayStyle {ArrayPartition} ()
237
- Base. BroadcastStyle (:: Broadcast.ArrayStyle{ArrayPartition} ,:: Broadcast.ArrayStyle ) = Broadcast. ArrayStyle {ArrayPartition} ()
238
- Base. BroadcastStyle (:: Broadcast.ArrayStyle ,:: Broadcast.ArrayStyle{ArrayPartition} ) = Broadcast. ArrayStyle {ArrayPartition} ()
239
- Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} ,:: Type{ElType} ) where ElType = similar (bc)
215
+ struct ArrayPartitionStyle{Style <: Broadcast.BroadcastStyle } <: Broadcast.AbstractArrayStyle{Any} end
216
+ ArrayPartitionStyle (:: S ) where {S} = ArrayPartitionStyle {S} ()
217
+ ArrayPartitionStyle (:: S , :: Val{N} ) where {S,N} = ArrayPartitionStyle (S (Val (N)))
218
+ ArrayPartitionStyle (:: Val{N} ) where N = ArrayPartitionStyle {Broadcast.DefaultArrayStyle{N}} ()
240
219
241
- function Base . copy (bc :: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} )
242
- ret = Broadcast. flatten (bc)
243
- __broadcast (ret . f,ret . args ... )
220
+ # promotion rules
221
+ function Broadcast. BroadcastStyle ( :: ArrayPartitionStyle{AStyle} , :: ArrayPartitionStyle{BStyle} ) where {AStyle, BStyle}
222
+ ArrayPartitionStyle (Broadcast . BroadcastStyle ( AStyle (), BStyle ()) )
244
223
end
245
224
246
- @generated function __broadcast (f,as ... )
247
-
248
- # common number of partitions
249
- N = npartitions (as ... )
225
+ combine_styles (args :: Tuple{} ) = Broadcast . DefaultArrayStyle {0} ( )
226
+ combine_styles (args :: Tuple{Any} ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]))
227
+ combine_styles (args :: Tuple{Any, Any} ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]), Broadcast . BroadcastStyle (args[ 2 ]))
228
+ @inline combine_styles (args :: Tuple ) = Broadcast . result_style (Broadcast . BroadcastStyle (args[ 1 ]), combine_styles (Base . tail (args)) )
250
229
251
- # broadcast partitions separately
252
- expr = :(broadcast (f,
253
- # index partitions
254
- $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
255
- for d in 1 : length (as)). .. )))
256
- build_arraypartition (N, expr)
230
+ function Broadcast. BroadcastStyle (:: Type{ArrayPartition{T,S}} ) where {T, S}
231
+ Style = combine_styles ((S. parameters... ,))
232
+ ArrayPartitionStyle (Style)
257
233
end
258
234
259
- function Base. copyto! (dest:: AbstractArray ,bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{ArrayPartition}} )
260
- ret = Broadcast. flatten (bc)
261
- __broadcast! (ret. f,dest,ret. args... )
262
- end
263
-
264
- @generated function __broadcast! (f, dest, as... )
265
- # common number of partitions
266
- N = npartitions (dest, as... )
267
-
268
- # broadcast partitions separately
269
- quote
270
- for i in 1 : $ N
271
- broadcast! (f, dest. x[i],
272
- # index partitions
273
- $ ((as[d] <: ArrayPartition ? :(as[$ d]. x[i]) : :(as[$ d])
274
- for d in 1 : length (as)). .. ))
275
- end
276
- dest
235
+ @inline function Base. copy (bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} ) where Style
236
+ N = npartitions (bc)
237
+ @inline function f (i)
238
+ copy (unpack (bc, i))
277
239
end
240
+ ArrayPartition (f, N)
278
241
end
279
242
280
- # # utils
281
-
282
- """
283
- build_arraypartition(N::Int, expr::Expr)
284
-
285
- Build `ArrayPartition` consisting of `N` partitions, each the result of an evaluation of
286
- `expr` with variable `i` set to the partition index in the range of 1 to `N`.
287
-
288
- This can help to write a type-stable method in cases in which the correct return type can
289
- can not be inferred for a simpler implementation with generators.
290
- """
291
- function build_arraypartition (N:: Int , expr:: Expr )
292
- quote
293
- @Base . nexprs $ N i-> (A_i = $ expr)
294
- partitions = @Base . ncall $ N tuple i-> A_i
295
- ArrayPartition (partitions)
243
+ @inline function Base. copyto! (dest:: ArrayPartition , bc:: Broadcast.Broadcasted )
244
+ N = npartitions (dest, bc)
245
+ for i in 1 : N
246
+ copyto! (dest. x[i], unpack (bc, i))
296
247
end
248
+ dest
297
249
end
298
250
251
+ # # broadcasting utils
252
+
299
253
"""
300
254
npartitions(A...)
301
255
302
256
Retrieve number of partitions of `ArrayPartitions` in `A...`, or throw an error if there are
303
257
`ArrayPartitions` with a different number of partitions.
304
258
"""
305
259
npartitions (A) = 0
306
- npartitions (:: Type{ArrayPartition{T,S}} ) where {T,S} = length (S. parameters)
307
- npartitions (A, B... ) = common_number (npartitions (A), npartitions (B... ))
260
+ npartitions (A:: ArrayPartition ) = length (A. x)
261
+ npartitions (bc:: Broadcast.Broadcasted ) = _npartitions (bc. args)
262
+ npartitions (A, Bs... ) = common_number (npartitions (A), _npartitions (Bs))
263
+
264
+ @inline _npartitions (args:: Tuple ) = common_number (npartitions (args[1 ]), _npartitions (Base. tail (args)))
265
+ _npartitions (args:: Tuple{Any} ) = npartitions (args[1 ])
266
+ _npartitions (args:: Tuple{} ) = 0
267
+
268
+ # drop axes because it is easier to recompute
269
+ @inline unpack (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
270
+ @inline unpack (bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
271
+ unpack (x,:: Any ) = x
272
+ unpack (x:: ArrayPartition , i) = x. x[i]
273
+
274
+ @inline unpack_args (i, args:: Tuple ) = (unpack (args[1 ], i), unpack_args (i, Base. tail (args))... )
275
+ unpack_args (i, args:: Tuple{Any} ) = (unpack (args[1 ], i),)
276
+ unpack_args (:: Any , args:: Tuple{} ) = ()
308
277
278
+ # # utils
309
279
common_number (a, b) =
310
280
a == 0 ? b :
311
281
(b == 0 ? a :
312
282
(a == b ? a :
313
283
throw (DimensionMismatch (" number of partitions must be equal" ))))
314
-
315
- """
316
- parameter(::Type{T})
317
-
318
- Return type `T` of singleton.
319
- """
320
- parameter (:: Type{T} ) where {T} = T
321
- parameter (:: Type{Type{T}} ) where {T} = T
0 commit comments