Skip to content

Commit 2a510c8

Browse files
committed
Fixes for type inference and misc bugs.
1 parent a2366c4 commit 2a510c8

File tree

4 files changed

+58
-50
lines changed

4 files changed

+58
-50
lines changed

src/derivatives.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#=
22
Single-point derivatives of scalar->scalar maps.
33
=#
4-
function finite_difference_derivative(f, x::T, fdtype::DataType=Val{:central},
5-
returntype::DataType=eltype(x), f_x::Union{Void,T}=nothing) where T<:Number
4+
function finite_difference_derivative(f, x::T, fdtype::Type{T1}=Val{:central},
5+
returntype::Type{T2}=eltype(x), f_x::Union{Void,T}=nothing) where {T<:Number,T1,T2}
66

77
epsilon = compute_epsilon(fdtype, x)
88
if fdtype==Val{:forward}
@@ -74,8 +74,8 @@ function DerivativeCache(
7474
x :: AbstractArray{<:Number},
7575
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
7676
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing,
77-
fdtype :: DataType = Val{:central},
78-
returntype :: DataType = eltype(x))
77+
fdtype :: Type{T1} = Val{:central},
78+
returntype :: Type{T2} = eltype(x)) where {T1,T2}
7979

8080
if fdtype==Val{:complex} && !(eltype(returntype)<:Real)
8181
fdtype_error(returntype)
@@ -99,8 +99,6 @@ function DerivativeCache(
9999
if typeof(epsilon)==Void || eltype(epsilon)!=real(eltype(x))
100100
epsilon = zeros(real(eltype(x)), size(x))
101101
end
102-
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
103-
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
104102
_epsilon = epsilon
105103
end
106104
DerivativeCache{typeof(_fx),typeof(_epsilon),fdtype,returntype}(_fx,_epsilon)
@@ -112,10 +110,10 @@ Compute the derivative df of a scalar-valued map f at a collection of points x.
112110
function finite_difference_derivative(
113111
f,
114112
x :: AbstractArray{<:Number},
115-
fdtype :: DataType = Val{:central},
116-
returntype :: DataType = eltype(x), # return type of f
113+
fdtype :: Type{T1} = Val{:central},
114+
returntype :: Type{T2} = eltype(x), # return type of f
117115
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
118-
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing)
116+
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing) where {T1,T2}
119117

120118
df = zeros(returntype, size(x))
121119
finite_difference_derivative!(df, f, x, fdtype, returntype, fx, epsilon)
@@ -125,10 +123,10 @@ function finite_difference_derivative!(
125123
df :: AbstractArray{<:Number},
126124
f,
127125
x :: AbstractArray{<:Number},
128-
fdtype :: DataType = Val{:central},
129-
returntype :: DataType = eltype(x),
126+
fdtype :: Type{T1} = Val{:central},
127+
returntype :: Type{T2} = eltype(x),
130128
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
131-
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing)
129+
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing) where {T1,T2}
132130

133131
cache = DerivativeCache(x, fx, epsilon, fdtype, returntype)
134132
finite_difference_derivative!(df, f, x, cache)
@@ -138,6 +136,10 @@ function finite_difference_derivative!(df::AbstractArray{<:Number}, f, x::Abstra
138136
cache::DerivativeCache{T1,T2,fdtype,returntype}) where {T1,T2,fdtype,returntype}
139137

140138
fx, epsilon = cache.fx, cache.epsilon
139+
if typeof(epsilon) != Void
140+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
141+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
142+
end
141143
if fdtype == Val{:forward}
142144
if typeof(fx) == Void
143145
@. df = (f(x+epsilon) - f(x)) / epsilon

src/gradients.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ function GradientCache(
1010
fx :: Union{Void,<:Number,AbstractArray{<:Number}} = nothing,
1111
c1 :: Union{Void,AbstractArray{<:Number}} = nothing,
1212
c2 :: Union{Void,AbstractArray{<:Number}} = nothing,
13-
fdtype :: DataType = Val{:central},
14-
returntype :: DataType = eltype(x),
15-
inplace :: Bool = true)
13+
fdtype :: Type{T1} = Val{:central},
14+
returntype :: Type{T2} = eltype(x),
15+
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
1616

