Skip to content

Commit 9bdbec9

Browse files
committed
add upstream test?
1 parent 1c90fdc commit 9bdbec9

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ForwardDiff = "0.10.3"
1616
julia = "1.6"
1717

1818
[extras]
19+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1920
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
2021
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2122
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -24,7 +25,9 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2425
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2526
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2627
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
28+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2729
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
30+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2831

2932
[targets]
30-
test = ["LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL"]
33+
test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"]

src/PreallocationTools.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,10 @@ end
9494

9595
function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray)
9696
s = b.sizemap(size(u)) # required buffer size
97+
T = ReverseDiff.TrackedArray
9798
buf = get!(b.bufs, (T, s)) do
9899
# declare type since b.bufs dictionary is untyped
99-
similar(u, s)::T # buffer to allocate if it was not found in b.bufs
100+
similar(u, s)
100101
end
101102
return buf
102103
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core"
1515
@safetestset "ODE tests" begin include("core_odes.jl") end
1616
@safetestset "Resizing" begin include("core_resizing.jl") end
1717
@safetestset "Nested Duals" begin include("core_nesteddual.jl") end
18+
@safetestset "ODE Sensitivity analysis" begin include("upstream/sensitivity_analysis.jl") end
1819
end
1920

2021
if !is_APPVEYOR && GROUP == "GPU"

test/upstream/sensitivity_analysis.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools
2+
using Random, FiniteDiff, ForwardDiff, ReverseDiff, SciMLSensitivity, Zygote
3+
4+
# see https://github.com/SciML/PreallocationTools.jl/issues/29
5+
@testset "VJP computation with LazyBuffer" begin
6+
u0 = rand(2, 2)
7+
p = rand(2, 2)
8+
struct foo{T}
9+
lbc::T
10+
end
11+
12+
f = foo(LazyBufferCache())
13+
14+
function (f::foo)(du, u, p, t)
15+
tmp = f.lbc[u]
16+
mul!(tmp, p, u) # avoid tmp = p*u
17+
@. du = u + tmp
18+
nothing
19+
end
20+
21+
prob = ODEProblem(f, u0, (0.0, 1.0), p)
22+
23+
function loss(u0, p; sensealg = nothing)
24+
_prob = remake(prob, u0 = u0, p = p)
25+
_sol = solve(_prob, Tsit5(), sensealg = sensealg, saveat = 0.1, abstol = 1e-14,
26+
reltol = 1e-14)
27+
sum(abs2, _sol)
28+
end
29+
30+
loss(u0, p)
31+
32+
du0 = FiniteDiff.finite_difference_gradient(u0 -> loss(u0, p), u0)
33+
dp = FiniteDiff.finite_difference_gradient(p -> loss(u0, p), p)
34+
Fdu0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0)
35+
Fdp = ForwardDiff.gradient(p -> loss(u0, p), p)
36+
@test du0Fdu0 rtol=1e-8
37+
@test dpFdp rtol=1e-8
38+
39+
Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p;
40+
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())),
41+
u0, p)
42+
@test du0Zdu0 rtol=1e-8
43+
@test dpZdp rtol=1e-8
44+
end

0 commit comments

Comments
 (0)