Skip to content

Commit

Permalink
fix: update reactant version
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 21, 2024
1 parent 24a0ca5 commit 8665bc3
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 18 deletions.
4 changes: 0 additions & 4 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,6 @@ export default defineConfig({
text: "Training a PINN on 2D PDE",
link: "/tutorials/intermediate/4_PINN2DPDE",
},
{
text: "Conditional VAE for MNIST using Reactant",
link: "/tutorials/intermediate/5_ConditionalVAE",
}
],
},
{
Expand Down
1 change: 0 additions & 1 deletion docs/tutorials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ const INTERMEDIATE_TUTORIALS = [
"BayesianNN/main.jl" => "CPU",
"HyperNet/main.jl" => "CUDA",
"PINN2DPDE/main.jl" => "CUDA",
"ConditionalVAE/main.jl" => "CUDA",
]
const ADVANCED_TUTORIALS = [
"GravitationalWaveForm/main.jl" => "CPU",
Expand Down
2 changes: 1 addition & 1 deletion examples/ConvMixer/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2"
Printf = "1.10"
ProgressBars = "1.5.1"
Random = "1.10"
Reactant = "0.2.8"
Reactant = "0.2.11"
StableRNGs = "1.0.2"
Statistics = "1.10"
Zygote = "0.6.70"
2 changes: 1 addition & 1 deletion examples/ConvMixer/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader)
return total_correct / total
end

Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8,
Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8,
patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5,
clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01,
backend::String="reactant")
Expand Down
11 changes: 0 additions & 11 deletions ext/LuxReactantExt/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,14 @@ for inplace in ("!", "")

@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
data, ts::Training.TrainState) where {F}
@show 1213

compiled_grad_and_step_function = @compile $(internal_fn)(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
objective_function, ts.model, data, ts.parameters, ts.states,
ts.optimizer_state)

@show Lux.Functors.fmap(typeof, st)

cache = TrainingBackendCache(
backend, False(), nothing, (; compiled_grad_and_step_function))
@set! ts.cache = cache
Expand All @@ -74,18 +68,13 @@ for inplace in ("!", "")
@set! ts.optimizer_state = opt_state
@set! ts.step = ts.step + 1

@show Lux.Functors.fmap(typeof, ts.states)

return grads, loss, stats, ts
end

# XXX: Should we add a check to ensure the inputs to this function is same as the one
# used in the compiled function? We can re-trigger the compilation with a warning
@eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data,
ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F}
@show Lux.Functors.fmap(typeof, ts.parameters)
@show Lux.Functors.fmap(typeof, ts.states)

grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function(
obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)

Expand Down

0 comments on commit 8665bc3

Please sign in to comment.