Closed
Description
Hello all,
I am trying to extend the pedagogical example to work as if Foo is a layer in the Flux ecosystem. I would like to insert Foo into a Flux Chain and train its parameters using the derivatives computed in the rrule instead of the standard Zygote AD. When I use the "gradient" function it returns a Grads(...) struct but the parameters do not match the gradients:
using Flux
using Flux: params
using ChainRulesCore
struct Foo
A::Matrix
c::Float64
end
Flux.@functor Foo
function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
println("calling foo rrule")
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
foo = Foo(randn(2, 2), 1.0)
b = randn(2)
ps = params(foo)
gs = gradient(() -> sum(foo_mul(foo, b)), ps)
display(gs.grads)
Here is the output of the code:
julia> include("scripts/foo.jl")
calling foo rrule
IdDict{Any, Any} with 3 entries:
:(Main.b) => [1.5824, -0.280951]
[0.795399 0.392306; 0.786999 -0.673257] => nothing
:(Main.foo) => (A = [-0.728823 0.776745; -0.728823 0.776745], c = nothing)
How can I make this work properly with the Flux "gradient" function?
Metadata
Metadata
Assignees
Labels
No labels