Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training edge weights as well as neural network #443

Open
jarroyoe opened this issue Jun 27, 2024 · 0 comments
Open

Training edge weights as well as neural network #443

jarroyoe opened this issue Jun 27, 2024 · 0 comments

Comments

@jarroyoe
Copy link

jarroyoe commented Jun 27, 2024

I’m trying to model Graph NODEs integrating GraphNeuralNetworks.jl and OrdinaryDiffEq.jl. I am trying to learn both the neural network parameters as well as the weights of the edges, so I have to manually modify the Flux parameters during prediction. When I run the following MWE:

using Graphs, GraphNeuralNetworks, Flux, OrdinaryDiffEq, ComponentArrays, Zygote, SciMLSensitivity

time = 1:10
x0 = rand(9)
obs = rand(9,10)

fullGraph = GNNGraph(complete_digraph(3))

layer1 = GCNConv(3 => 10,tanh,use_edge_weight=true)
layer2 = GCNConv(10 => 3,use_edge_weight=true)

chain = GNNChain(layer1,layer2)
pinit = ComponentArray{Float32}(weights = rand(ne(fullGraph)),
        layer1 = f64(layer1.weight),layer2 = f64(layer2.weight))

function predict(p)
	fullGraph = GNNGraph(complete_digraph(3))
    fullGraph = set_edge_weight(fullGraph,p.weights)
    chain.layers[1].weight .= p.layer1
    chain.layers[2].weight .= p.layer2

    function nn!(du,u,p,t)
		uGraph = reshape(u,(3,3))
        dGraph = reshape(chain(fullGraph,uGraph),(3*3))
        du .= dGraph
    end
    prob = ODEProblem(nn!,x0,(time[1],time[end]),saveat=time)
    sol = solve(prob)
    return Array(sol)
end

function loss_function(p)
    pred = predict(p)
        
	sum(abs2,pred .- obs)
end

Zygote.gradient(loss_function,pinit)

I get the following error:

ERROR: BoundsError: attempt to access 10-element UnitRange{Int64} at index [0]
Stacktrace:
  [1] throw_boundserror(A::UnitRange{Int64}, I::Int64)
    @ Base .\abstractarray.jl:737
  [2] getindex
    @ .\range.jl:930 [inlined]
  [3] (::SciMLSensitivity.ReverseLossCallback{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\adjoint_common.jl:530
  [4] #111
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqCallbacks\9fKPq\src\preset_time.jl:58 [inlined]
  [5] apply_discrete_callback!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:613 [inlined]
  [6] apply_discrete_callback!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\callbacks.jl:628 [inlined]
  [7] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:349
  [8] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:254
  [9] loopfooter!
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\integrators\integrator_utils.jl:207 [inlined]
 [10] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
    @ OrdinaryDiffEq C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:558
 [11] #__solve#670
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:7 [inlined]
 [12] __solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:1 [inlined]
 [13] solve_call(_prob::ODEProblem{…}, args::CompositeAlgorithm{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
 [14] solve_call
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:569 [inlined]
 [15] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::CompositeAlgorithm{…}; kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1080
 [16] solve_up
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
 [17] #solve#51
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
 [18] solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
 [19] #__solve#675
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:547 [inlined]
 [20] __solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\OrdinaryDiffEq\XAFCL\src\solve.jl:546 [inlined]
 [21] solve_call(_prob::ODEProblem{…}, args::Nothing; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:612
 [22] solve_up(prob::ODEProblem{…}, sensealg::Nothing, u0::Vector{…}, p::SciMLBase.NullParameters, args::Nothing; kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1072
 [23] solve_up
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1066 [inlined]
 [24] solve(prob::ODEProblem{…}, args::Nothing; sensealg::Nothing, u0::Nothing, p::Nothing, wrap::Val{…}, kwargs::@Kwargs{…})
    @ DiffEqBase C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003
 [25] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::QuadratureAdjoint{…}, alg::Nothing; t::UnitRange{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, callback::Nothing, kwargs::@Kwargs{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:340
 [26] _adjoint_sensitivities
    @ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\quadrature_adjoint.jl:328 [inlined]
 [27] #adjoint_sensitivities#63
    @ C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\sensitivity_interface.jl:386 [inlined]
 [28] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#314"{…})(Δ::ODESolution{…})
    @ SciMLSensitivity C:\Users\JArroyo-Esquivel\.julia\packages\SciMLSensitivity\waEMv\src\concrete_solve.jl:582
 [29] ZBack
    @ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\chainrules.jl:211 [inlined]
 [30] (::Zygote.var"#291#292"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206
 [31] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72
 [32] #solve#51
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:1003 [inlined]
 [33] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [34] #291
    @ C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\lib\lib.jl:206 [inlined]
 [35] #2169#back
    @ C:\Users\JArroyo-Esquivel\.julia\packages\ZygoteRules\M4xmc\src\adjoint.jl:72 [inlined]
 [36] solve
    @ C:\Users\JArroyo-Esquivel\.julia\packages\DiffEqBase\uRikl\src\solve.jl:993 [inlined]
 [37] (::Zygote.Pullback{…})(Δ::ODESolution{…})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [38] predict
    @ .\REPL[11]:13 [inlined]
 [39] (::Zygote.Pullback{Tuple{typeof(predict), ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}}}, Any})(Δ::Matrix{Float64})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface2.jl:0
 [40] loss_function
    @ .\REPL[12]:2 [inlined]
 [41] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:91
 [42] gradient(f::Function, args::ComponentVector{Float32, Vector{Float32}, Tuple{Axis{…}}})
    @ Zygote C:\Users\JArroyo-Esquivel\.julia\packages\Zygote\nsBv0\src\compiler\interface.jl:148
 [43] top-level scope
    @ REPL[17]:1

I'm crossposting this from the discourse as I don't know if this is necessarily a bug with GraphNeuralNetworks.jl or if the devs know a better alternative to do these kinds of processes.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant