Skip to content

Commit 6a771fb

Browse files
authored
Remove allocations by stronger function typing (#75)
1 parent 51b7200 commit 6a771fb

File tree

12 files changed

+393
-341
lines changed

12 files changed

+393
-341
lines changed

ext/DifferentiationInterfaceChairmarksExt/DifferentiationInterfaceChairmarksExt.jl

Lines changed: 236 additions & 212 deletions
Large diffs are not rendered by default.

src/DifferentiationTest/test_operators.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,16 @@ function test_operators(
118118
test_type_stability(backends, operators, scenarios)
119119
end
120120
if benchmark || allocations
121-
result = run_benchmark(backends, operators, scenarios)
122-
if allocations
123-
test_allocations(result)
124-
end
121+
result = run_benchmark(
122+
backends, operators, scenarios; test_allocations=allocations
123+
)
125124
end
126125
end
127-
return result
126+
if benchmark
127+
return result
128+
else
129+
return nothing
130+
end
128131
end
129132

130133
"""

src/derivative.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
Compute the primal value `y = f(x)` and the derivative `der = f'(x)` of a scalar-to-scalar function.
55
"""
66
function value_and_derivative(
7-
backend::AbstractADType, f, x::Number, extras=prepare_derivative(backend, f, x)
8-
)
7+
backend::AbstractADType, f::F, x::Number, extras=prepare_derivative(backend, f, x)
8+
) where {F}
99
return value_and_derivative_aux(backend, f, x, extras, mode(backend))
1010
end
1111

12-
function value_and_derivative_aux(backend, f, x, extras, ::ForwardMode)
12+
function value_and_derivative_aux(backend, f::F, x, extras, ::ForwardMode) where {F}
1313
return value_and_pushforward(backend, f, x, one(x), extras)
1414
end
1515

16-
function value_and_derivative_aux(backend, f, x, extras, ::ReverseMode)
16+
function value_and_derivative_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
1717
return value_and_pullback(backend, f, x, one(x), extras)
1818
end
1919

@@ -23,15 +23,15 @@ end
2323
Compute the derivative `der = f'(x)` of a scalar-to-scalar function.
2424
"""
2525
function derivative(
26-
backend::AbstractADType, f, x::Number, extras=prepare_derivative(backend, f, x)
27-
)
26+
backend::AbstractADType, f::F, x::Number, extras=prepare_derivative(backend, f, x)
27+
) where {F}
2828
return derivative_aux(backend, f, x, extras, mode(backend))
2929
end
3030

31-
function derivative_aux(backend, f, x, extras, ::ForwardMode)
31+
function derivative_aux(backend, f::F, x, extras, ::ForwardMode) where {F}
3232
return pushforward(backend, f, x, one(x), extras)
3333
end
3434

35-
function derivative_aux(backend, f, x, extras, ::ReverseMode)
35+
function derivative_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
3636
return pullback(backend, f, x, one(x), extras)
3737
end

src/gradient.jl

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an arra
66
function value_and_gradient!(
77
grad::AbstractArray,
88
backend::AbstractADType,
9-
f,
9+
f::F,
1010
x::AbstractArray,
1111
extras=prepare_gradient(backend, f, x),
12-
)
12+
) where {F}
1313
return value_and_gradient_aux!(grad, backend, f, x, extras, mode(backend))
1414
end
1515

16-
function value_and_gradient_aux!(grad, backend::AbstractADType, f, x, extras, ::ForwardMode)
16+
function value_and_gradient_aux!(
17+
grad, backend::AbstractADType, f::F, x, extras, ::ForwardMode
18+
) where {F}
1719
y = f(x)
1820
for j in eachindex(IndexCartesian(), grad)
1921
dx_j = basisarray(backend, grad, j)
@@ -22,7 +24,7 @@ function value_and_gradient_aux!(grad, backend::AbstractADType, f, x, extras, ::
2224
return y, grad
2325
end
2426

25-
function value_and_gradient_aux!(grad, backend, f, x, extras, ::ReverseMode)
27+
function value_and_gradient_aux!(grad, backend, f::F, x, extras, ::ReverseMode) where {F}
2628
return value_and_pullback!(grad, backend, f, x, one(eltype(x)), extras)
2729
end
2830

@@ -32,17 +34,17 @@ end
3234
Compute the primal value `y = f(x)` and the gradient `grad = ∇f(x)` of an array-to-scalar function.
3335
"""
3436
function value_and_gradient(
35-
backend::AbstractADType, f, x::AbstractArray, extras=prepare_gradient(backend, f, x)
36-
)
37+
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_gradient(backend, f, x)
38+
) where {F}
3739
return value_and_gradient_aux(backend, f, x, extras, mode(backend))
3840
end
3941

40-
function value_and_gradient_aux(backend, f, x, extras, ::AbstractMode)
42+
function value_and_gradient_aux(backend, f::F, x, extras, ::AbstractMode) where {F}
4143
grad = similar(x)
4244
return value_and_gradient!(grad, backend, f, x, extras)
4345
end
4446

45-
function value_and_gradient_aux(backend, f, x, extras, ::ReverseMode)
47+
function value_and_gradient_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
4648
return value_and_pullback(backend, f, x, one(eltype(x)), extras)
4749
end
4850

