From 1bc0ddee1e21452e86f25988e949ec4dd7ad8820 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 17 Jul 2024 08:24:37 +0200 Subject: [PATCH] Use dy=true for gradient (smallest possible one) --- .../src/first_order/gradient.jl | 12 ++++------ .../src/first_order/jacobian.jl | 24 ++++++++----------- .../src/first_order/pullback.jl | 16 ++++++++----- .../src/first_order/pushforward.jl | 16 ++++++++----- .../src/sparse/jacobian.jl | 4 ++-- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/DifferentiationInterface/src/first_order/gradient.jl b/DifferentiationInterface/src/first_order/gradient.jl index 45bb78ce1..7fb53ea7e 100644 --- a/DifferentiationInterface/src/first_order/gradient.jl +++ b/DifferentiationInterface/src/first_order/gradient.jl @@ -62,9 +62,7 @@ struct PullbackGradientExtras{E<:PullbackExtras} <: GradientExtras end function prepare_gradient(f::F, backend::AbstractADType, x) where {F} - y = f(x) - dy = one(y) - pullback_extras = prepare_pullback(f, backend, x, dy) + pullback_extras = prepare_pullback(f, backend, x, true) return PullbackGradientExtras(pullback_extras) end @@ -93,23 +91,23 @@ end function value_and_gradient( f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} - return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) + return value_and_pullback(f, backend, x, true, extras.pullback_extras) end function value_and_gradient!( f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} - return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) + return value_and_pullback!(f, grad, backend, x, true, extras.pullback_extras) end function gradient( f::F, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} - return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras) + return pullback(f, backend, x, true, extras.pullback_extras) end function gradient!( f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras ) where {F} - return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras) + return pullback!(f, grad, backend, x, true, extras.pullback_extras) end diff --git a/DifferentiationInterface/src/first_order/jacobian.jl b/DifferentiationInterface/src/first_order/jacobian.jl index 90362596a..802f7354f 100644 --- a/DifferentiationInterface/src/first_order/jacobian.jl +++ b/DifferentiationInterface/src/first_order/jacobian.jl @@ -86,7 +86,9 @@ function prepare_jacobian(f!::F, y, backend::AbstractADType, x) where {F} return prepare_jacobian_aux((f!, y), backend, x, y, pushforward_performance(backend)) end -function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) where {FY} +function prepare_jacobian_aux( + f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardFast +) where {FY} N = length(x) B = pick_batchsize(backend, N) seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)] @@ -107,7 +109,9 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh ) end -function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY} +function prepare_jacobian_aux( + f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardSlow +) where {FY} M = length(y) B = pick_batchsize(backend, M) seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)] @@ -221,7 +225,7 @@ end ## Common auxiliaries function jacobian_aux( - f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B} + f_or_f!y::FY, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B} ) where {FY,B} @compat (; batched_seeds, pushforward_batched_extras, N) = extras @@ -244,7 +248,7 @@ function jacobian_aux( end function jacobian_aux( - f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B} + f_or_f!y::FY, backend::AbstractADType, x, extras::PullbackJacobianExtras{B} ) where {FY,B} @compat (; batched_seeds, pullback_batched_extras, M) = extras @@ -267,11 +271,7 @@ function jacobian_aux( end function jacobian_aux!( - f_or_f!y::FY, - jac::AbstractMatrix, - backend, - x::AbstractArray, - extras::PushforwardJacobianExtras{B}, + f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B} ) where {FY,B} @compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras @@ -303,11 +303,7 @@ function jacobian_aux!( end function jacobian_aux!( - f_or_f!y::FY, - jac::AbstractMatrix, - backend, - x::AbstractArray, - extras::PullbackJacobianExtras{B}, + f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PullbackJacobianExtras{B} ) where {FY,B} @compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index 0c49c759d..a50cf77f2 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -110,23 +110,27 @@ function prepare_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F} return prepare_pullback_aux(f!, y, backend, x, dy, pullback_performance(backend)) end -function prepare_pullback_aux(f::F, backend, x, dy, ::PullbackSlow) where {F} +function prepare_pullback_aux( + f::F, backend::AbstractADType, x, dy, ::PullbackSlow +) where {F} dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x))) pushforward_extras = prepare_pushforward(f, backend, x, dx) return PushforwardPullbackExtras(pushforward_extras) end -function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F} +function prepare_pullback_aux( + f!::F, y, backend::AbstractADType, x, dy, ::PullbackSlow +) where {F} dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x))) pushforward_extras = prepare_pushforward(f!, y, backend, x, dx) return PushforwardPullbackExtras(pushforward_extras) end -function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) +function prepare_pullback_aux(f, backend::AbstractADType, x, dy, ::PullbackFast) throw(MissingBackendError(backend)) end -function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) +function prepare_pullback_aux(f!, y, backend::AbstractADType, x, dy, ::PullbackFast) throw(MissingBackendError(backend)) end @@ -177,7 +181,7 @@ end ### With extras function value_and_pullback( - f::F, backend, x, dy, extras::PushforwardPullbackExtras + f::F, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras y = f(x) @@ -241,7 +245,7 @@ end ### With extras function value_and_pullback( - f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras + f!::F, y, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras ) where {F} @compat (; pushforward_extras) = extras dx = if x isa Number && y isa AbstractArray diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 940622d8a..a7e1e0617 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -110,24 +110,28 @@ function prepare_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F} return prepare_pushforward_aux(f!, y, backend, x, dx, pushforward_performance(backend)) end -function prepare_pushforward_aux(f::F, backend, x, dx, ::PushforwardSlow) where {F} +function prepare_pushforward_aux( + f::F, backend::AbstractADType, x, dx, ::PushforwardSlow +) where {F} y = f(x) dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y))) pullback_extras = prepare_pullback(f, backend, x, dy) return PullbackPushforwardExtras(pullback_extras) end -function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) where {F} +function prepare_pushforward_aux( + f!::F, y, backend::AbstractADType, x, dx, ::PushforwardSlow +) where {F} dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y))) pullback_extras = prepare_pullback(f!, y, backend, x, dy) return PullbackPushforwardExtras(pullback_extras) end -function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) +function prepare_pushforward_aux(f, backend::AbstractADType, x, dx, ::PushforwardFast) throw(MissingBackendError(backend)) end -function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) +function prepare_pushforward_aux(f!, y, backend::AbstractADType, x, dx, ::PushforwardFast) throw(MissingBackendError(backend)) end @@ -180,7 +184,7 @@ end ### With extras function value_and_pushforward( - f::F, backend, x, dx, extras::PullbackPushforwardExtras + f::F, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras y = f(x) @@ -248,7 +252,7 @@ end ### With extras function value_and_pushforward( - f!::F, y, backend, x, dx, extras::PullbackPushforwardExtras + f!::F, y, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras ) where {F} @compat (; pullback_extras) = extras dy = if x isa Number && y isa AbstractArray diff --git a/DifferentiationInterface/src/sparse/jacobian.jl b/DifferentiationInterface/src/sparse/jacobian.jl index 64c478e8d..d9a491e35 100644 --- a/DifferentiationInterface/src/sparse/jacobian.jl +++ b/DifferentiationInterface/src/sparse/jacobian.jl @@ -84,7 +84,7 @@ function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F} end function prepare_sparse_jacobian_aux( - f_or_f!y::FY, backend, x, y, ::PushforwardFast + f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardFast ) where {FY} dense_backend = dense_ad(backend) initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend)) @@ -116,7 +116,7 @@ function prepare_sparse_jacobian_aux( end function prepare_sparse_jacobian_aux( - f_or_f!y::FY, backend, x, y, ::PushforwardSlow + f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardSlow ) where {FY} dense_backend = dense_ad(backend) initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))