Skip to content

Commit 4c4f059

Browse files
fix broadcast fallback and noise as AMSA
1 parent 30b4f98 commit 4c4f059

8 files changed

+202
-62
lines changed

REQUIRE

+1
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ julia 1.0
22
RecursiveArrayTools 0.8.0
33
DiffEqBase 0.11.0
44
TreeViews
5+
StochasticDiffEq

src/MultiScaleArrays.jl

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ abstract type AbstractMultiScaleArrayLeaf{B} <: AbstractMultiScaleArray{B} end
1515
abstract type AbstractMultiScaleArrayHead{B} <: AbstractMultiScaleArray{B} end
1616

1717
using DiffEqBase, Statistics
18+
import StochasticDiffEq
1819

1920
Base.show(io::IO, x::AbstractMultiScaleArray) = invoke(show, Tuple{IO, Any}, io, x)
2021
Base.show(io::IO, ::MIME"text/plain", x::AbstractMultiScaleArray) = show(io, x)

src/diffeq.jl

+155-8
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function remove_node!(integrator::DiffEqBase.DEIntegrator, I...)
88
remove_node!(c, I...)
99
end
1010
end
11-
deleteat_non_user_cache!(integrator, idxs) # required to do noise correctly
11+
remove_node_non_user_cache!(integrator, I) # required to do noise correctly
1212
end
1313

1414
function add_node!(integrator::DiffEqBase.DEIntegrator, x, I...)
@@ -18,14 +18,14 @@ function add_node!(integrator::DiffEqBase.DEIntegrator, x, I...)
1818
idx_start = getindices(integrator.u, last_idx)[end] + 1
1919
idxs = idx_start:idx_start+add_len-1
2020
for c in full_cache(integrator)
21-
add_node!(c, similar(x, eltype(c)), I...)
21+
add_node!(c, fill!(similar(x, eltype(c)),0), I...)
2222
end
2323
if DiffEqBase.is_diagonal_noise(integrator.sol.prob)
2424
for c in DiffEqBase.ratenoise_cache(integrator)
25-
add_node!(c, similar(x, eltype(c)), I...)
25+
add_node!(c, fill!(similar(x, eltype(c)),0), I...)
2626
end
2727
end
28-
addat_non_user_cache!(integrator, idxs) # required to do noise correctly
28+
add_node_non_user_cache!(integrator, idxs, x, I...) # required to do noise correctly
2929
end
3030

3131
function add_node!(integrator::DiffEqBase.DEIntegrator, x)
@@ -34,17 +34,164 @@ function add_node!(integrator::DiffEqBase.DEIntegrator, x)
3434
last_idx = length(integrator.u.nodes)
3535
idx_start = getindices(integrator.u, last_idx)[end] + 1
3636
idxs = idx_start:idx_start+add_len-1
37-
@show idxs
3837
for c in full_cache(integrator)
39-
add_node!(c, similar(x, eltype(c)))
38+
add_node!(c, fill!(similar(x, eltype(c)),0))
4039
end
4140
if DiffEqBase.is_diagonal_noise(integrator.sol.prob)
4241
for c in DiffEqBase.ratenoise_cache(integrator)
43-
add_node!(c, similar(x, eltype(c)))
42+
add_node!(c, fill!(similar(x, eltype(c)),0))
4443
end
4544
end
46-
addat_non_user_cache!(integrator, idxs) # required to do noise correctly
45+
add_node_non_user_cache!(integrator, idxs, fill!(similar(x, eltype(x)),0)) # required to do noise correctly
4746
end
4847

4948

