From 29f67182eea8b8e20a8b3f064ddfd3cb112f07d3 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 1 Oct 2023 17:40:12 -0700 Subject: [PATCH 1/5] replace DiffEqSensitivity with SciMLSensitivity --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b35e6284..66686a8b 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" @@ -28,6 +27,7 @@ QuartzImageIO = "dca85d43-d64c-5e67-8c65-017450d5d020" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" SmoothingSplines = "102930c3-cf33-599f-b3b1-9a29a5acab30" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -47,7 +47,7 @@ Colors = "0.11,0.12" Conda = "1" CSV = "0" DataFrames = "1" -DiffEqSensitivity = "6" +SciMLSensitivity = "^7" ForwardDiff = "0.10" Images = "0.24" IncompleteLU = "0.2" From 97cbe1697e56663c9df9731b96f6e5d19a9dc916 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sun, 1 Oct 2023 17:40:39 -0700 Subject: [PATCH 2/5] adapt adjoint_sensitivity syntax --- src/Simulation.jl | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/Simulation.jl b/src/Simulation.jl index 206325c6..d80012d6 100644 --- a/src/Simulation.jl +++ b/src/Simulation.jl @@ -1,6 +1,6 @@ using SciMLBase -import SciMLBase: AbstractODESolution, HermiteInterpolation -using DiffEqSensitivity +import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInterpolation +using SciMLSensitivity using ForwardDiff using PreallocationTools @@ -538,15 +538,19 @@ function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2 if length(bsol.domain.p) <= pethane if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces) - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=dgdu, + dgdp_continuous=dgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) else - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdu, dsensgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdu, + dgdp_continuous=dsensgdp, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) end else if target in ["T", "V", "P", "mass"] || !isempty(bsol.interfaces) - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, g, nothing, (dgdurevdiff, dgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=g, dgdu_continuous=gdurevdiff, + dgdp_continuous=dgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) else - du0, dpadj = adjoint_sensitivities(bsol.sol, solver, sensg, nothing, (dsensgdurevdiff, dsensgdprevdiff); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) + du0, dpadj = adjoint_sensitivities(bsol.sol, solver; g=sensg, dgdu_continuous=dsensgdurevdiff, + dgdp_continuous=dsensgdprevdiff, sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) end end if normalize @@ -591,7 +595,8 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W; end dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y) dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p) - du0, dpadj = adjoint_sensitivities(syssim.sol, solver, g, nothing, (dgdu, dgdp); sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) + du0, dpadj = adjoint_sensitivities(syssim.sol, solver; g=g, dgdu_continuous=dgdu, dgdp_continuous=dgdp, + sensealg=sensalg, abstol=abstol, reltol=reltol, kwargs...) if normalize for domain in domains dpadj[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]] .*= syssim.p[domain.parameterindexes[1]+length(domain.phase.species):domain.parameterindexes[2]] From 0cd360123379cd753f9b2f03c2f322d2527ca456 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Wed, 7 Feb 2024 12:28:36 -0800 Subject: [PATCH 3/5] increase absolute tolerances for test runs in adjoint sensitivities --- src/TestReactors.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/TestReactors.jl b/src/TestReactors.jl index 2612bd29..dc20d76f 100644 --- a/src/TestReactors.jl +++ b/src/TestReactors.jl @@ -220,7 +220,7 @@ jp=jacobianpforwarddiff(y,p,t,domain,[],nothing); @test all((abs.(jpa.-jp) .> 1e-4.*abs.(jp).+1e-16).==false) #sensitivities -dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-16,reltol=1e-6) +dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-12,reltol=1e-6) react2 = Reactor(domain,y0,(0.0,150.11094);p=p,forwardsensitivities=true) sol2 = solve(react2.ode,CVODE_BDF(linear_solver=:GMRES),abstol=1e-21,reltol=1e-7); #solve the ode associated with the reactor sim2 = Simulation(sol2,domain) @@ -272,7 +272,7 @@ end; @test all((abs.(jpa.-jp) .> 1e-4.*abs.(jp).+1e-16).==false) #sensitivities - dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-16,reltol=1e-6) + dps = getadjointsensitivities(sim,"H2",CVODE_BDF(linear_solver=:GMRES);sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true)),abstol=1e-12,reltol=1e-6) react2 = Reactor(domain,y0,(0.0,150.11094),interfaces;p=p,forwardsensitivities=true) sol2 = solve(react2.ode,CVODE_BDF(linear_solver=:GMRES),abstol=1e-21,reltol=1e-7); #solve the ode associated with the reactor sim2 = Simulation(sol2,domain,interfaces) @@ -316,7 +316,7 @@ jp=jacobianpforwarddiff(y,p,t,domain,[],nothing); react = Reactor(domain,y0,(0.0,0.02),p=p) #Create the reactor object sol = solve(react.ode,CVODE_BDF(),abstol=1e-20,reltol=1e-12); #solve the ode associated with the reactor sim = Simulation(sol,domain) -dps = getadjointsensitivities(sim,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6) +dps = getadjointsensitivities(sim,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6) react2 = Reactor(domain,y0,(0.0,0.02);p=p,forwardsensitivities=true) sol2 = solve(react2.ode,CVODE_BDF(),abstol=1e-16,reltol=1e-6); #solve the ode associated with the reactor sim2 = Simulation(sol2,domain) @@ -488,8 +488,8 @@ end; @test sol(t)[1:length(spcs)] ≈ solV(t)[1:end-2] rtol=1e-5 @test sol(t)[length(spcs)+1:end-4] ≈ solV(t)[1:end-2] rtol=1e-5 - dpsV = getadjointsensitivities(simV,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6) - dps = getadjointsensitivities(sysim,sysim.sims[1],"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-16,reltol=1e-6) + dpsV = getadjointsensitivities(simV,"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6) + dps = getadjointsensitivities(sysim,sysim.sims[1],"H2",CVODE_BDF();sensealg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol=1e-12,reltol=1e-6) @test dpsV ≈ dps rtol=1e-4 end; From f0aa7109ef08491c0b2b8299dc924e27a0391132 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Wed, 7 Feb 2024 13:03:00 -0800 Subject: [PATCH 4/5] Remove parametrized typing from getadjointsensitivities --- src/Simulation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Simulation.jl b/src/Simulation.jl index d80012d6..d9035a93 100644 --- a/src/Simulation.jl +++ b/src/Simulation.jl @@ -459,8 +459,8 @@ By default uses the InterpolatingAdjoint algorithm with vector Jacobian products this assumes no changes in code branching during simulation, if that were to become no longer true, the Tracker based alternative algorithm is slower, but avoids this concern. """ -function getadjointsensitivities(bsol::Q, target::String, solver::W; sensalg::W2=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)), - abstol::Float64=1e-6, reltol::Float64=1e-3, normalize=true, kwargs...) where {Q,W,W2} +function getadjointsensitivities(bsol::Simulation, target::String, solver; sensalg=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)), + abstol::Float64=1e-6, reltol::Float64=1e-3, normalize=true, kwargs...) @assert target in bsol.names || target in ["T", "V", "P", "mass"] pethane = 160 From e80bb706578da17d76881c6f74e582987ad17884 Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Wed, 7 Feb 2024 13:28:06 -0800 Subject: [PATCH 5/5] Fix type stability in getadjointsensitivities --- src/Simulation.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/Simulation.jl b/src/Simulation.jl index d9035a93..95668fa7 100644 --- a/src/Simulation.jl +++ b/src/Simulation.jl @@ -3,6 +3,7 @@ import SciMLBase: AbstractODESolution, HermiteInterpolation, AbstractDiffEqInter using SciMLSensitivity using ForwardDiff using PreallocationTools +using LinearAlgebra abstract type AbstractSimulation end export AbstractSimulation @@ -561,7 +562,7 @@ function getadjointsensitivities(bsol::Simulation, target::String, solver; sensa dpadj ./= bsol.sol(bsol.sol.t[end])[ind] end end - return dpadj + return dpadj::LinearAlgebra.Adjoint{Float64, Vector{Float64}} end function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W; sensalg::W2=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)), @@ -605,7 +606,7 @@ function getadjointsensitivities(syssim::Q, bsol::W3, target::String, solver::W; dpadj ./= bsol.sol(bsol.sol.t[end])[ind] end end - return dpadj + return dpadj::LinearAlgebra.Adjoint{Float64, Vector{Float64}} end export getadjointsensitivities