Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient w.r.t parameters not working with MTKParameters #1130

Open
AayushSabharwal opened this issue Oct 7, 2024 · 4 comments
Open

Gradient w.r.t parameters not working with MTKParameters #1130

AayushSabharwal opened this issue Oct 7, 2024 · 4 comments
Assignees
Labels

Comments

@AayushSabharwal
Copy link
Member

Describe the bug 🐞

Taking the gradient with respect to a vector of parameter values (which are replaced into the parameter object) is not working
with MTKParameters

Expected behavior

The gradient works

Minimal Reproducible Example 👇

using ModelingToolkit, OrdinaryDiffEq, Zygote, SciMLSensitivity
using SymbolicIndexingInterface: setp_oop
using ModelingToolkit: t_nounits as t, D_nounits as D

@variables x(t) o(t)
function lotka_volterra(; name = name)
    unknowns = @variables x(t)=1.0 y(t)=1.0 o(t)
    params = @parameters p1=1.5 p2=1.0 p3=3.0 p4=1.0
    eqs = [
        D(x) ~ p1 * x - p2 * x * y,
        D(y) ~ -p3 * y + p4 * x * y,
        o ~ x * y
    ]
    return ODESystem(eqs, t, unknowns, params; name = name)
end

@mtkbuild lotka_volterra_sys = lotka_volterra()
prob = ODEProblem(lotka_volterra_sys, [], (0.0, 10.0), [])
u0 = [1.0, 1.0]
p = [1.5, 1.0, 1.0, 1.0]

oop_setter = setp_oop(prob, [lotka_volterra_sys.p1, lotka_volterra_sys.p2, lotka_volterra_sys.p3, lotka_volterra_sys.p4])

function symbolic_indexing(u0, p)
           _p = oop_setter(prob, p)
           _prob = remake(prob, u0 = u0, p = _p)
           soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
               sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
           sum(soln[x])
       end

du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)

Error & Stacktrace ⚠️

Adjoint sensitivity analysis functionality requires being able to solve
a differential equation defined by the parameter struct `p`. Thus while
DifferentialEquations.jl can support any parameter struct type, usage
with adjoint sensitivity analysis requires that `p` could be a valid
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during adjoint differentiation.
To work around this issue for complicated cases like nested structs, look
into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl so that `p` is an `AbstractArray` with a concrete element type.

