Skip to content

Commit

Permalink
Remove allocations by stronger function typing (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 19, 2024
1 parent 51b7200 commit 6a771fb
Show file tree
Hide file tree
Showing 12 changed files with 393 additions and 341 deletions.

Large diffs are not rendered by default.

13 changes: 8 additions & 5 deletions src/DifferentiationTest/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,16 @@ function test_operators(
test_type_stability(backends, operators, scenarios)
end
if benchmark || allocations
result = run_benchmark(backends, operators, scenarios)
if allocations
test_allocations(result)
end
result = run_benchmark(
backends, operators, scenarios; test_allocations=allocations
)
end
end
return result
if benchmark
return result
else
return nothing
end
end

"""
Expand Down
16 changes: 8 additions & 8 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function.
"""
function value_and_derivative(
backend::AbstractADType, f, x::Number, extras=prepare_derivative(backend, f, x)
)
backend::AbstractADType, f::F, x::Number, extras=prepare_derivative(backend, f, x)
) where {F}
return value_and_derivative_aux(backend, f, x, extras, mode(backend))
end

function value_and_derivative_aux(backend, f, x, extras, ::ForwardMode)
function value_and_derivative_aux(backend, f::F, x, extras, ::ForwardMode) where {F}
return value_and_pushforward(backend, f, x, one(x), extras)
end

function value_and_derivative_aux(backend, f, x, extras, ::ReverseMode)
function value_and_derivative_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
return value_and_pullback(backend, f, x, one(x), extras)
end

Expand All @@ -23,15 +23,15 @@ end
Compute the derivative `der = f'(x)` of a scalar-to-scalar function.
"""
function derivative(
backend::AbstractADType, f, x::Number, extras=prepare_derivative(backend, f, x)
)
backend::AbstractADType, f::F, x::Number, extras=prepare_derivative(backend, f, x)
) where {F}
return derivative_aux(backend, f, x, extras, mode(backend))
end

function derivative_aux(backend, f, x, extras, ::ForwardMode)
function derivative_aux(backend, f::F, x, extras, ::ForwardMode) where {F}
return pushforward(backend, f, x, one(x), extras)
end

function derivative_aux(backend, f, x, extras, ::ReverseMode)
function derivative_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
return pullback(backend, f, x, one(x), extras)
end
34 changes: 18 additions & 16 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an arra
function value_and_gradient!(
grad::AbstractArray,
backend::AbstractADType,
f,
f::F,
x::AbstractArray,
extras=prepare_gradient(backend, f, x),
)
) where {F}
return value_and_gradient_aux!(grad, backend, f, x, extras, mode(backend))
end

function value_and_gradient_aux!(grad, backend::AbstractADType, f, x, extras, ::ForwardMode)
function value_and_gradient_aux!(
grad, backend::AbstractADType, f::F, x, extras, ::ForwardMode
) where {F}
y = f(x)
for j in eachindex(IndexCartesian(), grad)
dx_j = basisarray(backend, grad, j)
Expand All @@ -22,7 +24,7 @@ function value_and_gradient_aux!(grad, backend::AbstractADType, f, x, extras, ::
return y, grad
end

function value_and_gradient_aux!(grad, backend, f, x, extras, ::ReverseMode)
function value_and_gradient_aux!(grad, backend, f::F, x, extras, ::ReverseMode) where {F}
return value_and_pullback!(grad, backend, f, x, one(eltype(x)), extras)
end

Expand All @@ -32,17 +34,17 @@ end
Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function.
"""
function value_and_gradient(
backend::AbstractADType, f, x::AbstractArray, extras=prepare_gradient(backend, f, x)
)
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_gradient(backend, f, x)
) where {F}
return value_and_gradient_aux(backend, f, x, extras, mode(backend))
end

function value_and_gradient_aux(backend, f, x, extras, ::AbstractMode)
function value_and_gradient_aux(backend, f::F, x, extras, ::AbstractMode) where {F}
grad = similar(x)
return value_and_gradient!(grad, backend, f, x, extras)
end

function value_and_gradient_aux(backend, f, x, extras, ::ReverseMode)
function value_and_gradient_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
return value_and_pullback(backend, f, x, one(eltype(x)), extras)
end

Expand All @@ -54,18 +56,18 @@ Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwritin
function gradient!(
grad::AbstractArray,
backend::AbstractADType,
f,
f::F,
x::AbstractArray,
extras=prepare_gradient(backend, f, x),
)
) where {F}
return gradient_aux!(grad, backend, f, x, extras, mode(backend))
end

