Skip to content

Commit

Permalink
fix: unthunk order
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 23, 2025
1 parent 81ec729 commit ec156a5
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 12 deletions.
2 changes: 1 addition & 1 deletion lib/LuxLib/src/impl/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ CRC.@opt_out rrule(

function ∇fused_conv(Δ′, weight, x, bias, cdims::ConvDims, z, tmp, 𝒫w, 𝒫x, 𝒫b, act)
old_threads = maybe_reduce_BLAS_threads(weight)
Δ = CRC.unthunk(NNlib.colmajor(Δ′))
Δ = NNlib.colmajor(CRC.unthunk(Δ′))
∂y = ∇activation(Δ, z, act, tmp)
∂w, ∂x, ∂b = ∇conv_bias(∂y, weight, x, bias, cdims)
reset_BLAS_threads(old_threads)
Expand Down
17 changes: 11 additions & 6 deletions lib/LuxLib/src/impl/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,15 @@ function CRC.rrule(::typeof(alpha_dropout), ::LoopedArrayOp, noise::AbstractArra

∇alpha_dropout = let cond = cond, 𝒫x = CRC.ProjectTo(x), x = x
Δ -> begin
∂x = similar(x)
@simd ivdep for I in eachindex(cond, Δ, ∂x)
@inbounds ∂x[I] = cond[I] * Δ[I] * A
∂x = CRC.@thunk begin
∂x_tmp = similar(x)
Δ_ = CRC.unthunk(Δ)
@simd ivdep for I in eachindex(cond, Δ_, ∂x_tmp)
@inbounds ∂x_tmp[I] = cond[I] * Δ_[I] * A
end
𝒫x(∂x_tmp)
end
return (ntuple(Returns(∂∅), 4)..., 𝒫x(∂x), ntuple(Returns(∂∅), 3)...)
return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...)
end
end

Expand All @@ -105,7 +109,7 @@ function CRC.rrule(::typeof(alpha_dropout), ::AbstractInternalArrayOpMode,

𝒫x = CRC.ProjectTo(x)
∇alpha_dropout = @closure Δ -> begin
∂x = 𝒫x(Δ .* cond .* A)
∂x = CRC.@thunk 𝒫x(CRC.unthunk(Δ) .* cond .* A)
return (ntuple(Returns(∂∅), 4)..., ∂x, ntuple(Returns(∂∅), 3)...)
end

Expand Down Expand Up @@ -167,7 +171,8 @@ dropout_dot_mul(x::AbstractArray, mask::AbstractArray) = x .* mask

function CRC.rrule(::typeof(dropout_dot_mul), x::AbstractArray, mask::AbstractArray)
∇dropout_dot_mul = @closure Δ -> begin
return ∂∅, (CRC.ProjectTo(x))(dropout_dot_mul(Δ, mask)), ∂∅
∂x = CRC.@thunk CRC.ProjectTo(x)(dropout_dot_mul(CRC.unthunk(Δ), mask))
return ∂∅, ∂x, ∂∅
end
return dropout_dot_mul(x, mask), ∇dropout_dot_mul
end
2 changes: 1 addition & 1 deletion src/autodiff/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ for type in (:Gradient, :Jacobian)
$(rrule_call)
∇forwarddiff_ad = let pb_f = pb_f
Δ -> begin
∂x, ∂y = pb_f(tuple(Δ))[(end - 1):end]
∂x, ∂y = pb_f(tuple(CRC.unthunk(Δ)))[(end - 1):end]
𝒫x, 𝒫y = CRC.ProjectTo(x), CRC.ProjectTo(y)
return (ntuple(Returns(NoTangent()), 4)..., 𝒫x(∂x), 𝒫y(∂y))
end
Expand Down
6 changes: 3 additions & 3 deletions src/autodiff/nested_autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_gradient

Δ = CRC.unthunk(Δ′)
# For Zygote and such which return a tuple
(res isa Tuple || Δ isa Tuple) &&= only(Δ))
(res isa Tuple || Δ isa Tuple) &&= CRC.unthunk(only)))
∂x, ∂y = forwarddiff_jvp(@closure((x, y)->grad_fn(f, x, y)), x, Δ, y)
𝒫x, 𝒫y = CRC.ProjectTo(x), CRC.ProjectTo(y)
return NoTangent(), NoTangent(), NoTangent(), 𝒫x(∂x), 𝒫y(∂y)
Expand Down Expand Up @@ -79,7 +79,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_pullback

Δ = CRC.unthunk(Δ′)
# For Zygote and such which return a tuple
(res isa Tuple || Δ isa Tuple) &&= only(Δ))
(res isa Tuple || Δ isa Tuple) &&= CRC.unthunk(only)))
∂x, ∂y = forwarddiff_jvp(x, Δ, y) do x_dual, y_
return last(pb_f(f, x_dual, y_))(u)
end
Expand Down Expand Up @@ -113,7 +113,7 @@ function CRC.rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(autodiff_jacobian

Δ = CRC.unthunk(Δ′)
# For Zygote and such which return a tuple
(res isa Tuple || Δ isa Tuple) &&= only(Δ))
(res isa Tuple || Δ isa Tuple) &&= CRC.unthunk(only)))
Δ = compactify_if_structured_matrix(res isa Tuple ? only(res) : res, Δ)

inner_grad_fn = @closure(i->sum Base.Fix2(getindex, i:i) vec f)
Expand Down
2 changes: 1 addition & 1 deletion src/layers/extension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function CRC.rrule(::typeof(apply_simple_chain), layer, x, ps, ::CPUDevice)
res, pb = CRC.rrule(layer, x, ps)
# Safety measure to prevent errors from weird Array types that SimpleChains doesn't support
∇apply_simple_chain = @closure Δ -> begin
_, ∂x, ∂ps = pb(convert(Array, Δ))
_, ∂x, ∂ps = pb(convert(Array, CRC.unthunk(Δ)))
return NoTangent(), NoTangent(), 𝒫x(∂x), 𝒫ps(∂ps), NoTangent()
end
return res, ∇apply_simple_chain
Expand Down

0 comments on commit ec156a5

Please sign in to comment.