Stacktrace:
  [1] _concrete_solve_adjoint(::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, FunctionWrappersWrappers.FunctionWrappersWrapper{Tuple{FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, Float64}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{Float64}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}, FunctionWrappers.FunctionWrapper{Nothing, Tuple{Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, Vector{ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ForwardDiff.Dual{ForwardDiff.Tag{DiffEqBase.OrdinaryDiffEqTag, Float64}, Float64, 1}}}}, false}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::SciMLBase.ChainRulesOriginator; save_start::Bool, save_end::Bool, saveat::Float64, save_idxs::Nothing, kwargs::@Kwargs{reltol::Float64, abstol::Float64})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:378
  [2] _solve_adjoint(prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, originator::SciMLBase.ChainRulesOriginator, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; merge_callbacks::Bool, kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1537
  [3] rrule(::typeof(DiffEqBase.solve_up), prob::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, u0::Vector{Float64}, p::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, args::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False}; kwargs::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64})
    @ DiffEqBaseChainRulesCoreExt ~/.julia/packages/DiffEqBase/DdIeW/ext/DiffEqBaseChainRulesCoreExt.jl:26
  [4] kwcall(::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(ChainRulesCore.rrule), ::Zygote.ZygoteRuleConfig{Zygote.Context{false}}, ::Function, ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ ChainRulesCore ~/.julia/packages/ChainRulesCore/6Pucz/src/rules.jl:144
  [5] chain_rrule_kw
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:236 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [7] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(DiffEqBase.solve_up), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Vector{Float64}, ::MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [8] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
  [9] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [10] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [11] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [12] _pullback(::Zygote.Context{false}, ::DiffEqBase.var"##solve#51", ::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}, ::Nothing, ::Nothing, ::Val{true}, ::@Kwargs{reltol::Float64, abstol::Float64, saveat::Float64}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] _apply(::Function, ::Vararg{Any})
    @ Core ./boot.jl:838
 [14] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [15] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [16] solve
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
 [17] _pullback(::Zygote.Context{false}, ::typeof(Core.kwcall), ::@NamedTuple{reltol::Float64, abstol::Float64, saveat::Float64, sensealg::BacksolveAdjoint{0, true, Val{:central}, ZygoteVJP}}, ::typeof(solve), ::ODEProblem{Vector{Float64}, Tuple{Float64, Float64}, true, MTKParameters{Vector{Float64}, Tuple{}, Tuple{}, Tuple{}}, ODEFunction{true, SciMLBase.AutoSpecialize, ModelingToolkit.var"#f#835"{RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x4abac1fc, 0xb17c7b66, 0x95b8fd42, 0xf9281edf, 0xa2a10f56), Nothing}, RuntimeGeneratedFunctions.RuntimeGeneratedFunction{(:ˍ₋out, :ˍ₋arg1, :ˍ₋arg2, :t), ModelingToolkit.var"#_RGF_ModTag", ModelingToolkit.var"#_RGF_ModTag", (0x25ed63bf, 0xcf4bccc8, 0x9286cc6c, 0x7330a30a, 0xd77e77e0), Nothing}}, LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, ModelingToolkit.ObservedFunctionCache{ODESystem}, Nothing, ODESystem, Nothing, Nothing}, @Kwargs{}, SciMLBase.StandardODEProblem}, ::Tsit5{typeof(OrdinaryDiffEqCore.trivial_limiter!), typeof(OrdinaryDiffEqCore.trivial_limiter!), Static.False})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [18] symbolic_indexing
    @ ./REPL[125]:4 [inlined]
 [19] _pullback(::Zygote.Context{false}, ::typeof(symbolic_indexing), ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [20] pullback(::Function, ::Zygote.Context{false}, ::Vector{Float64}, ::Vararg{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90
 [21] pullback(::Function, ::Vector{Float64}, ::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88
 [22] gradient(::Function, ::Vector{Float64}, ::Vararg{Any})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:147
 [23] top-level scope
    @ REPL[126]:1

Environment (please complete the following information):

  • Output of using Pkg; Pkg.status()

(SciMLBase test/downstream environment)

Status `~/Julia/SciML/SciMLBase.jl/test/downstream/Project.toml`
  [764a87c0] BoundaryValueDiffEq v5.10.0
  [bcd4f6db] DelayDiffEq v5.48.1
⌅ [459566f4] DiffEqCallbacks v3.9.1
  [f6369f11] ForwardDiff v0.10.36
  [ccbc3e58] JumpProcesses v9.13.7
  [961ee093] ModelingToolkit v9.42.0
  [16a59e39] ModelingToolkitStandardLibrary v2.15.0
⌃ [8913a72c] NonlinearSolve v3.14.0
⌅ [7f7a1694] Optimization v3.28.0
⌅ [fd9f6733] OptimizationMOI v0.4.3
⌅ [36348300] OptimizationOptimJL v0.3.2
  [1dea7af3] OrdinaryDiffEq v6.89.0
  [91a5bcdd] Plots v1.40.8
  [731186ca] RecursiveArrayTools v3.27.0
  [0bca4576] SciMLBase v2.55.0 `../..`
  [1ed8b502] SciMLSensitivity v7.68.0
  [53ae85a6] SciMLStructures v1.5.0
  [860ef19b] StableRNGs v1.0.2
  [9672c7b4] SteadyStateDiffEq v2.4.1
  [789caeaf] StochasticDiffEq v6.69.1
  [c3572dad] Sundials v4.25.0
  [2efcf032] SymbolicIndexingInterface v0.3.31
  [d1185830] SymbolicUtils v3.7.1
  [1986cc42] Unitful v1.21.0
  [e88e6eb3] Zygote v0.6.71
@DhairyaLGandhi
Copy link
Member

I get a different error complaining about oop_setter.

julia> oop_setter(prob, p)
ERROR: TypeError: in validate_parameter_type, in Parameter ModelingToolkit.ParameterIndex{SciMLStructures.Tunable, Int64}(SciMLStructures.Tunable(), 3, false), expected ModelingToolkit.ParameterIndex, got a value of type Float64
Stacktrace:
 [1] validate_parameter_type(ic::ModelingToolkit.IndexCache, p::ModelingToolkit.ParameterIndex{…}, index::ModelingToolkit.ParameterIndex{…}, val::Float64)
   @ ModelingToolkit ~/.julia/packages/ModelingToolkit/hZRaA/src/systems/parameter_buffer.jl:489
 [2] remake_buffer(indp::ODESystem, oldbuf::MTKParameters{…}, vals::Dict{…})
   @ ModelingToolkit ~/.julia/packages/ModelingToolkit/hZRaA/src/systems/parameter_buffer.jl:529
 [3] remake_buffer(sys::ODESystem, oldbuffer::MTKParameters{…}, idxs::Vector{…}, vals::Vector{…})
   @ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/remake.jl:59
 [4] (::SymbolicIndexingInterface.OOPSetter{…})(valp::ODEProblem{…}, val::Vector{…})
   @ SymbolicIndexingInterface ~/.julia/packages/SymbolicIndexingInterface/cwAFH/src/parameter_indexing.jl:740
 [5] top-level scope
   @ REPL[5]:1
Some type information was truncated. Use `show(err)` to see complete types.

@AayushSabharwal
Copy link
Member Author

With GaussAdjoint I get the following:

julia> function symbolic_indexing(u0, p)
                  _p = oop_setter(prob, p)
                  _prob = remake(prob, u0 = u0, p = _p)
                  soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
                      sensealg = GaussAdjoint(autojacvec = ZygoteVJP()))
                  sum(soln[x])
              end
symbolic_indexing (generic function with 1 method)

julia> du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)
ERROR: No matching function wrapper was found!
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(ctx::Zygote.Context{false}, f::typeof(throw), args::FunctionWrappersWrappers.NoFunctionWrapperFoundError)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:87
  [3] _call
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:23 [inlined]
  [4] _pullback(::Zygote.Context{…}, ::typeof(FunctionWrappersWrappers._call), ::Tuple{}, ::Tuple{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
  [5] _call
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:13 [inlined]
--- the last 2 lines are repeated 3 more times ---
 [12] _pullback(::Zygote.Context{…}, ::typeof(FunctionWrappersWrappers._call), ::Tuple{…}, ::Tuple{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [13] FunctionWrappersWrapper
    @ ~/.julia/packages/FunctionWrappersWrappers/9XR0m/src/FunctionWrappersWrappers.jl:10 [inlined]
 [14] _pullback(::Zygote.Context{…}, ::FunctionWrappersWrappers.FunctionWrappersWrapper{…}, ::Vector{…}, ::MTKParameters{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [15] _apply
    @ ./boot.jl:838 [inlined]
 [16] adjoint
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:203 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:67 [inlined]
 [18] ODEFunction
    @ ~/Julia/SciML/SciMLBase.jl/src/scimlfunctions.jl:2330 [inlined]
 [19] _pullback(::Zygote.Context{…}, ::ODEFunction{…}, ::Vector{…}, ::MTKParameters{…}, ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [20] #262
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:488 [inlined]
 [21] _pullback(ctx::Zygote.Context{…}, f::SciMLSensitivity.var"#262#263"{}, args::MTKParameters{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [22] pullback
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:90 [inlined]
 [23] pullback
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:88 [inlined]
 [24] vec_pjac!(out::Vector{…}, λ::Vector{…}, y::Vector{…}, t::Float64, S::SciMLSensitivity.GaussIntegrand{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:487
 [25] GaussIntegrand
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:519 [inlined]
 [26] (::SciMLSensitivity.var"#265#266"{})(out::Vector{…}, u::Vector{…}, t::Float64, integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:560
 [27] (::DiffEqCallbacks.SavingIntegrandSumAffect{…})(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ DiffEqCallbacks ~/.julia/packages/DiffEqCallbacks/n5zrr/src/integrating_sum.jl:50
 [28] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/callbacks.jl:618 [inlined]
 [29] apply_discrete_callback!
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/callbacks.jl:637 [inlined]
 [30] handle_callbacks!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:355
 [31] _loopfooter!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:243
 [32] loopfooter!
    @ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/integrators/integrator_utils.jl:207 [inlined]
 [33] solve!(integrator::OrdinaryDiffEqCore.ODEIntegrator{…})
    @ OrdinaryDiffEqCore ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:552
 [34] #__solve#75
    @ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:7 [inlined]
 [35] __solve
    @ ~/.julia/packages/OrdinaryDiffEqCore/55UVY/src/solve.jl:1 [inlined]
 [36] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{})
    @ DiffEqBase ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:612
 [37] solve_call
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:569 [inlined]
 [38] #solve_up#53
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1080 [inlined]
 [39] solve_up
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1066 [inlined]
 [40] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [41] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::GaussAdjoint{…}, alg::Tsit5{…}; t::StepRangeLen{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Bool, callback::Nothing, kwargs::@Kwargs{})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:580
 [42] _adjoint_sensitivities
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/gauss_adjoint.jl:533 [inlined]
 [43] #adjoint_sensitivities#63
    @ ~/.julia/packages/SciMLSensitivity/HRhwU/src/sensitivity_interface.jl:401 [inlined]
 [44] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#313"{})(Δ::ODESolution{…})
    @ SciMLSensitivity ~/.julia/packages/SciMLSensitivity/HRhwU/src/concrete_solve.jl:627
 [45] ZBack
    @ ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:212 [inlined]
 [46] (::Zygote.var"#kw_zpullback#56"{})(dy::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/chainrules.jl:238
 [47] #294
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
 [48] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [49] #solve#51
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:1003 [inlined]
 [50] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [51] #294
    @ ~/.julia/packages/Zygote/Tt5Gx/src/lib/lib.jl:206 [inlined]
 [52] (::Zygote.var"#2169#back#296"{})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
 [53] solve
    @ ~/.julia/packages/DiffEqBase/DdIeW/src/solve.jl:993 [inlined]
 [54] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [55] symbolic_indexing
    @ ./REPL[16]:4 [inlined]
 [56] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface2.jl:0
 [57] (::Zygote.var"#78#79"{Zygote.Pullback{Tuple{}, Tuple{}}})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:91
 [58] gradient(::Function, ::Vector{Float64}, ::Vararg{Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/Tt5Gx/src/compiler/interface.jl:148

@DhairyaLGandhi
Copy link
Member

With SciML/ModelingToolkit.jl#3100 and #1131 the following works

function symbolic_indexing(u0, p)
    _p = SciMLStructures.replace(SciMLStructures.Tunable(), prob.p, p)
    _prob = remake(prob, u0 = u0, p = _p)
    soln = solve(_prob, Tsit5(), reltol = 1e-6, abstol = 1e-6, saveat = 0.1,
        sensealg = BacksolveAdjoint(autojacvec = ZygoteVJP()))
    sum(soln[x])
end
julia> du01, dp1 = Zygote.gradient(symbolic_indexing, u0, p)
([-6.439844108628638, -0.699257997562027], [90.44500786842111, -0.6992683768440167, -23.971135312882016, -159.4738415565799])

https://github.com/SciML/ModelingToolkit.jl/blob/d7fa2b9a03fa964c214c7fa5fd23574de1fd0db5/ext/MTKChainRulesCoreExt.jl#L87 doesn't handle the case where the tangents are arrays.

@AayushSabharwal
Copy link
Member Author

I've updated the remake_buffer adjoint in SciML/ModelingToolkit.jl#3104. Waiting for the SciMLSensitivity PR to merge so I can add tests

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants