From 8665bc3046ddd4b1f308402bf4661b32a1c87e81 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 20 Dec 2024 12:01:35 +0530 Subject: [PATCH] fix: update reactant version --- docs/src/.vitepress/config.mts | 4 ---- docs/tutorials.jl | 1 - examples/ConvMixer/Project.toml | 2 +- examples/ConvMixer/main.jl | 2 +- ext/LuxReactantExt/training.jl | 11 ----------- 5 files changed, 2 insertions(+), 18 deletions(-) diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 35c5739439..f785f6a316 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -218,10 +218,6 @@ export default defineConfig({ text: "Training a PINN on 2D PDE", link: "/tutorials/intermediate/4_PINN2DPDE", }, - { - text: "Conditional VAE for MNIST using Reactant", - link: "/tutorials/intermediate/5_ConditionalVAE", - } ], }, { diff --git a/docs/tutorials.jl b/docs/tutorials.jl index b9b9971d3e..d9dad6510b 100644 --- a/docs/tutorials.jl +++ b/docs/tutorials.jl @@ -11,7 +11,6 @@ const INTERMEDIATE_TUTORIALS = [ "BayesianNN/main.jl" => "CPU", "HyperNet/main.jl" => "CUDA", "PINN2DPDE/main.jl" => "CUDA", - "ConditionalVAE/main.jl" => "CUDA", ] const ADVANCED_TUTORIALS = [ "GravitationalWaveForm/main.jl" => "CPU", diff --git a/examples/ConvMixer/Project.toml b/examples/ConvMixer/Project.toml index 11e2f29d3b..04fec524d2 100644 --- a/examples/ConvMixer/Project.toml +++ b/examples/ConvMixer/Project.toml @@ -39,7 +39,7 @@ PreferenceTools = "0.1.2" Printf = "1.10" ProgressBars = "1.5.1" Random = "1.10" -Reactant = "0.2.8" +Reactant = "0.2.11" StableRNGs = "1.0.2" Statistics = "1.10" Zygote = "0.6.70" diff --git a/examples/ConvMixer/main.jl b/examples/ConvMixer/main.jl index 08e5553e75..ac36b6f570 100644 --- a/examples/ConvMixer/main.jl +++ b/examples/ConvMixer/main.jl @@ -73,7 +73,7 @@ function accuracy(model, ps, st, dataloader) return total_correct / total end -Comonicon.@main function main(; batchsize::Int=64, hidden_dim::Int=256, depth::Int=8, +Comonicon.@main function main(; batchsize::Int=512, hidden_dim::Int=256, depth::Int=8, patch_size::Int=2, kernel_size::Int=5, weight_decay::Float64=1e-5, clip_norm::Bool=false, seed::Int=42, epochs::Int=25, lr_max::Float64=0.01, backend::String="reactant") diff --git a/ext/LuxReactantExt/training.jl b/ext/LuxReactantExt/training.jl index 814d8d3bfb..9ab8880ab3 100644 --- a/ext/LuxReactantExt/training.jl +++ b/ext/LuxReactantExt/training.jl @@ -51,20 +51,14 @@ for inplace in ("!", "") @eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState) where {F} - @show 1213 - compiled_grad_and_step_function = @compile $(internal_fn)( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function( objective_function, ts.model, data, ts.parameters, ts.states, ts.optimizer_state) - @show Lux.Functors.fmap(typeof, st) - cache = TrainingBackendCache( backend, False(), nothing, (; compiled_grad_and_step_function)) @set! ts.cache = cache @@ -74,8 +68,6 @@ for inplace in ("!", "") @set! ts.optimizer_state = opt_state @set! ts.step = ts.step + 1 - @show Lux.Functors.fmap(typeof, ts.states) - return grads, loss, stats, ts end @@ -83,9 +75,6 @@ for inplace in ("!", "") # used in the compiled function? We can re-trigger the compilation with a warning @eval function Lux.Training.$(fname)(::ReactantBackend, obj_fn::F, data, ts::Training.TrainState{<:TrainingBackendCache{ReactantBackend}, F}) where {F} - @show Lux.Functors.fmap(typeof, ts.parameters) - @show Lux.Functors.fmap(typeof, ts.states) - grads, ps, loss, stats, st, opt_state = ts.cache.extras.compiled_grad_and_step_function( obj_fn, ts.model, data, ts.parameters, ts.states, ts.optimizer_state)