Skip to content

Commit

Permalink
Use dy=true for gradient (smallest possible one)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jul 17, 2024
1 parent 63fa856 commit 1bc0dde
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 35 deletions.
12 changes: 5 additions & 7 deletions DifferentiationInterface/src/first_order/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ struct PullbackGradientExtras{E<:PullbackExtras} <: GradientExtras
end

function prepare_gradient(f::F, backend::AbstractADType, x) where {F}
y = f(x)
dy = one(y)
pullback_extras = prepare_pullback(f, backend, x, dy)
pullback_extras = prepare_pullback(f, backend, x, true)
return PullbackGradientExtras(pullback_extras)
end

Expand Down Expand Up @@ -93,23 +91,23 @@ end
function value_and_gradient(
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
) where {F}
return value_and_pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
return value_and_pullback(f, backend, x, true, extras.pullback_extras)
end

function value_and_gradient!(
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
) where {F}
return value_and_pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
return value_and_pullback!(f, grad, backend, x, true, extras.pullback_extras)
end

function gradient(
f::F, backend::AbstractADType, x, extras::PullbackGradientExtras
) where {F}
return pullback(f, backend, x, one(eltype(x)), extras.pullback_extras)
return pullback(f, backend, x, true, extras.pullback_extras)
end

function gradient!(
f::F, grad, backend::AbstractADType, x, extras::PullbackGradientExtras
) where {F}
return pullback!(f, grad, backend, x, one(eltype(x)), extras.pullback_extras)
return pullback!(f, grad, backend, x, true, extras.pullback_extras)
end
24 changes: 10 additions & 14 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ function prepare_jacobian(f!::F, y, backend::AbstractADType, x) where {F}
return prepare_jacobian_aux((f!, y), backend, x, y, pushforward_performance(backend))
end

function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) where {FY}
function prepare_jacobian_aux(
f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardFast
) where {FY}
N = length(x)
B = pick_batchsize(backend, N)
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
Expand All @@ -107,7 +109,9 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
)
end

function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY}
function prepare_jacobian_aux(
f_or_f!y::FY, backend::AbstractADType, x, y, ::PushforwardSlow
) where {FY}
M = length(y)
B = pick_batchsize(backend, M)
seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)]
Expand Down Expand Up @@ -221,7 +225,7 @@ end
## Common auxiliaries

function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
f_or_f!y::FY, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pushforward_batched_extras, N) = extras

