-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using the gradient function from Flux / Zygote with a custom rrule #617
Comments
Why you think it isn't working? using Flux
using ChainRulesCore
struct Foo
A::Matrix
c::Float64
end
Flux.@functor Foo
function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
println("calling foo rrule")
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = @thunk(Tangent{Foo}(; A=fill!(similar(foo.A), 1), c=ZeroTangent()))
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
foo = Foo(randn(2, 2), 1.0)
b = randn(2)
grad = gradient(foo -> sum(foo_mul(foo, b)), foo)[1]
# calling foo rrule
# (A = [1.0 1.0; 1.0 1.0], c = nothing) |
Hi Carlo, Thanks for your response. I see my mistake now. I was trying to extract the parameters of foo but that is unnecessary. How can I use this to now update foo with an optimizer? This code is not working properly:
|
After reviewing this documentation I am still a bit confused how to update the model parameters. For example:
gs looks like this:
How can I update the model using an optimizer? This doesn't work:
As it results in the following error:
|
Nevermind, I figured it out. Will post a full solution later. |
Hello all,
I am trying to extend the pedagogical example to work as if Foo is a layer in the Flux ecosystem. I would like to insert Foo into a Flux Chain and train its parameters using the derivatives computed in the rrule instead of the standard Zygote AD. When I use the "gradient" function it returns a Grads(...) struct but the parameters do not match the gradients:
Here is the output of the code:
How can I make this work properly with the Flux "gradient" function?
The text was updated successfully, but these errors were encountered: