Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using the gradient function from Flux / Zygote with a custom rrule #617

Closed
gladisor opened this issue Apr 13, 2023 · 5 comments
Closed

Using the gradient function from Flux / Zygote with a custom rrule #617

gladisor opened this issue Apr 13, 2023 · 5 comments

Comments

@gladisor
Copy link

gladisor commented Apr 13, 2023

Hello all,

I am trying to extend the pedagogical example to work as if Foo is a layer in the Flux ecosystem. I would like to insert Foo into a Flux Chain and train its parameters using the derivatives computed in the rrule instead of the standard Zygote AD. When I use the "gradient" function it returns a Grads(...) struct but the parameters do not match the gradients:

using Flux
using Flux: params
using ChainRulesCore

struct Foo
    A::Matrix
    c::Float64
end

Flux.@functor Foo

function foo_mul(foo::Foo, b::AbstractArray)
    return foo.A * b
end

function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
    println("calling foo rrule")
    y = foo_mul(foo, b)

    function foo_mul_pullback(ȳ)

        f̄ = NoTangent()
        f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
        b̄ = @thunk(foo.A' * ȳ)

        return f̄, f̄oo, b̄
    end

    return y, foo_mul_pullback
end

foo = Foo(randn(2, 2), 1.0)
b = randn(2)

ps = params(foo)
gs = gradient(() -> sum(foo_mul(foo, b)), ps)

display(gs.grads)

Here is the output of the code:

julia> include("scripts/foo.jl")
calling foo rrule
IdDict{Any, Any} with 3 entries:
  :(Main.b)                               => [1.5824, -0.280951]
  [0.795399 0.392306; 0.786999 -0.673257] => nothing
  :(Main.foo)                             => (A = [-0.728823 0.776745; -0.728823 0.776745], c = nothing)

How can I make this work properly with the Flux "gradient" function?

@CarloLucibello
Copy link

CarloLucibello commented Apr 13, 2023

Why you think it isn't working?
I changed slightly the example to show clearly that it is working:

using Flux
using ChainRulesCore

struct Foo
    A::Matrix
    c::Float64
end

Flux.@functor Foo

function foo_mul(foo::Foo, b::AbstractArray)
    return foo.A * b
end

function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
    println("calling foo rrule")
    y = foo_mul(foo, b)

    function foo_mul_pullback(ȳ)

        f̄ = NoTangent()
        f̄oo = @thunk(Tangent{Foo}(; A=fill!(similar(foo.A), 1), c=ZeroTangent()))
        b̄ = @thunk(foo.A' * ȳ)

        return f̄, f̄oo, b̄
    end

    return y, foo_mul_pullback
end

foo = Foo(randn(2, 2), 1.0)
b = randn(2)

grad = gradient(foo -> sum(foo_mul(foo, b)), foo)[1]
# calling foo rrule
# (A = [1.0 1.0; 1.0 1.0], c = nothing)

@gladisor
Copy link
Author

Hi Carlo,

Thanks for your response. I see my mistake now. I was trying to extract the parameters of foo but that is unnecessary.

How can I use this to now update foo with an optimizer?

This code is not working properly:

gs = gradient(foo -> sum(foo_mul(foo, b)), foo)[1]

opt = Descent(0.01)
Flux.Optimise.update!(opt, foo, gs)

@CarloLucibello
Copy link

@gladisor
Copy link
Author

After reviewing this documentation I am still a bit confused how to update the model parameters. For example:

using Flux
using ChainRulesCore

struct Linear
    W::Matrix
    b::Vector
end

function (l::Linear)(x::Vector)
    return l.W * x .+ l.b
end

function ChainRulesCore.rrule(l::Linear, x::Vector)
    println("calling linear rrule")

    y = l(x)

    function linear_back(Δ)
        println("calling linear back")

        dW = Δ * x'
        db = Δ
        dx = (Δ' * l.W)'

        tangent = Tangent{Linear}(;W = dW, b = db)
        return tangent, dx
    end

    return y, linear_back
end

model = Flux.Chain(
    Linear(randn(2, 2), zeros(2)),
    sum)

x = randn(2)
opt = Descent(0.01)
gs = gradient(m -> m(x), model)

gs looks like this:

((layers = ((W = [-1.3399043659000172 -2.3293859097721454; -1.3399043659000172 -2.3293859097721454], b = Fill(1.0, 2)), nothing),),)

How can I update the model using an optimizer? This doesn't work:

Flux.Optimise.update!(opt, model, gs[1])

As it results in the following error:

ERROR: MethodError: no method matching similar(::Tuple{Linear, typeof(sum)}, ::Type{Tuple{NamedTuple{(:W, :b), Tuple{Matrix{Float64}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}, Nothing}})

@gladisor
Copy link
Author

Nevermind, I figured it out. Will post a full solution later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants