Skip to content

Commit 1873e22

Browse files
almost have it
1 parent 68c9ab0 commit 1873e22

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

src/math.jl

+27-19
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,20 @@ Base.map!(f::F, m::AMSA, A0::AbstractArray, As::AbstractArray...) where {F} =
66
Base.map!(f::F, m::AMSA, A0, As...) where {F} =
77
broadcast!(f, m, A0, As...)
88

9-
struct AMSAStyle{Style <: Broadcast.BroadcastStyle} <: Broadcast.AbstractArrayStyle{Any} end
9+
Base.BroadcastStyle(::Type{<:AMSA}) = Broadcast.ArrayStyle{AMSA}()
10+
Base.BroadcastStyle(::Type{<:AbstractMultiScaleArrayLeaf}) = Broadcast.ArrayStyle{AbstractMultiScaleArrayLeaf}()
11+
12+
#=
1013
AMSAStyle(::S) where {S} = AMSAStyle{S}()
1114
AMSAStyle(::S, ::Val{N}) where {S,N} = AMSAStyle(S(Val(N)))
1215
AMSAStyle(::Val{N}) where N = AMSAStyle{Broadcast.DefaultArrayStyle{N}}()
1316
17+
1418
# promotion rules
1519
function Broadcast.BroadcastStyle(::AMSAStyle{AStyle}, ::AMSAStyle{BStyle}) where {AStyle, BStyle}
1620
AMSAStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
1721
end
22+
=#
1823

1924
#=
2025
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
@@ -28,13 +33,23 @@ function Broadcast.BroadcastStyle(::Type{AMSA{T}}) where {T}
2833
end
2934
=#
3035

31-
@inline function Base.copy(bc::Broadcast.Broadcasted{AMSAStyle{Style}}) where Style
32-
nnodes(bc)
36+
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}})
37+
N = nnodes(bc)
3338
@inline function f(i)
3439
copy(unpack(bc, i))
3540
end
36-
CommonType = get_common_type(bc)
37-
construct(CommonType, map(f,N), f(nothing))
41+
first_amsa = find_amsa(bc)
42+
construct(first_amsa, map(f,N), f(nothing))
43+
end
44+
45+
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AbstractMultiScaleArrayLeaf}})
46+
@show bc
47+
@inline function f(i)
48+
copy(unpack(bc, i))
49+
end
50+
first_amsa = find_amsa(bc)
51+
@show first_amsa
52+
construct(first_amsa, f(nothing))
3853
end
3954

4055
@inline function Base.copyto!(dest::AMSA, bc::Broadcast.Broadcasted{Nothing})
@@ -51,7 +66,7 @@ end
5166

5267
# drop axes because it is easier to recompute
5368
@inline unpack(bc::Broadcast.Broadcasted{Style}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
54-
@inline unpack(bc::Broadcast.Broadcasted{AMSAStyle{Style}}, i) where Style = Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
69+
@inline unpack(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}}, i) = Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}}(bc.f, unpack_args(i, bc.args))
5570
unpack(x,::Any) = x
5671
unpack(x::AMSA, i) = x.nodes[i]
5772
unpack(x::AMSA, ::Nothing) = x.values
@@ -69,23 +84,16 @@ nnodes(A, Bs...) = common_number(nnodes(A), _nnodes(Bs))
6984
_nnodes(args::Tuple{Any}) = nnodes(args[1])
7085
_nnodes(args::Tuple{}) = 0
7186

72-
get_common_type(A) = Nothing
73-
get_common_type(A::AMSA) = typeof(A)
74-
get_common_type(bc::Broadcast.Broadcasted) = _nnodes(bc.args)
75-
get_common_type(A, Bs...) = common_type(get_common_type(A), _get_common_type(Bs))
76-
77-
@inline _get_common_type(args::Tuple) = get_common_type(get_common_type(args[1]), _get_common_type(Base.tail(args)))
78-
_get_common_type(args::Tuple{Any}) = get_common_type(args[1])
79-
_get_common_type(args::Tuple{}) = Nothing
87+
"`A = find_amsa(As)` returns the first AMSA among the arguments."
88+
find_amsa(bc::Base.Broadcast.Broadcasted) = find_amsa(bc.args)
89+
find_amsa(args::Tuple) = find_amsa(find_amsa(args[1]), Base.tail(args))
90+
find_amsa(x) = x
91+
find_amsa(a::AMSA, rest) = a
92+
find_amsa(::Any, rest) = find_amsa(rest)
8093

8194
## utils
8295
common_number(a, b) =
8396
a == 0 ? b :
8497
(b == 0 ? a :
8598
(a == b ? a :
8699
throw(DimensionMismatch("number of nodes must be equal"))))
87-
88-
common_type(a::T, b::T) where T = T
89-
common_type(a::T, b::Nothing) where T = T
90-
common_type(a::Nothing, b::T) where T = T
91-
common_type(a::Nothing, b::Nothing) where T = Nothing

test/indexing_and_creation_tests.jl

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ end
2828
cell1 = Cell([1.0; 2.0; 3.0])
2929
cell2 = Cell([4.0; 5])
3030

31+
cell1 .+ cell1
32+
3133
sim_cell = similar(cell1)
3234

3335
@test length(cell1) == 3

0 commit comments

Comments
 (0)