diff --git a/Project.toml b/Project.toml index 6891c4fc4..5207a51ce 100644 --- a/Project.toml +++ b/Project.toml @@ -77,7 +77,7 @@ LinearAlgebra = "1.10" LinearSolve = "2, 3" Lux = "1" Markdown = "1.10" -ModelingToolkit = "9.74" +ModelingToolkit = "9.78" ModelingToolkitStandardLibrary = "2" Mooncake = "0.4.52" NLsolve = "4.5.1" diff --git a/src/SciMLSensitivity.jl b/src/SciMLSensitivity.jl index a3569f384..c239f871a 100644 --- a/src/SciMLSensitivity.jl +++ b/src/SciMLSensitivity.jl @@ -33,7 +33,7 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm, AbstractNonlinearProblem, AbstractSensitivityAlgorithm, AbstractDiffEqFunction, AbstractODEFunction, unwrapped_f, CallbackSet, ContinuousCallback, DESolution, NonlinearFunction, NonlinearProblem, - DiscreteCallback, LinearProblem, ODEFunction, ODEProblem, + DiscreteCallback, LinearProblem, ODEFunction, ODEProblem, DAEProblem, RODEFunction, RODEProblem, ReturnCode, SDEFunction, SDEProblem, VectorContinuousCallback, deleteat!, get_tmp_cache, has_adjoint, isinplace, reinit!, remake, diff --git a/src/adjoint_common.jl b/src/adjoint_common.jl index 01838ae70..609d26528 100644 --- a/src/adjoint_common.jl +++ b/src/adjoint_common.jl @@ -761,4 +761,4 @@ if !hasmethod(Zygote.adjoint, end sol.u, solu_adjoint end -end +end \ No newline at end of file diff --git a/src/concrete_solve.jl b/src/concrete_solve.jl index ba6e0a1bc..2bc4d65be 100644 --- a/src/concrete_solve.jl +++ b/src/concrete_solve.jl @@ -408,7 +408,7 @@ function DiffEqBase._concrete_solve_adjoint( end # Remove callbacks, saveat, etc. from kwargs since it's handled separately - kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback,))}(values(kwargs)) + kwargs_fwd = NamedTuple{Base.diff_names(Base._nt_names(values(kwargs)), (:callback, :initializealg))}(values(kwargs)) # Capture the callback_adj for the reverse pass and remove both callbacks kwargs_adj = NamedTuple{ @@ -454,10 +454,11 @@ function DiffEqBase._concrete_solve_adjoint( end igs = back(one(iy))[1] .- one(eltype(tunables)) - igs, new_u0, new_p, SciMLBase.NoInit() + igs, new_u0, new_p, SciMLBase.CheckInit() else nothing, u0, p, initializealg end + _prob = remake(_prob, u0 = new_u0, p = new_p) if sensealg isa BacksolveAdjoint @@ -672,18 +673,20 @@ function DiffEqBase._concrete_solve_adjoint( else cb2 = cb end - if ArrayInterface.ismutable(eltype(state_values(sol))) + + if prob isa Union{ODEProblem, DAEProblem} du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, - dgdu_discrete = df_iip, - sensealg = sensealg, - callback = cb2, - kwargs_init...) + dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop, + sensealg = sensealg, + callback = cb2, + initializealg = BrownFullBasicInit(), + kwargs_init...) else du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts, - dgdu_discrete = df_oop, - sensealg = sensealg, - callback = cb2, - kwargs_init...) + dgdu_discrete = ArrayInterface.ismutable(eltype(state_values(sol))) ? df_iip : df_oop, + sensealg = sensealg, + callback = cb2, + kwargs_init...) end du0 = reshape(du0, size(u0)) @@ -1581,6 +1584,8 @@ function DiffEqBase._concrete_solve_adjoint( Array(ybar) elseif eltype(ybar) <: AbstractArray Array(VectorOfArray(ybar)) + elseif ybar isa Tangent + Array(VectorOfArray(ybar.u)) else ybar end @@ -1769,7 +1774,8 @@ function DiffEqBase._concrete_solve_adjoint( @. _out[_save_idxs] = Δ.u[_save_idxs] end end - dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit()) + + dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df) dp, Δtunables = if Δ isa AbstractArray || Δ isa Number # if Δ isa AbstractArray, the gradients correspond to `u` diff --git a/test/desauty_dae_mwe.jl b/test/desauty_dae_mwe.jl index 68baa0277..c40539675 100644 --- a/test/desauty_dae_mwe.jl +++ b/test/desauty_dae_mwe.jl @@ -36,7 +36,7 @@ desauty_model = create_model() sys = structural_simplify(desauty_model) -prob = ODEProblem(sys, [], (0.0, 0.1), guesses = [sys.resistor1.v => 1.]) +prob = ODEProblem(sys, [sys.resistor1.v => 1.], (0.0, 0.1)) iprob = prob.f.initialization_data.initializeprob isys = iprob.f.sys diff --git a/test/mtk.jl b/test/mtk.jl index 1f2e3d237..fbd088dbf 100644 --- a/test/mtk.jl +++ b/test/mtk.jl @@ -70,27 +70,26 @@ tspan = (0.0, 100.0) # and with the initialization corrected to satisfy the algebraic equation prob_incorrectu0 = ODEProblem(sys, u0_incorrect, tspan, p, jac = true, guesses = [w2 => 0.0]) mtkparams_incorrectu0 = SciMLSensitivity.parameter_values(prob_incorrectu0) +test_sol = solve(prob_incorrectu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3) u0_timedep = [D(x) => 2.0, x => 1.0, y => t, - z => 0.0, - w2 => 0.0,] + z => 0.0] # this ensures that `y => t` is not applied in the adjoint equation # If the MTK init is called for the reverse, then `y0` in the backwards # pass will be extremely far off and cause an incorrect gradient prob_timedepu0 = ODEProblem(sys, u0_timedep, tspan, p, jac = true, guesses = [w2 => 0.0]) mtkparams_timedepu0 = SciMLSensitivity.parameter_values(prob_incorrectu0) +test_sol = solve(prob_timedepu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3) u0_correct = [D(x) => 2.0, x => 1.0, y => 0.0, - z => 0.0, - w2 => -1.0,] + z => 0.0,] prob_correctu0 = ODEProblem(sys, u0_correct, tspan, p, jac = true, guesses = [w2 => -1.0]) mtkparams_correctu0 = SciMLSensitivity.parameter_values(prob_correctu0) -prob_correctu0.u0[5] = -1.0 - +test_sol = solve(prob_correctu0, Rodas5P(), abstol = 1e-6, reltol = 1e-3) u0_overdetermined = [D(x) => 2.0, x => 1.0, y => 0.0, @@ -98,6 +97,7 @@ u0_overdetermined = [D(x) => 2.0, w2 => -1.0,] prob_overdetermined = ODEProblem(sys, u0_overdetermined, tspan, p, jac = true) mtkparams_overdetermined = SciMLSensitivity.parameter_values(prob_overdetermined) +test_sol = solve(prob_overdetermined, Rodas5P(), abstol = 1e-6, reltol = 1e-3) sensealg = GaussAdjoint(; autojacvec = SciMLSensitivity.ZygoteVJP()) @@ -115,7 +115,7 @@ setups = [ (prob_correctu0, mtkparams_correctu0, BrownFullBasicInit()), (prob_correctu0, mtkparams_correctu0, OrdinaryDiffEqCore.DefaultInit()), - (prob_correctu0, mtkparams_correctu0, NoInit()), + (prob_correctu0, mtkparams_correctu0, NoInit()), (prob_correctu0, mtkparams_correctu0, nothing), (prob_overdetermined, mtkparams_overdetermined, BrownFullBasicInit()), @@ -123,17 +123,18 @@ setups = [ (prob_overdetermined, mtkparams_overdetermined, NoInit()), (prob_overdetermined, mtkparams_overdetermined, nothing), -] +]; grads = map(setups) do setup prob, ps, init = setup @show init u0 = prob.u0 Zygote.gradient(u0, ps) do u0,p + new_prob = remake(prob, u0 = u0, p = p) if init === nothing - new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, sensealg, abstol = 1e-6, reltol = 1e-3) + new_sol = solve(new_prob, Rodas5P(); sensealg, abstol = 1e-6, reltol = 1e-3) else - new_sol = solve(prob, Rodas5P(); u0 = u0, p = ps, initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3) + new_sol = solve(new_prob, Rodas5P(); initializealg = init, sensealg, abstol = 1e-6, reltol = 1e-3) end gt = Zygote.ChainRules.ChainRulesCore.ignore_derivatives() do @test new_sol.retcode == SciMLBase.ReturnCode.Success @@ -148,5 +149,5 @@ end u0grads = getindex.(grads,1) pgrads = getproperty.(getindex.(grads, 2), (:tunable,)) -@test all(x ≈ u0grads[1] for x in grads) -@test all(x ≈ pgrads[1] for x in grads) +@test all(x ≈ u0grads[1] for x in u0grads) +@test all(x ≈ pgrads[1] for x in pgrads)