1717
if fdtype!=Val{:forward} && typeof(fx)!=Void
1818
warn("Pre-computed function values are only useful for fdtype == Val{:forward}.")
@@ -30,13 +30,10 @@ function GradientCache(
3030
else
3131
_c1 = c1
3232
end
33-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
34-
@. _c1 = compute_epsilon(fdtype, real(x), epsilon_factor)
35-
3633
if typeof(c2)!=typeof(x) || size(c2)!=size(x)
37-
_c2 = copy(x)
34+
_c2 = similar(x)
3835
else
39-
copy!(_c2, x)
36+
_c2 = c2
4037
end
4138
else
4239
if !(returntype<:Real)
@@ -76,16 +73,16 @@ function GradientCache(
7673
GradientCache{typeof(_fx),typeof(_c1),typeof(_c2),fdtype,returntype,inplace}(_fx,_c1,_c2)
7774
end
7875

79-
function finite_difference_gradient(f, x, fdtype::DataType=Val{:central},
80-
returntype::DataType=eltype(x), inplace::Bool=true,
76+
function finite_difference_gradient(f, x, fdtype::Type{T1}=Val{:central},
77+
returntype::Type{T2}=eltype(x), inplace::Type{Val{T3}}=Val{true},
8178
fx::Union{Void,AbstractArray{<:Number}}=nothing,
8279
c1::Union{Void,AbstractArray{<:Number}}=nothing,
83-
c2::Union{Void,AbstractArray{<:Number}}=nothing)
80+
c2::Union{Void,AbstractArray{<:Number}}=nothing) where {T1,T2,T3}
8481

8582
if typeof(x) <: AbstractArray
8683
df = zeros(returntype, size(x))
8784
else
88-
if inplace
85+
if inplace == Val{true}
8986
if typeof(fx)==Void && typeof(c1)==Void && typeof(c2)==Void
9087
error("In the scalar->vector in-place map case, at least one of fx, c1 or c2 must be provided, otherwise we cannot infer the return size.")
9188
else
@@ -102,12 +99,12 @@ function finite_difference_gradient(f, x, fdtype::DataType=Val{:central},
10299
finite_difference_gradient!(df,f,x,cache)
103100
end
104101

105-
function finite_difference_gradient!(df, f, x, fdtype::DataType=Val{:central},
106-
returntype::DataType=eltype(x), inplace::Bool=true,
102+
function finite_difference_gradient!(df, f, x, fdtype::Type{T1}=Val{:central},
103+
returntype::Type{T2}=eltype(x), inplace::Type{Val{T3}}=Val{true},
107104
fx::Union{Void,AbstractArray{<:Number}}=nothing,
108105
c1::Union{Void,AbstractArray{<:Number}}=nothing,
109106
c2::Union{Void,AbstractArray{<:Number}}=nothing,
110-
)
107+
) where {T1,T2,T3}
111108

112109
cache = GradientCache(df,x,fx,c1,c2,fdtype,returntype,inplace)
113110
finite_difference_gradient!(df,f,x,cache)
@@ -131,9 +128,13 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
131128
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace}) where {T1,T2,T3,fdtype,returntype,inplace}
132129

133130
# NOTE: in this case epsilon is a vector, we need two arrays for epsilon and x1
134-
# c1 denotes epsilon (pre-computed by the cache constructor),
135-
# c2 is x1, pre-set to the values of x by the cache constructor
131+
# c1 denotes epsilon, c2 is x1, pre-set to the values of x by the cache constructor
136132
fx, c1, c2 = cache.fx, cache.c1, cache.c2
133+
if fdtype != Val{:complex}
134+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
135+
@. c1 = compute_epsilon(fdtype, x, epsilon_factor)
136+
copy!(c2,x)
137+
end
137138
if fdtype == Val{:forward}
138139
@inbounds for i eachindex(x)
139140
c2[i] += c1[i]
@@ -153,6 +154,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
153154
x[i] += c1[i]
154155
end
155156
elseif fdtype == Val{:complex} && returntype <: Real
157+
copy!(c1,x)
156158
epsilon_complex = eps(real(eltype(x)))
157159
# we use c1 here to avoid typing issues with x
158160
@inbounds for i eachindex(x)
@@ -176,27 +178,27 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
176178
fx, c1, c2 = cache.fx, cache.c1, cache.c2
177179

178180
if fdtype == Val{:forward}
179-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
180-
epsilon = compute_epsilon(Val{:forward}, real(x), epsilon_factor)
181-
if inplace
181+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
182+
epsilon = compute_epsilon(Val{:forward}, x, epsilon_factor)
183+
if inplace == Val{true}
182184
f(c1, x+epsilon)
183185
else
184186
c1 .= f(x+epsilon)
185187
end
186188
if typeof(fx) != Void
187189
@. df = (c1 - fx) / epsilon
188190
else
189-
if inplace
191+
if inplace == Val{true}
190192
f(c2, x)
191193
else
192194
c2 .= f(x)
193195
end
194196
@. df = (c1 - c2) / epsilon
195197
end
196198
elseif fdtype == Val{:central}
197-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
198-
epsilon = compute_epsilon(Val{:central}, real(x), epsilon_factor)
199-
if inplace
199+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
200+
epsilon = compute_epsilon(Val{:central}, x, epsilon_factor)
201+
if inplace == Val{true}
200202
f(c1, x+epsilon)
201203
f(c2, x-epsilon)
202204
else
@@ -206,7 +208,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,
206208
@. df = (c1 - c2) / (2*epsilon)
207209
elseif fdtype == Val{:complex} && returntype <: Real
208210
epsilon_complex = eps(real(eltype(x)))
209-
if inplace
211+
if inplace == Val{true}
210212
f(c1, x+im*epsilon_complex)
211213
else
212214
c1 .= f(x+im*epsilon_complex)

src/jacobians.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ function JacobianCache(
99
x1 :: Union{Void,AbstractArray{<:Number}} = nothing,
1010
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
1111
fx1 :: Union{Void,AbstractArray{<:Number}} = nothing,
12-
fdtype :: DataType = Val{:central},
13-
returntype :: DataType = eltype(x),
14-
inplace :: Bool = true)
12+
fdtype :: Type{T1} = Val{:central},
13+
returntype :: Type{T2} = eltype(x),
14+
inplace :: Type{Val{T3}} = Val{true}) where {T1,T2,T3}
1515

