@@ -6,47 +6,65 @@ Base.map!(f::F, m::AMSA, A0::AbstractArray, As::AbstractArray...) where {F} =
6
6
Base. map! (f:: F , m:: AMSA , A0, As... ) where {F} =
7
7
broadcast! (f, m, A0, As... )
8
8
9
- const AMSAStyle = Broadcast. ArrayStyle{AMSA}
10
- Base. BroadcastStyle (:: Type{<:AMSA} ) = Broadcast. ArrayStyle {AMSA} ()
11
- Base. BroadcastStyle (:: Broadcast.ArrayStyle{AMSA} ,:: Broadcast.DefaultArrayStyle{1} ) = Broadcast. DefaultArrayStyle {1} ()
12
- Base. BroadcastStyle (:: Broadcast.DefaultArrayStyle{1} ,:: Broadcast.ArrayStyle{AMSA} ) = Broadcast. DefaultArrayStyle {1} ()
13
- Base. similar (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}} ,:: Type{ElType} ) where ElType = similar (bc)
14
-
15
- function Base. copy (bc:: Broadcast.Broadcasted{AMSAStyle} )
16
- ret = Broadcast. flatten (bc)
17
- __broadcast (ret. f,ret. args... )
18
- end
9
+ struct AMSAStyle{Style <: Broadcast.BroadcastStyle } <: Broadcast.AbstractArrayStyle{Any} end
10
+ AMSAStyle (:: S ) where {S} = AMSAStyle {S} ()
11
+ AMSAStyle (:: S , :: Val{N} ) where {S,N} = AMSAStyle (S (Val (N)))
12
+ AMSAStyle (:: Val{N} ) where N = AMSAStyle {Broadcast.DefaultArrayStyle{N}} ()
19
13
20
- function Base . copyto! (dest :: AMSA , bc :: Broadcast.Broadcasted{AMSAStyle} )
21
- ret = Broadcast. flatten (bc)
22
- __broadcast! (ret . f,dest,ret . args ... )
14
+ # promotion rules
15
+ function Broadcast. BroadcastStyle ( :: AMSAStyle{AStyle} , :: AMSAStyle{BStyle} ) where {AStyle, BStyle}
16
+ AMSAStyle (Broadcast . BroadcastStyle ( AStyle (), BStyle ()) )
23
17
end
24
18
25
- function __broadcast (f, A:: AMSA , Bs... )
26
- broadcast! (f, similar (A), A, Bs... )
19
+ #=
20
+ combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
21
+ combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
22
+ combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
23
+ @inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
24
+
25
+ function Broadcast.BroadcastStyle(::Type{AMSA{T}}) where {T}
26
+ Style = combine_styles((T.parameters...,))
27
+ AMSAStyle(Style)
27
28
end
29
+ =#
28
30
29
- function __broadcast! (f, A:: AbstractMultiScaleArrayLeaf , Bs:: Union{Number,AbstractMultiScaleArrayLeaf} ...)
30
- broadcast! (f, A. values, (typeof (B)<: AbstractMultiScaleArrayLeaf ? B. values : B for B in Bs). .. )
31
- A
31
+ @inline function Base. copy (bc:: Broadcast.Broadcasted{AMSAStyle{Style}} ) where Style
32
+ nnodes (bc)
33
+ @inline function f (i)
34
+ copy (unpack (bc, i))
35
+ end
36
+ CommonType = get_common_type (bc)
37
+ construct (CommonType, map (f,N), f (nothing ))
32
38
end
33
39
34
- function __broadcast! (f, A:: AMSA , Bs:: Union{Number,AMSA} ...)
35
- for i in eachindex (A. nodes)
36
- broadcast! (f, A. nodes[i], (typeof (B)<: AMSA ? B. nodes[i] : B for B in Bs). .. )
40
+ @inline function Base. copyto! (dest:: AMSA , bc:: Broadcast.Broadcasted )
41
+ N = length (dest. nodes)
42
+ for i in 1 : N
43
+ copyto! (dest. nodes[i], unpack (bc, i))
37
44
end
38
- broadcast! (f, A. values, (typeof (B)<: AMSA ? B. values : B for B in Bs). .. )
39
- A
45
+ copyto! (dest. values,unpack (bc, nothing ))
46
+ end
47
+
48
+ @inline function Base. copyto! (dest:: AbstractMultiScaleArrayLeaf , bc:: Broadcast.Broadcasted )
49
+ copyto! (dest. values,unpack (bc,nothing ))
40
50
end
41
51
42
- + (m:: AbstractMultiScaleArray , y:: Number ) = m .+ y
43
- + (y:: Number , m:: AbstractMultiScaleArray ) = m .+ y
52
+ # drop axes because it is easier to recompute
53
+ @inline unpack (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
54
+ @inline unpack (bc:: Broadcast.Broadcasted{AMSAStyle{Style}} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
55
+ unpack (x,:: Any ) = x
56
+ unpack (x:: AMSA , i) = x. nodes[i]
57
+ unpack (x:: AMSA , :: Nothing ) = x. values
44
58
45
- - (m:: AbstractMultiScaleArray , y:: Number ) = m .- y
46
- - (y:: Number , m:: AbstractMultiScaleArray ) = y .- m
59
+ @inline unpack_args (i, args:: Tuple ) = (unpack (args[1 ], i), unpack_args (i, Base. tail (args))... )
60
+ unpack_args (i, args:: Tuple{Any} ) = (unpack (args[1 ], i),)
61
+ unpack_args (:: Any , args:: Tuple{} ) = ()
47
62
48
- * (m:: AbstractMultiScaleArray , y:: Number ) = m .* y
49
- * (y:: Number , m:: AbstractMultiScaleArray ) = m .* y
63
+ nnodes (A) = 0
64
+ nnodes (A:: ASMA ) = length (A. nodes)
65
+ nnodes (bc:: Broadcast.Broadcasted ) = _nnodes (bc. args)
66
+ nnodes (A, Bs... ) = common_number (nnodes (A), _nnodes (Bs))
50
67
51
- / (m:: AbstractMultiScaleArray , y:: Number ) = m ./ y
52
- / (y:: Number , m:: AbstractMultiScaleArray ) = y ./ m
68
+ @inline _nnodes (args:: Tuple ) = common_number (nnodes (args[1 ]), _nnodes (Base. tail (args)))
69
+ _nnodes (args:: Tuple{Any} ) = nnodes (args[1 ])
70
+ _nnodes (args:: Tuple{} ) = 0
0 commit comments