From 3df1fa1bcf3808f04129af9322d3c9c69afc58f5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 10 Feb 2025 23:46:02 -0500 Subject: [PATCH] feat: emit batch norm training op --- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReactantExt.jl | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 420bd0956..7905bc21b 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.6.1" +version = "1.7.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibReactantExt.jl b/lib/LuxLib/ext/LuxLibReactantExt.jl index 36b79cccd..1787b8278 100644 --- a/lib/LuxLib/ext/LuxLibReactantExt.jl +++ b/lib/LuxLib/ext/LuxLibReactantExt.jl @@ -2,7 +2,7 @@ module LuxLibReactantExt using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray, AnyTracedRVector, TracedRNumber -using Static: False +using Static: True, False using LuxLib: LuxLib, Impl, Optional, Utils @@ -56,14 +56,11 @@ function Impl.batchnorm( return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ² end -# The following code is commented out since we don't have Batchnorm Op Adjoint registered -# for EnzymeJAX yet -#= function Impl.batchnorm( x::AnyTracedRArray{T}, γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector}, rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector}, - training::StaticBool, act::F, momentum, ϵ + ::True, act::F, momentum, ϵ ) where {T, F} x = TracedUtils.materialize_traced_array(x) @@ -101,6 +98,5 @@ function Impl.batchnorm( return res, rμ, rσ² end end -=# end