1616
if fdtype==Val{:complex}
1717
if !(returntype<:Real)
@@ -49,8 +49,12 @@ function JacobianCache(
4949
JacobianCache{typeof(_x1),typeof(_fx),typeof(_fx1),fdtype,returntype,inplace}(_x1,_fx,_fx1)
5050
end
5151

52-
function finite_difference_jacobian(f,x,fdtype=Val{:central},returntype=eltype(x))
53-
cache = JacobianCache(x,nothing,nothing,nothing,fdtype,returntype)
52+
function finite_difference_jacobian(f, x::AbstractArray{<:Number},
53+
fdtype :: Type{T1}=Val{:central},
54+
returntype :: Type{T2}=eltype(x),
55+
inplace :: Type{Val{T3}}=Val{true}) where {T1,T2,T3}
56+
57+
cache = JacobianCache(x,nothing,nothing,nothing,fdtype,returntype,inplace)
5458
finite_difference_jacobian(f,x,cache)
5559
end
5660

@@ -74,7 +78,7 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
7478
epsilon = compute_epsilon(Val{:forward}, x[i], epsilon_factor)
7579
x1_save = x1[i]
7680
x1[i] += epsilon
77-
if inplace
81+
if inplace == Val{true}
7882
f(fx1, x1)
7983
f(fx, x)
8084
else
@@ -93,7 +97,7 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
9397
x_save = x[i]
9498
x1[i] += epsilon
9599
x[i] -= epsilon
96-
if inplace
100+
if inplace == Val{true}
97101
f(fx1, x1)
98102
f(fx, x)
99103
else
@@ -109,7 +113,7 @@ function finite_difference_jacobian!(J::AbstractMatrix{<:Number}, f,x::AbstractA
109113
@inbounds for i 1:n
110114
x1_save = x1[i]
111115
x1[i] += im * epsilon
112-
if inplace
116+
if inplace == Val{true}
113117
f(fx,x1)
114118
else
115119
fx .= f(x1)

test/finitedifftests.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ complex_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:compl
136136

137137
@time @testset "Gradient of f:scalar->vector real-valued tests" begin
138138
@test_broken err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}), df_ref) < 1e-4
139-
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(x), true, fx), df_ref) < 1e-4
140-
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}, eltype(x), true, fx), df_ref) < 1e-8
141-
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}, eltype(x), true, fx), df_ref) < 1e-15
139+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(x), Val{true}, fx), df_ref) < 1e-4
140+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}, eltype(x), Val{true}, fx), df_ref) < 1e-8
141+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:complex}, eltype(x), Val{true}, fx), df_ref) < 1e-15
142142

143143
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
144144
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-8
@@ -159,8 +159,8 @@ forward_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:forwa
159159
central_cache = DiffEqDiffTools.GradientCache(df,x,fx,nothing,nothing,Val{:central})
160160

161161
@time @testset "Gradient of f:vector->scalar complex-valued tests" begin
162-
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(x), true, fx), df_ref) < 1e-4
163-
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}, eltype(x), true, fx), df_ref) < 1e-7
162+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:forward}, eltype(x), Val{true}, fx), df_ref) < 1e-4
163+
@test err_func(DiffEqDiffTools.finite_difference_gradient(f, x, Val{:central}, eltype(x), Val{true}, fx), df_ref) < 1e-7
164164

165165
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:forward}), df_ref) < 1e-4
166166
@test err_func(DiffEqDiffTools.finite_difference_gradient!(df, f, x, Val{:central}), df_ref) < 1e-7

0 commit comments

Comments
 (0)