Skip to content

Commit 8d56e48

Browse files
fix for newest OrdinaryDiffEq
1 parent c3cb313 commit 8d56e48

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MultiScaleArrays"
22
uuid = "f9640e96-87f6-5992-9c3b-0743c6a49ffa"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "1.8.0"
4+
version = "1.8.1"
55

66
[deps]
77
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"

src/diffeq.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,26 @@ function remove_node_grad_config!(cache,grad_config::ForwardDiff.DerivativeConfi
169169
nothing
170170
end
171171

172+
function add_node_grad_config!(cache,grad_config::AbstractArray,i,x)
173+
cache.grad_config = ForwardDiff.Dual{typeof(ForwardDiff.Tag(cache.tf,eltype(cache.du1)))}.(cache.du1, cache.du1)
174+
nothing
175+
end
176+
177+
function add_node_grad_config!(cache,grad_config::AbstractArray,i,x,I...)
178+
cache.grad_config = ForwardDiff.Dual{typeof(ForwardDiff.Tag(cache.tf,eltype(cache.du1)))}.(cache.du1, cache.du1)
179+
nothing
180+
end
181+
182+
function remove_node_grad_config!(cache,grad_config::AbstractArray,i,x)
183+
cache.grad_config = ForwardDiff.Dual{typeof(ForwardDiff.Tag(cache.tf,eltype(cache.du1)))}.(cache.du1, cache.du1)
184+
nothing
185+
end
186+
187+
function remove_node_grad_config!(cache,grad_config::AbstractArray,i,x,I...)
188+
cache.grad_config = ForwardDiff.Dual{typeof(ForwardDiff.Tag(cache.tf,eltype(cache.du1)))}.(cache.du1, cache.du1)
189+
nothing
190+
end
191+
172192
function add_node_grad_config!(cache,grad_config::FiniteDiff.GradientCache,i,x,I...)
173193
grad_config.fx !== nothing && add_node!(grad_config.fx, recursivecopy(x), I...)
174194
grad_config.c1 !== nothing && add_node!(grad_config.c1, recursivecopy(x), I...)

src/math.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,17 @@ Broadcast.BroadcastStyle(::Type{<:AMSA}) = AMSAStyle()
1414

1515
@inline function Base.copy(bc::Broadcast.Broadcasted{<:AMSAStyle})
1616
first_amsa = find_amsa(bc)
17-
out = similar(first_amsa)
17+
18+
out = similar(first_amsa,Base.Broadcast._broadcast_getindex_eltype(bc))
19+
20+
#=
21+
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
22+
if Base.isconcretetype(ElType)
23+
# We can trust it and defer to the simpler `copyto!`
24+
return copyto!(similar(bc, ElType), bc)
25+
end
26+
=#
27+
1828
copyto!(out,bc)
1929
out
2030
end

0 commit comments

Comments
 (0)