@@ -23,22 +23,25 @@ function GradientCache(
23
23
end
24
24
25
25
if typeof (x)<: AbstractArray # the vector->scalar case
26
- # need cache arrays for x1 (c1) and epsilon (c2, only if non-StridedArray)
26
+ # need cache arrays for x1 (c1) and epsilon (c2) (both only if non-StridedArray)
27
27
if fdtype!= Val{:complex } # complex-mode FD only needs one cache, for x+eps*im
28
- if typeof (c1)!= typeof (x) || size (c1)!= size (x)
29
- _c1 = similar (x)
30
- else
31
- _c1 = c1
32
- end
33
- if typeof (c2)!= Void && x<: StridedVector
34
- warn (" c2 cache isn't necessary when x<:StridedVector." )
35
- end
36
- if (typeof (c2)== Void || eltype (c2)!= real (eltype (x))) && ! (typeof (x)<: StridedVector )
37
- _c2 = zeros (real (eltype (x)), size (x))
38
- elseif typeof (x)<: StridedArray
28
+ if typeof (x)<: StridedVector
29
+ _c1 = nothing
39
30
_c2 = nothing
31
+ if typeof (c1)!= Void || typeof (c2)!= Void
32
+ warn (" For StridedVectors, neither c1 nor c2 are necessary." )
33
+ end
40
34
else
41
- _c2 = c2
35
+ if typeof (c1)!= typeof (x) || size (c1)!= size (x)
36
+ _c1 = similar (x)
37
+ else
38
+ _c1 = c1
39
+ end
40
+ if (typeof (c2)== Void || eltype (c2)!= real (eltype (x)))
41
+ _c2 = zeros (real (eltype (x)), size (x))
42
+ else
43
+ _c2 = c2
44
+ end
42
45
end
43
46
else
44
47
if ! (returntype<: Real )
@@ -186,30 +189,31 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
186
189
fx, c1, c2 = cache. fx, cache. c1, cache. c2
187
190
if fdtype != Val{:complex }
188
191
epsilon_factor = compute_epsilon_factor (fdtype, eltype (x))
189
- copy! (c1,x)
190
192
end
191
193
if fdtype == Val{:forward }
192
194
@inbounds for i ∈ eachindex (x)
193
195
epsilon = compute_epsilon (fdtype, x[i], epsilon_factor)
194
- c1_old = c1[i]
195
- c1[i] += epsilon
196
+ x_old = x[i]
197
+ x[i] += epsilon
198
+ dfi = f (x)
199
+ x[i] = x_old
196
200
if typeof (fx) != Void
197
- df[i] = ( f (c1) - fx) / epsilon
201
+ dfi -= fx
198
202
else
199
- df[i] = ( f (c1) - f (x)) / epsilon
203
+ dfi -= f (x)
200
204
end
201
- c1 [i] = c1_old
205
+ df [i] = dfi / epsilon
202
206
end
203
207
elseif fdtype == Val{:central }
204
208
@inbounds for i ∈ eachindex (x)
205
209
epsilon = compute_epsilon (fdtype, x[i], epsilon_factor)
206
- c1_old = c1 [i]
207
- c1 [i] += epsilon
208
- x_old = x[i]
209
- x[i] -= epsilon
210
- df[i] = ( f (c1) - f (x)) / ( 2 * epsilon )
211
- c1 [i] = c1_old
212
- x [i] = x_old
210
+ x_old = x [i]
211
+ x [i] += epsilon
212
+ dfi = f (x)
213
+ x[i] = x_old - epsilon
214
+ dfi -= f (x )
215
+ x [i] = x_old
216
+ df [i] = dfi / ( 2 * epsilon)
213
217
end
214
218
elseif fdtype == Val{:complex } && returntype <: Real
215
219
copy! (c1,x)
0 commit comments