From ec156a5b7fb03256abe164f099a426351bd4cf0d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Jan 2025 22:22:47 -0500 Subject: [PATCH] fix: unthunk order --- lib/LuxLib/src/impl/conv.jl | 2 +- lib/LuxLib/src/impl/dropout.jl | 17 +++++++++++------ src/autodiff/forwarddiff.jl | 2 +- src/autodiff/nested_autodiff.jl | 6 +++--- src/layers/extension.jl | 2 +- 5 files changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/LuxLib/src/impl/conv.jl b/lib/LuxLib/src/impl/conv.jl index eda9eda134..8002061d25 100644 --- a/lib/LuxLib/src/impl/conv.jl +++ b/lib/LuxLib/src/impl/conv.jl @@ -196,7 +196,7 @@ CRC.@opt_out rrule( function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act) old_threads = maybe_reduce_BLAS_threads(weight) - Δ = CRC.unthunk(NNlib.colmajor(Δ′)) + Δ = NNlib.colmajor(CRC.unthunk(Δ′)) ∂y = ∇activation(Δ, z, act, tmp) ∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims) reset_BLAS_threads(old_threads) diff --git a/lib/LuxLib/src/impl/dropout.jl b/lib/LuxLib/src/impl/dropout.jl index 10dda2f69e..6d01b913fe 100644 --- a/lib/LuxLib/src/impl/dropout.jl +++ b/lib/LuxLib/src/impl/dropout.jl @@ -87,11 +87,15 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra ∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x Δ -> begin - ∂x = similar(x) - @simd ivdep for I in eachindex(cond, Δ, ∂x) - @inbounds ∂x[I] = cond[I] * Δ[I] * A + ∂x = CRC.@thunk begin + ∂x_tmp = similar(x) + Δ_ = CRC.unthunk(Δ) + @simd ivdep for I in eachindex(cond, Δ_, ∂x_tmp) + @inbounds ∂x_tmp[I] = cond[I] * Δ_[I] * A + end + 𝒫x(∂x_tmp) end - return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...) + return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) end end @@ -105,7 +109,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode, 𝒫x = CRC.ProjectTo(x) ∇alpha_dropout = @closure Δ -> begin - ∂x = 𝒫x(Δ .* cond .* A) + ∂x = CRC.@thunk 𝒫x(CRC.unthunk(Δ) .* cond .* A) return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...) end @@ -167,7 +171,8 @@ dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray) ∇dropout_dot_mul = @closure Δ -> begin - return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅ + ∂x = CRC.@thunk CRC.ProjectTo(x)(dropout_dot_mul(CRC.unthunk(Δ), mask)) + return ∂∅, ∂x, ∂∅ end return dropout_dot_mul(x, mask), ∇dropout_dot_mul end diff --git a/src/autodiff/forwarddiff.jl b/src/autodiff/forwarddiff.jl index f403120942..848cf3cd2f 100644 --- a/src/autodiff/forwarddiff.jl +++ b/src/autodiff/forwarddiff.jl @@ -49,7 +49,7 @@ for type in (:Gradient, :Jacobian) $(rrule_call) ∇forwarddiff_ad = let pb_f = pb_f Δ -> begin - ∂x, ∂y = pb_f(tuple(Δ))[(end - 1):end] + ∂x, ∂y = pb_f(tuple(CRC.unthunk(Δ)))[(end - 1):end] 𝒫x, 𝒫y = CRC.ProjectTo(x), CRC.ProjectTo(y) return (ntuple(Returns(NoTangent()), 4)..., 𝒫x(∂x), 𝒫y(∂y)) end diff --git a/src/autodiff/nested_autodiff.jl b/src/autodiff/nested_autodiff.jl index dfc94ad6f6..e94a9b8e4d 100644 --- a/src/autodiff/nested_autodiff.jl +++ b/src/autodiff/nested_autodiff.jl @@ -49,7 +49,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_gradient Δ = CRC.unthunk(Δ′) # For Zygote and such which return a tuple - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) + (res isa Tuple || Δ isa Tuple) && (Δ = CRC.unthunk(only(Δ))) ∂x, ∂y = forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y) 𝒫x, 𝒫y = CRC.ProjectTo(x), CRC.ProjectTo(y) return NoTangent(), NoTangent(), NoTangent(), 𝒫x(∂x), 𝒫y(∂y) @@ -79,7 +79,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_pullback Δ = CRC.unthunk(Δ′) # For Zygote and such which return a tuple - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) + (res isa Tuple || Δ isa Tuple) && (Δ = CRC.unthunk(only(Δ))) ∂x, ∂y = forwarddiff_jvp(x, Δ, y) do x_dual, y_ return last(pb_f(f, x_dual, y_))(u) end @@ -113,7 +113,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_jacobian Δ = CRC.unthunk(Δ′) # For Zygote and such which return a tuple - (res isa Tuple || Δ isa Tuple) && (Δ = only(Δ)) + (res isa Tuple || Δ isa Tuple) && (Δ = CRC.unthunk(only(Δ))) Δ = compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ) inner_grad_fn = @closure(i->sum ∘ Base.Fix2(getindex, i:i) ∘ vec ∘ f) diff --git a/src/layers/extension.jl b/src/layers/extension.jl index 8242790a86..ed1a2123e4 100644 --- a/src/layers/extension.jl +++ b/src/layers/extension.jl @@ -100,7 +100,7 @@ function CRC.rrule(::typeof(apply_simple_chain), layer, x, ps, ::CPUDevice) res, pb = CRC.rrule(layer, x, ps) # Safety measure to prevent errors from weird Array types that SimpleChains doesn't support ∇apply_simple_chain = @closure Δ -> begin - _, ∂x, ∂ps = pb(convert(Array, Δ)) + _, ∂x, ∂ps = pb(convert(Array, CRC.unthunk(Δ))) return NoTangent(), NoTangent(), 𝒫x(∂x), 𝒫ps(∂ps), NoTangent() end return res, ∇apply_simple_chain