@@ -54,18 +56,18 @@ Compute the gradient `grad = ∇f(x)` of an array-to-scalar function, overwritin
5456
function gradient!(
5557
grad::AbstractArray,
5658
backend::AbstractADType,
57-
f,
59+
f::F,
5860
x::AbstractArray,
5961
extras=prepare_gradient(backend, f, x),
60-
)
62+
) where {F}
6163
return gradient_aux!(grad, backend, f, x, extras, mode(backend))
6264
end
6365

64-
function gradient_aux!(grad, backend, f, x, extras, ::AbstractMode)
66+
function gradient_aux!(grad, backend, f::F, x, extras, ::AbstractMode) where {F}
6567
return last(value_and_gradient!(grad, backend, f, x, extras))
6668
end
6769

68-
function gradient_aux!(grad, backend, f, x, extras, ::ReverseMode)
70+
function gradient_aux!(grad, backend, f::F, x, extras, ::ReverseMode) where {F}
6971
return pullback!(grad, backend, f, x, one(eltype(x)), extras)
7072
end
7173

@@ -75,15 +77,15 @@ end
7577
Compute the gradient `grad = ∇f(x)` of an array-to-scalar function.
7678
"""
7779
function gradient(
78-
backend::AbstractADType, f, x::AbstractArray, extras=prepare_gradient(backend, f, x)
79-
)
80+
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_gradient(backend, f, x)
81+
) where {F}
8082
return gradient_aux(backend, f, x, extras, mode(backend))
8183
end
8284

83-
function gradient_aux(backend, f, x, extras, ::AbstractMode)
85+
function gradient_aux(backend, f::F, x, extras, ::AbstractMode) where {F}
8486
return last(value_and_gradient(backend, f, x, extras))
8587
end
8688

87-
function gradient_aux(backend, f, x, extras, ::ReverseMode)
89+
function gradient_aux(backend, f::F, x, extras, ::ReverseMode) where {F}
8890
return pullback(backend, f, x, one(eltype(x)), extras)
8991
end

src/hessian.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ function value_gradient_and_hessian!(
2525
grad::AbstractArray,
2626
hess::AbstractMatrix,
2727
backend::AbstractADType,
28-
f,
28+
f::F,
2929
x::AbstractArray,
3030
extras=prepare_hessian(backend, f, x),
31-
)
31+
) where {F}
3232
return value_gradient_and_hessian!(
3333
grad, hess, SecondOrder(backend, backend), f, x, extras
3434
)
@@ -38,18 +38,18 @@ function value_gradient_and_hessian!(
3838
grad::AbstractArray,
3939
hess::AbstractMatrix,
4040
backend::SecondOrder,
41-
f,
41+
f::F,
4242
x::AbstractArray,
4343
extras=prepare_hessian(backend, f, x),
44-
)
44+
) where {F}
4545
return value_gradient_and_hessian_aux!(
4646
grad, hess, backend, f, x, extras, mode(inner(backend)), mode(outer(backend))
4747
)
4848
end
4949

5050
function value_gradient_and_hessian_aux!(
51-
grad, hess, backend, f, x, extras, ::AbstractMode, ::ForwardMode
52-
)
51+
grad, hess, backend, f::F, x, extras, ::AbstractMode, ::ForwardMode
52+
) where {F}
5353
y = f(x)
5454
check_hess(hess, x)
5555
for (k, j) in enumerate(eachindex(IndexCartesian(), x))
@@ -61,8 +61,8 @@ function value_gradient_and_hessian_aux!(
6161
end
6262

6363
function value_gradient_and_hessian_aux!(
64-
grad, hess, backend, f, x, extras, ::AbstractMode, ::ReverseMode
65-
)
64+
grad, hess, backend, f::F, x, extras, ::AbstractMode, ::ReverseMode
65+
) where {F}
6666
y, _ = value_and_gradient!(grad, inner(backend), f, x, extras)
6767
check_hess(hess, x)
6868
for (k, j) in enumerate(eachindex(IndexCartesian(), x))
@@ -81,8 +81,8 @@ Compute the primal value `y = f(x)`, the gradient `grad = ∇f(x)` and the Hessi
8181
$HESS_NOTES
8282
"""
8383
function value_gradient_and_hessian(
84-
backend::AbstractADType, f, x::AbstractArray, extras=prepare_hessian(backend, f, x)
85-
)
84+
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_hessian(backend, f, x)
85+
) where {F}
8686
grad = similar(x)
8787
hess = similar(x, length(x), length(x))
8888
return value_gradient_and_hessian!(grad, hess, backend, f, x, extras)
@@ -98,10 +98,10 @@ $HESS_NOTES
9898
function hessian!(
9999
hess::AbstractMatrix,
100100
backend::AbstractADType,
101-
f,
101+
f::F,
102102
x::AbstractArray,
103103
extras=prepare_hessian(backend, f, x),
104-
)
104+
) where {F}
105105
grad = similar(x)
106106
return last(value_gradient_and_hessian!(grad, hess, backend, f, x, extras))
107107
end
@@ -114,7 +114,7 @@ Compute the Hessian `hess = ∇²f(x)` of an array-to-scalar function.
114114
$HESS_NOTES
115115
"""
116116
function hessian(
117-
backend::AbstractADType, f, x::AbstractArray, extras=prepare_hessian(backend, f, x)
118-
)
117+
backend::AbstractADType, f::F, x::AbstractArray, extras=prepare_hessian(backend, f, x)
118+
) where {F}
119119
return last(value_gradient_and_hessian(backend, f, x, extras))
120120
end

0 commit comments

Comments
 (0)