Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Transfer-Learning/Layer-Freezing #150

Open
JoshuaBillson opened this issue Jul 14, 2023 · 2 comments
Open

Support for Transfer-Learning/Layer-Freezing #150

JoshuaBillson opened this issue Jul 14, 2023 · 2 comments
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@JoshuaBillson
Copy link

JoshuaBillson commented Jul 14, 2023

Motivation and description

A common practice in machine learning is to take a pre-trained model and fine-tune it on a particular dataset. This typically involves freezing the weights in some layers while fitting the output layer(s) on the new data.

Unfortunately, this functionally appears to be incompatible with the current implementation of the ToDevice callback based on the following code:

function on(::EpochBegin, ::Phase, cb::ToDevice, learner)
    model!(learner, cb.movemodelfn(learner.model))
end

function model!(learner, model)
    learner.model = model
    learner.params = setupoptimstate(model, learner.optimizer)
end

setupoptimstate(model, ::Flux.Optimise.AbstractOptimiser) = Flux.params(model)

setupoptimstate(model, optim) = Optimisers.setup(optim, model)

This essentially means that learner.params is set to the parameters of the full model at the start of each epoch. Thus, even if we try to freeze the layers manually with Flux.freeze!(learner.params.layers[1:end-1]), this will be undone by ToDevice.

Possible Implementation

One solution that would work with Flux's new explicit optimizers would be to create a callback to freeze layers after ToDevice is executed. An example is given below:

mutable struct LayerFreezing{F} <: FluxTraining.Callback
    accessor::F
end

function FluxTraining.stateaccess(scheduler::LayerFreezing)
    return (;params = FluxTraining.Write())
end

function FluxTraining.on(
    event::FluxTraining.EpochBegin, 
    phase::FluxTraining.AbstractTrainingPhase, 
    freezer::LayerFreezing, 
    learner)
    Flux.freeze!(freezer.accessor(learner.params))
end

FluxTraining.runafter(::LayerFreezing) = (FluxTraining.ToDevice,)

However, perhaps we should consider whether it's necessary for ToDevice to move the model to the GPU at the start of every epoch. Maybe we could extend the Callback interface to allow for some one-time setup code to run before the first epoch is executed?

@JoshuaBillson
Copy link
Author

I think this issue may also be related to #148. In particular, the memory leak appears to be caused by ToDevice resetting the optimizer in each epoch. We could potentially kill two birds with one stone by changing this behaviour.

@ToucheSir ToucheSir added enhancement New feature or request help wanted Extra attention is needed labels Jul 15, 2023
@vargonis
Copy link

vargonis commented Jun 3, 2024

Any update on this? Also, I'd really appreciate if the potential implementation above is turned into a complete example to build upon (for users like me who know nothing about the internals of FluxTraining.jl). Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants