Skip to content

Commit

Permalink
feat: update training api to account for thunking
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 23, 2025
1 parent 63750f1 commit e1c3a4a
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions ext/LuxZygoteExt/training.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
function Lux.Training.compute_gradients_impl(
::AutoZygote, objective_function::F, data, ts::Lux.Training.TrainState) where {F}
(loss, st, stats), back = Zygote.pullback(
objective_function, ts.model, ts.parameters, ts.states, data)
grads = back((one(loss), nothing, nothing))[2]
@static if pkgversion(Zygote) v"0.7-"
# Zygote 0.7 doesn't aggressively unthunk everything, so it is better to use a
# closure here
(loss, st, stats), back = Zygote.pullback(
ps -> objective_function(ts.model, ps, ts.states, data), ts.parameters)
grads = only(back((one(loss), nothing, nothing)))
else
(loss, st, stats), back = Zygote.pullback(
objective_function, ts.model, ts.parameters, ts.states, data
)
grads = back((one(loss), nothing, nothing))[2]
end
@set! ts.states = st
return grads, loss, stats, ts
end

0 comments on commit e1c3a4a

Please sign in to comment.