5049
reshape(m::AbstractMultiScaleArray, i::Int...) = m
50+
51+
function remove_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,node)
52+
i = length(integrator.u)
53+
resize_non_user_cache!(integrator,integrator.cache,i)
54+
end
55+
function remove_node_non_user_cache!(integrator::DiffEqBase.AbstractSDEIntegrator,node)
56+
if DiffEqBase.is_diagonal_noise(integrator.sol.prob)
57+
remove_node_noise!(integrator,node)
58+
for c in rand_cache(integrator)
59+
remove_node!(c,node...)
60+
end
61+
end
62+
end
63+
64+
function remove_node_noise!(integrator,node)
65+
for c in integrator.W.S₁
66+
remove_node!(c[2],node...)
67+
if DiffEqBase.alg_needs_extra_process(integrator.alg)
68+
remove_node!(c[3],node...)
69+
end
70+
end
71+
for c in integrator.W.S₂
72+
remove_node!(c[2],node...)
73+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
74+
remove_node!(c[3],node...)
75+
end
76+
end
77+
remove_node!(integrator.W.dW,node...)
78+
remove_node!(integrator.W.dWtilde,node...)
79+
remove_node!(integrator.W.dWtmp,node...)
80+
remove_node!(integrator.W.curW,node...)
81+
82+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
83+
remove_node!(integrator.W.curZ,node...)
84+
remove_node!(integrator.W.dZtmp,node...)
85+
remove_node!(integrator.W.dZtilde,node...)
86+
remove_node!(integrator.W.dZ,node...)
87+
end
88+
end
89+
90+
function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,idxs,x)
91+
i = length(integrator.u)
92+
resize_non_user_cache!(integrator,integrator.cache,i)
93+
end
94+
function add_node_non_user_cache!(integrator::DiffEqBase.AbstractODEIntegrator,idxs,x,node...)
95+
i = length(integrator.u)
96+
resize_non_user_cache!(integrator,integrator.cache,i)
97+
end
98+
function add_node_non_user_cache!(integrator::DiffEqBase.AbstractSDEIntegrator,idxs,x,node...)
99+
if DiffEqBase.is_diagonal_noise(integrator.sol.prob)
100+
add_node_noise!(integrator,idxs,x,node...)
101+
for c in rand_cache(integrator)
102+
add_node!(c,copy(x),node...)
103+
end
104+
end
105+
end
106+
function add_node_non_user_cache!(integrator::DiffEqBase.AbstractSDEIntegrator,idxs,x)
107+
if DiffEqBase.is_diagonal_noise(integrator.sol.prob)
108+
add_node_noise!(integrator,idxs,x)
109+
for c in rand_cache(integrator)
110+
add_node!(c,copy(x))
111+
end
112+
end
113+
end
114+
115+
function add_node_noise!(integrator,idxs,x,node...)
116+
for c in integrator.W.S₁
117+
add_node!(c[2],copy(x),node...)
118+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
119+
add_node!(c[3],copy(x),node...)
120+
end
121+
StochasticDiffEq.fill_new_noise_caches!(integrator,c,c[1],idxs)
122+
end
123+
for c in integrator.W.S₂
124+
add_node!(c[2],copy(x),node...)
125+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
126+
add_node!(c[3],copy(x),node...)
127+
end
128+
StochasticDiffEq.fill_new_noise_caches!(integrator,c,c[1],idxs)
129+
end
130+
131+
add_node!(integrator.W.dW,copy(x),node...)
132+
integrator.W.dW[idxs] .= zero(eltype(integrator.u))
133+
add_node!(integrator.W.curW,copy(x),node...)
134+
integrator.W.curW[idxs] .= zero(eltype(integrator.u))
135+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
136+
add_node!(integrator.W.dZ,copy(x),node...)
137+
integrator.W.dZ[idxs] .= zero(eltype(integrator.u))
138+
add_node!(integrator.W.curZ,copy(x),node...)
139+
integrator.W.curZ[idxs] .= zero(eltype(integrator.u))
140+
end
141+
142+
i = length(integrator.u)
143+
add_node!(integrator.W.dWtilde,copy(x),node...)
144+
add_node!(integrator.W.dWtmp,copy(x),node...)
145+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
146+
add_node!(integrator.W.dZtmp,copy(x),node...)
147+
add_node!(integrator.W.dZtilde,copy(x),node...)
148+
end
149+
150+
# fill in rands
151+
fill!(@view(integrator.W.curW[idxs]),zero(eltype(integrator.u)))
152+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
153+
fill!(@view(integrator.W.curZ[idxs]),zero(eltype(integrator.u)))
154+
end
155+
end
156+
157+
function add_node_noise!(integrator,idxs,x)
158+
for c in integrator.W.S₁
159+
add_node!(c[2],copy(x))
160+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
161+
add_node!(c[3],copy(x))
162+
end
163+
StochasticDiffEq.fill_new_noise_caches!(integrator,c,c[1],idxs)
164+
end
165+
for c in integrator.W.S₂
166+
add_node!(c[2],copy(x))
167+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
168+
add_node!(c[3],copy(x))
169+
end
170+
StochasticDiffEq.fill_new_noise_caches!(integrator,c,c[1],idxs)
171+
end
172+
173+
add_node!(integrator.W.dW,copy(x))
174+
integrator.W.dW[idxs] .= zero(eltype(integrator.u))
175+
add_node!(integrator.W.curW,copy(x))
176+
integrator.W.curW[idxs] .= zero(eltype(integrator.u))
177+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
178+
add_node!(integrator.W.dZ,copy(x))
179+
integrator.W.dZ[idxs] .= zero(eltype(integrator.u))
180+
add_node!(integrator.W.curZ,copy(x))
181+
integrator.W.curZ[idxs] .= zero(eltype(integrator.u))
182+
end
183+
184+
i = length(integrator.u)
185+
add_node!(integrator.W.dWtilde,copy(x))
186+
add_node!(integrator.W.dWtmp,copy(x))
187+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
188+
add_node!(integrator.W.dZtmp,copy(x))
189+
add_node!(integrator.W.dZtilde,copy(x))
190+
end
191+
192+
# fill in rands
193+
fill!(@view(integrator.W.curW[idxs]),zero(eltype(integrator.u)))
194+
if StochasticDiffEq.alg_needs_extra_process(integrator.alg)
195+
fill!(@view(integrator.W.curZ[idxs]),zero(eltype(integrator.u)))
196+
end
197+
end

src/math.jl

+25-31
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,6 @@ Base.map!(f::F, m::AMSA, A0, As...) where {F} =
88

99
Base.BroadcastStyle(::Type{<:AMSA}) = Broadcast.ArrayStyle{AMSA}()
1010
Base.BroadcastStyle(::Type{<:AbstractMultiScaleArrayLeaf}) = Broadcast.ArrayStyle{AbstractMultiScaleArrayLeaf}()
11-
Base.BroadcastStyle(a::Broadcast.ArrayStyle{AMSA}, b::Base.Broadcast.DefaultArrayStyle) = b
12-
#=
13-
AMSAStyle(::S) where {S} = AMSAStyle{S}()
14-
AMSAStyle(::S, ::Val{N}) where {S,N} = AMSAStyle(S(Val(N)))
15-
AMSAStyle(::Val{N}) where N = AMSAStyle{Broadcast.DefaultArrayStyle{N}}()
16-
17-
18-
# promotion rules
19-
function Broadcast.BroadcastStyle(::AMSAStyle{AStyle}, ::AMSAStyle{BStyle}) where {AStyle, BStyle}
20-
AMSAStyle(Broadcast.BroadcastStyle(AStyle(), BStyle()))
21-
end
22-
=#
23-
24-
#=
25-
combine_styles(args::Tuple{}) = Broadcast.DefaultArrayStyle{0}()
26-
combine_styles(args::Tuple{Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]))
27-
combine_styles(args::Tuple{Any, Any}) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), Broadcast.BroadcastStyle(args[2]))
28-
@inline combine_styles(args::Tuple) = Broadcast.result_style(Broadcast.BroadcastStyle(args[1]), combine_styles(Base.tail(args)))
29-
30-
function Broadcast.BroadcastStyle(::Type{AMSA{T}}) where {T}
31-
Style = combine_styles((T.parameters...,))
32-
AMSAStyle(Style)
33-
end
34-
=#
3511

3612
@inline function Base.copy(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}})
3713
first_amsa = find_amsa(bc)
@@ -47,16 +23,26 @@ end
4723
out
4824
end
4925

