Skip to content

Commit

Permalink
feat: emit batch norm training op
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 11, 2025
1 parent 044efef commit 5f1abca
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.6.1"
version = "1.7.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
8 changes: 2 additions & 6 deletions lib/LuxLib/ext/LuxLibReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module LuxLibReactantExt

using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray,
AnyTracedRVector, TracedRNumber
using Static: False
using Static: True, False

using LuxLib: LuxLib, Impl, Optional, Utils

Expand Down Expand Up @@ -56,14 +56,11 @@ function Impl.batchnorm(
return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ²
end

# The following code is commented out since we don't have Batchnorm Op Adjoint registered
# for EnzymeJAX yet
#=
function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
training::StaticBool, act::F, momentum, ϵ
::True, act::F, momentum, ϵ
) where {T, F}
x = TracedUtils.materialize_traced_array(x)

Expand Down Expand Up @@ -101,6 +98,5 @@ function Impl.batchnorm(
return res, rμ, rσ²
end
end
=#

end

0 comments on commit 5f1abca

Please sign in to comment.