Expand All @@ -244,7 +248,7 @@ function jacobian_aux(
end

function jacobian_aux(
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
f_or_f!y::FY, backend::AbstractADType, x, extras::PullbackJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, pullback_batched_extras, M) = extras

Expand All @@ -267,11 +271,7 @@ function jacobian_aux(
end

function jacobian_aux!(
f_or_f!y::FY,
jac::AbstractMatrix,
backend,
x::AbstractArray,
extras::PushforwardJacobianExtras{B},
f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PushforwardJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, batched_results, pushforward_batched_extras, N) = extras

Expand Down Expand Up @@ -303,11 +303,7 @@ function jacobian_aux!(
end

function jacobian_aux!(
f_or_f!y::FY,
jac::AbstractMatrix,
backend,
x::AbstractArray,
extras::PullbackJacobianExtras{B},
f_or_f!y::FY, jac, backend::AbstractADType, x, extras::PullbackJacobianExtras{B}
) where {FY,B}
@compat (; batched_seeds, batched_results, pullback_batched_extras, M) = extras

Expand Down
16 changes: 10 additions & 6 deletions DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,27 @@ function prepare_pullback(f!::F, y, backend::AbstractADType, x, dy) where {F}
return prepare_pullback_aux(f!, y, backend, x, dy, pullback_performance(backend))
end

function prepare_pullback_aux(f::F, backend, x, dy, ::PullbackSlow) where {F}
function prepare_pullback_aux(
f::F, backend::AbstractADType, x, dy, ::PullbackSlow
) where {F}
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
pushforward_extras = prepare_pushforward(f, backend, x, dx)
return PushforwardPullbackExtras(pushforward_extras)
end

function prepare_pullback_aux(f!::F, y, backend, x, dy, ::PullbackSlow) where {F}
function prepare_pullback_aux(
f!::F, y, backend::AbstractADType, x, dy, ::PullbackSlow
) where {F}
dx = x isa Number ? one(x) : basis(backend, x, first(CartesianIndices(x)))
pushforward_extras = prepare_pushforward(f!, y, backend, x, dx)
return PushforwardPullbackExtras(pushforward_extras)
end

function prepare_pullback_aux(f, backend, x, dy, ::PullbackFast)
function prepare_pullback_aux(f, backend::AbstractADType, x, dy, ::PullbackFast)
throw(MissingBackendError(backend))
end

function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast)
function prepare_pullback_aux(f!, y, backend::AbstractADType, x, dy, ::PullbackFast)
throw(MissingBackendError(backend))
end

Expand Down Expand Up @@ -177,7 +181,7 @@ end
### With extras

function value_and_pullback(
f::F, backend, x, dy, extras::PushforwardPullbackExtras
f::F, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
) where {F}
@compat (; pushforward_extras) = extras
y = f(x)
Expand Down Expand Up @@ -241,7 +245,7 @@ end
### With extras

function value_and_pullback(
f!::F, y, backend, x, dy, extras::PushforwardPullbackExtras
f!::F, y, backend::AbstractADType, x, dy, extras::PushforwardPullbackExtras
) where {F}
@compat (; pushforward_extras) = extras
dx = if x isa Number && y isa AbstractArray
Expand Down
16 changes: 10 additions & 6 deletions DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,28 @@ function prepare_pushforward(f!::F, y, backend::AbstractADType, x, dx) where {F}
return prepare_pushforward_aux(f!, y, backend, x, dx, pushforward_performance(backend))
end

function prepare_pushforward_aux(f::F, backend, x, dx, ::PushforwardSlow) where {F}
function prepare_pushforward_aux(
f::F, backend::AbstractADType, x, dx, ::PushforwardSlow
) where {F}
y = f(x)
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
pullback_extras = prepare_pullback(f, backend, x, dy)
return PullbackPushforwardExtras(pullback_extras)
end

function prepare_pushforward_aux(f!::F, y, backend, x, dx, ::PushforwardSlow) where {F}
function prepare_pushforward_aux(
f!::F, y, backend::AbstractADType, x, dx, ::PushforwardSlow
) where {F}
dy = y isa Number ? one(y) : basis(backend, y, first(CartesianIndices(y)))
pullback_extras = prepare_pullback(f!, y, backend, x, dy)
return PullbackPushforwardExtras(pullback_extras)
end

function prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast)
function prepare_pushforward_aux(f, backend::AbstractADType, x, dx, ::PushforwardFast)
throw(MissingBackendError(backend))
end

function prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast)
function prepare_pushforward_aux(f!, y, backend::AbstractADType, x, dx, ::PushforwardFast)
throw(MissingBackendError(backend))
end

Expand Down Expand Up @@ -180,7 +184,7 @@ end
### With extras

function value_and_pushforward(
f::F, backend, x, dx, extras::PullbackPushforwardExtras
f::F, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras
) where {F}
@compat (; pullback_extras) = extras
y = f(x)
Expand Down Expand Up @@ -248,7 +252,7 @@ end
### With extras

function value_and_pushforward(
f!::F, y, backend, x, dx, extras::PullbackPushforwardExtras
f!::F, y, backend::AbstractADType, x, dx, extras::PullbackPushforwardExtras
) where {F}
@compat (; pullback_extras) = extras
dy = if x isa Number && y isa AbstractArray
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/sparse/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function prepare_jacobian(f!::F, y, backend::AutoSparse, x) where {F}
end

function prepare_sparse_jacobian_aux(
f_or_f!y::FY, backend, x, y, ::PushforwardFast
f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardFast
) where {FY}
dense_backend = dense_ad(backend)
initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))
Expand Down Expand Up @@ -116,7 +116,7 @@ function prepare_sparse_jacobian_aux(
end

function prepare_sparse_jacobian_aux(
f_or_f!y::FY, backend, x, y, ::PushforwardSlow
f_or_f!y::FY, backend::AutoSparse, x, y, ::PushforwardSlow
) where {FY}
dense_backend = dense_ad(backend)
initial_sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))
Expand Down

0 comments on commit 1bc0dde

Please sign in to comment.