@@ -6,15 +6,20 @@ Base.map!(f::F, m::AMSA, A0::AbstractArray, As::AbstractArray...) where {F} =
6
6
Base. map! (f:: F , m:: AMSA , A0, As... ) where {F} =
7
7
broadcast! (f, m, A0, As... )
8
8
9
- struct AMSAStyle{Style <: Broadcast.BroadcastStyle } <: Broadcast.AbstractArrayStyle{Any} end
9
+ Base. BroadcastStyle (:: Type{<:AMSA} ) = Broadcast. ArrayStyle {AMSA} ()
10
+ Base. BroadcastStyle (:: Type{<:AbstractMultiScaleArrayLeaf} ) = Broadcast. ArrayStyle {AbstractMultiScaleArrayLeaf} ()
11
+
12
+ #=
10
13
AMSAStyle(::S) where {S} = AMSAStyle{S}()
11
14
AMSAStyle(::S, ::Val{N}) where {S,N} = AMSAStyle(S(Val(N)))
12
15
AMSAStyle(::Val{N}) where N = AMSAStyle{Broadcast.DefaultArrayStyle{N}}()
13
16
17
+
14
18
# promotion rules
15
19
function Broadcast.BroadcastStyle(::AMSAStyle{AStyle}, ::AMSAStyle{BStyle}) where {AStyle, BStyle}
16
20
AMSAStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
17
21
end
22
+ =#
18
23
19
24
#=
20
25
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
@@ -28,13 +33,23 @@ function Broadcast.BroadcastStyle(::Type{AMSA{T}}) where {T}
28
33
end
29
34
=#
30
35
31
- @inline function Base. copy (bc:: Broadcast.Broadcasted{AMSAStyle{Style }} ) where Style
32
- nnodes (bc)
36
+ @inline function Base. copy (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA }} )
37
+ N = nnodes (bc)
33
38
@inline function f (i)
34
39
copy (unpack (bc, i))
35
40
end
36
- CommonType = get_common_type (bc)
37
- construct (CommonType, map (f,N), f (nothing ))
41
+ first_amsa = find_amsa (bc)
42
+ construct (first_amsa, map (f,N), f (nothing ))
43
+ end
44
+
45
+ @inline function Base. copy (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{AbstractMultiScaleArrayLeaf}} )
46
+ @show bc
47
+ @inline function f (i)
48
+ copy (unpack (bc, i))
49
+ end
50
+ first_amsa = find_amsa (bc)
51
+ @show first_amsa
52
+ construct (first_amsa, f (nothing ))
38
53
end
39
54
40
55
@inline function Base. copyto! (dest:: AMSA , bc:: Broadcast.Broadcasted{Nothing} )
51
66
52
67
# drop axes because it is easier to recompute
53
68
@inline unpack (bc:: Broadcast.Broadcasted{Style} , i) where Style = Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
54
- @inline unpack (bc:: Broadcast.Broadcasted{AMSAStyle{Style }} , i) where Style = Broadcast. Broadcasted {Style } (bc. f, unpack_args (i, bc. args))
69
+ @inline unpack (bc:: Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA }} , i) = Broadcast. Broadcasted {Broadcast.ArrayStyle{AMSA} } (bc. f, unpack_args (i, bc. args))
55
70
unpack (x,:: Any ) = x
56
71
unpack (x:: AMSA , i) = x. nodes[i]
57
72
unpack (x:: AMSA , :: Nothing ) = x. values
@@ -69,23 +84,16 @@ nnodes(A, Bs...) = common_number(nnodes(A), _nnodes(Bs))
69
84
_nnodes (args:: Tuple{Any} ) = nnodes (args[1 ])
70
85
_nnodes (args:: Tuple{} ) = 0
71
86
72
- get_common_type (A) = Nothing
73
- get_common_type (A:: AMSA ) = typeof (A)
74
- get_common_type (bc:: Broadcast.Broadcasted ) = _nnodes (bc. args)
75
- get_common_type (A, Bs... ) = common_type (get_common_type (A), _get_common_type (Bs))
76
-
77
- @inline _get_common_type (args:: Tuple ) = get_common_type (get_common_type (args[1 ]), _get_common_type (Base. tail (args)))
78
- _get_common_type (args:: Tuple{Any} ) = get_common_type (args[1 ])
79
- _get_common_type (args:: Tuple{} ) = Nothing
87
+ " `A = find_amsa(As)` returns the first AMSA among the arguments."
88
+ find_amsa (bc:: Base.Broadcast.Broadcasted ) = find_amsa (bc. args)
89
+ find_amsa (args:: Tuple ) = find_amsa (find_amsa (args[1 ]), Base. tail (args))
90
+ find_amsa (x) = x
91
+ find_amsa (a:: AMSA , rest) = a
92
+ find_amsa (:: Any , rest) = find_amsa (rest)
80
93
81
94
# # utils
82
95
common_number (a, b) =
83
96
a == 0 ? b :
84
97
(b == 0 ? a :
85
98
(a == b ? a :
86
99
throw (DimensionMismatch (" number of nodes must be equal" ))))
87
-
88
- common_type (a:: T , b:: T ) where T = T
89
- common_type (a:: T , b:: Nothing ) where T = T
90
- common_type (a:: Nothing , b:: T ) where T = T
91
- common_type (a:: Nothing , b:: Nothing ) where T = Nothing
0 commit comments