Skip to content

Commit 76a28c5

Browse files
committed
Gradient performance fixes #3.
1 parent e6e3417 commit 76a28c5

File tree

1 file changed

+30
-26
lines changed

1 file changed

+30
-26
lines changed

src/gradients.jl

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,25 @@ function GradientCache(
2323
end
2424

2525
if typeof(x)<:AbstractArray # the vector->scalar case
26-
# need cache arrays for x1 (c1) and epsilon (c2, only if non-StridedArray)
26+
# need cache arrays for x1 (c1) and epsilon (c2) (both only if non-StridedArray)
2727
if fdtype!=Val{:complex} # complex-mode FD only needs one cache, for x+eps*im
28-
if typeof(c1)!=typeof(x) || size(c1)!=size(x)
29-
_c1 = similar(x)
30-
else
31-
_c1 = c1
32-
end
33-
if typeof(c2)!=Void && x<:StridedVector
34-
warn("c2 cache isn't necessary when x<:StridedVector.")
35-
end
36-
if (typeof(c2)==Void || eltype(c2)!=real(eltype(x))) && !(typeof(x)<:StridedVector)
37-
_c2 = zeros(real(eltype(x)), size(x))
38-
elseif typeof(x)<:StridedArray
28+
if typeof(x)<:StridedVector
29+
_c1 = nothing
3930
_c2 = nothing
31+
if typeof(c1)!=Void || typeof(c2)!=Void
32+
warn("For StridedVectors, neither c1 nor c2 are necessary.")
33+
end
4034
else
41-
_c2 = c2
35+
if typeof(c1)!=typeof(x) || size(c1)!=size(x)
36+
_c1 = similar(x)
37+
else
38+
_c1 = c1
39+
end
40+
if (typeof(c2)==Void || eltype(c2)!=real(eltype(x)))
41+
_c2 = zeros(real(eltype(x)), size(x))
42+
else
43+
_c2 = c2
44+
end
4245
end
4346
else
4447
if !(returntype<:Real)
@@ -186,30 +189,31 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
186189
fx, c1, c2 = cache.fx, cache.c1, cache.c2
187190
if fdtype != Val{:complex}
188191
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
189-
copy!(c1,x)
190192
end
191193
if fdtype == Val{:forward}
192194
@inbounds for i eachindex(x)
193195
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
194-
c1_old = c1[i]
195-
c1[i] += epsilon
196+
x_old = x[i]
197+
x[i] += epsilon
198+
dfi = f(x)
199+
x[i] = x_old
196200
if typeof(fx) != Void
197-
df[i] = (f(c1) - fx) / epsilon
201+
dfi -= fx
198202
else
199-
df[i] = (f(c1) - f(x)) / epsilon
203+
dfi -= f(x)
200204
end
201-
c1[i] = c1_old
205+
df[i] = dfi / epsilon
202206
end
203207
elseif fdtype == Val{:central}
204208
@inbounds for i eachindex(x)
205209
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
206-
c1_old = c1[i]
207-
c1[i] += epsilon
208-
x_old = x[i]
209-
x[i] -= epsilon
210-
df[i] = (f(c1) - f(x)) / (2*epsilon)
211-
c1[i] = c1_old
212-
x[i] = x_old
210+
x_old = x[i]
211+
x[i] += epsilon
212+
dfi = f(x)
213+
x[i] = x_old - epsilon
214+
dfi -= f(x)
215+
x[i] = x_old
216+
df[i] = dfi / (2*epsilon)
213217
end
214218
elseif fdtype == Val{:complex} && returntype <: Real
215219
copy!(c1,x)

0 commit comments

Comments
 (0)