50-
@inline function Base.copyto!(dest::AMSA, bc::Broadcast.Broadcasted{Nothing})
51-
N = length(dest.nodes)
52-
for i in 1:N
53-
copyto!(dest.nodes[i], unpack(bc, i))
26+
@inline function Base.copyto!(dest::AMSA, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AMSA}})
27+
if !any_non_amsa(bc)
28+
N = length(dest.nodes)
29+
for i in 1:N
30+
copyto!(dest.nodes[i], unpack(bc, i))
31+
end
32+
copyto!(dest.values,unpack(bc, nothing))
33+
else
34+
copyto!(dest,convert(Base.Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
5435
end
55-
copyto!(dest.values,unpack(bc, nothing))
36+
dest
5637
end
5738

58-
@inline function Base.copyto!(dest::AbstractMultiScaleArrayLeaf, bc::Broadcast.Broadcasted{Nothing})
59-
copyto!(dest.values,unpack(bc,nothing))
39+
@inline function Base.copyto!(dest::AbstractMultiScaleArrayLeaf, bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{AbstractMultiScaleArrayLeaf}})
40+
if !any_non_amsa(bc)
41+
copyto!(dest.values,unpack(bc,nothing))
42+
else
43+
copyto!(dest,convert(Base.Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
44+
end
45+
dest
6046
end
6147

6248
# drop axes because it is easier to recompute
@@ -85,6 +71,14 @@ find_amsa(x) = x
8571
find_amsa(a::AMSA, rest) = a
8672
find_amsa(::Any, rest) = find_amsa(rest)
8773

74+
any_non_amsa(bc::Base.Broadcast.Broadcasted) = any_non_amsa(bc.args)
75+
any_non_amsa(args::Tuple) = any_non_amsa(any_non_amsa(args[1]), Base.tail(args))
76+
any_non_amsa(x::AMSA) = false
77+
any_non_amsa(x::Number) = false
78+
any_non_amsa(x::Any) = true
79+
any_non_amsa(x::AbstractArray) = true
80+
any_non_amsa(x::Bool, rest) = isempty(rest) ? x : x || any_non_amsa(rest)
81+
8882
## utils
8983
common_number(a, b) =
9084
a == 0 ? b :

src/shape_construction.jl

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
length(m::AbstractMultiScaleArrayLeaf) = length(m.values)
22
length(m::AbstractMultiScaleArray) = m.end_idxs[end]
3+
Base.isempty(m::AbstractMultiScaleArray) = isempty(m.nodes) && isempty(m.values)
4+
Base.isempty(m::AbstractMultiScaleArrayLeaf) = isempty(m.values)
35
num_nodes(m::AbstractMultiScaleArrayLeaf) = 0
46
num_nodes(m::AbstractMultiScaleArray) = size(m.nodes, 1)
57
ndims(m::AbstractMultiScaleArray) = 1

test/dynamic_diffeq.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ cell4 = Cell([4.0; 6])
3030
population2 = construct(Population, deepcopy([cell3, cell4]))
3131
tissue1 = construct(Tissue, deepcopy([population, population2])) # Make a Tissue from Populations
3232
tissue2 = construct(Tissue, deepcopy([population2, population]))
33-
embryo = construct(Embryo, deepcopy([tissue1, tissue2])) # Make an embryo from Tissues
33+
_embryo = construct(Embryo, deepcopy([tissue1, tissue2])) # Make an embryo from Tissues
34+
embryo = deepcopy(_embryo)
3435

3536
cell_ode = function (dcell,cell,p,t)
3637
m = mean(cell)

test/indexing_and_creation_tests.jl

+1-7
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,7 @@ size(cell1)
144144

145145
t = 2
146146
p/zero(t)
147-
148-
size(p)
149-
150-
p ./ rand(length(p))
151-
147+
p ./ randn(length(p))
152148

153149
f = function (du,u,p,t)
154150
for i in eachindex(u)
@@ -179,8 +175,6 @@ Random.seed!(100)
179175
prob = SDEProblem(f, g, em, (0.0, 1000.0))
180176
@time sol1 = solve(prob, SRIW1(), progress=false, abstol=1e-2, reltol=1e-2, save_everystep=false)
181177

182-
cell1 .= randn.()
183-
184178
Random.seed!(100)
185179
prob = SDEProblem(f, g, em[:], (0.0, 1000.0))
186180
@time sol2 = solve(prob, SRIW1(), progress=false, abstol=1e-2, reltol=1e-2, save_everystep=false)

0 commit comments

Comments
 (0)