function gradient_aux!(grad, backend, f, x, extras, ::AbstractMode)
function gradient_aux!(grad, backend, f::F, x, extras, ::AbstractMode) where {F}
return last(value_and_gradient!(grad, backend, f, x, extras))
end

function gradient_aux!(grad, backend, f, x, extras, ::ReverseMode)
function gradient_aux!(grad, backend, f::F, x, extras, ::ReverseMode) where {F}
return pullback!(grad, backend, f, x, one(eltype(x)), extras)
end

Expand All @@ -75,15 +77,15 @@ end
Compute the gradient `grad = ∇f(x)` of an array-to-scalar function.
"""
function gradient(
backend::AbstractADType, f, x::AbstractArray, extras=prepare_gradient(backend, f, x)
)
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_gradient(backend, f, x)
) where {F}
return gradient_aux(backend, f, x, extras, mode(backend))
end

function gradient_aux(backend, f, x, extras, ::AbstractMode)
function gradient_aux(backend, f::F, x, extras, ::AbstractMode) where {F}
return last(value_and_gradient(backend, f, x, extras))
end

function gradient_aux(backend, f, x, extras, ::ReverseMode)
function gradient_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
return pullback(backend, f, x, one(eltype(x)), extras)
end
28 changes: 14 additions & 14 deletions src/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ function value_gradient_and_hessian!(
grad::AbstractArray,
hess::AbstractMatrix,
backend::AbstractADType,
f,
f::F,
x::AbstractArray,
extras=prepare_hessian(backend, f, x),
)
) where {F}
return value_gradient_and_hessian!(
grad, hess, SecondOrder(backend, backend), f, x, extras
)
Expand All @@ -38,18 +38,18 @@ function value_gradient_and_hessian!(
grad::AbstractArray,
hess::AbstractMatrix,
backend::SecondOrder,
f,
f::F,
x::AbstractArray,
extras=prepare_hessian(backend, f, x),
)
) where {F}
return value_gradient_and_hessian_aux!(
grad, hess, backend, f, x, extras, mode(inner(backend)), mode(outer(backend))
)
end

function value_gradient_and_hessian_aux!(
grad, hess, backend, f, x, extras, ::AbstractMode, ::ForwardMode
)
grad, hess, backend, f::F, x, extras, ::AbstractMode, ::ForwardMode
) where {F}
y = f(x)
check_hess(hess, x)
for (k, j) in enumerate(eachindex(IndexCartesian(), x))
Expand All @@ -61,8 +61,8 @@ function value_gradient_and_hessian_aux!(
end

function value_gradient_and_hessian_aux!(
grad, hess, backend, f, x, extras, ::AbstractMode, ::ReverseMode
)
grad, hess, backend, f::F, x, extras, ::AbstractMode, ::ReverseMode
) where {F}
y, _ = value_and_gradient!(grad, inner(backend), f, x, extras)
check_hess(hess, x)
for (k, j) in enumerate(eachindex(IndexCartesian(), x))
Expand All @@ -81,8 +81,8 @@ Compute the primal value `y = f(x)`, the gradient `grad = ∇f(x)` and the Hessi
$HESS_NOTES
"""
function value_gradient_and_hessian(
backend::AbstractADType, f, x::AbstractArray, extras=prepare_hessian(backend, f, x)
)
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_hessian(backend, f, x)
) where {F}
grad = similar(x)
hess = similar(x, length(x), length(x))
return value_gradient_and_hessian!(grad, hess, backend, f, x, extras)
Expand All @@ -98,10 +98,10 @@ $HESS_NOTES
function hessian!(
hess::AbstractMatrix,
backend::AbstractADType,
f,
f::F,
x::AbstractArray,
extras=prepare_hessian(backend, f, x),
)
) where {F}
grad = similar(x)
return last(value_gradient_and_hessian!(grad, hess, backend, f, x, extras))
end
Expand All @@ -114,7 +114,7 @@ Compute the Hessian `hess = ∇²f(x)` of an array-to-scalar function.
$HESS_NOTES
"""
function hessian(
backend::AbstractADType, f, x::AbstractArray, extras=prepare_hessian(backend, f, x)
)
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_hessian(backend, f, x)
) where {F}
return last(value_gradient_and_hessian(backend, f, x, extras))
end
Loading

0 comments on commit 6a771fb

Please sign in to comment.