Skip to content

Commit e6e3417

Browse files
committed
Minor perf improvements for gradients #2.
1 parent 624039b commit e6e3417

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/gradients.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@ function GradientCache(
3333
if typeof(c2)!=Void && x<:StridedVector
3434
warn("c2 cache isn't necessary when x<:StridedVector.")
3535
end
36-
if typeof(c2)==Void || eltype(c2)!=real(eltype(x)) && !(x<:StridedVector)
36+
if (typeof(c2)==Void || eltype(c2)!=real(eltype(x))) && !(typeof(x)<:StridedVector)
3737
_c2 = zeros(real(eltype(x)), size(x))
38+
elseif typeof(x)<:StridedArray
39+
_c2 = nothing
3840
else
3941
_c2 = c2
4042
end
@@ -176,7 +178,7 @@ function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Abstract
176178
end
177179
df
178180
end
179-
#=
181+
180182
function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedVector{<:Number},
181183
cache::GradientCache{T1,T2,T3,fdtype,returntype,inplace}) where {T1,T2,T3,fdtype,returntype,inplace}
182184

@@ -189,24 +191,24 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
189191
if fdtype == Val{:forward}
190192
@inbounds for i eachindex(x)
191193
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
192-
c2_old = c2[i]
193-
c2[i] += epsilon
194+
c1_old = c1[i]
195+
c1[i] += epsilon
194196
if typeof(fx) != Void
195-
df[i] = (f(c2) - fx) / epsilon
197+
df[i] = (f(c1) - fx) / epsilon
196198
else
197-
df[i] = (f(c2) - f(x)) / epsilon
199+
df[i] = (f(c1) - f(x)) / epsilon
198200
end
199-
c2[i] = c2_old
201+
c1[i] = c1_old
200202
end
201203
elseif fdtype == Val{:central}
202204
@inbounds for i eachindex(x)
203205
epsilon = compute_epsilon(fdtype, x[i], epsilon_factor)
204-
c2_old = c2[i]
205-
c2[i] += epsilon
206+
c1_old = c1[i]
207+
c1[i] += epsilon
206208
x_old = x[i]
207209
x[i] -= epsilon
208-
df[i] = (f(c2) - f(x)) / (2*epsilon)
209-
c2[i] = c2_old
210+
df[i] = (f(c1) - f(x)) / (2*epsilon)
211+
c1[i] = c1_old
210212
x[i] = x_old
211213
end
212214
elseif fdtype == Val{:complex} && returntype <: Real
@@ -224,7 +226,7 @@ function finite_difference_gradient!(df::StridedVector{<:Number}, f, x::StridedV
224226
end
225227
df
226228
end
227-
=#
229+
228230
# vector of derivatives of a scalar->vector map
229231
# this is effectively a vector of partial derivatives, but we still call it a gradient
230232
function finite_difference_gradient!(df::AbstractArray{<:Number}, f, x::Number,

0 commit comments

Comments
 (0)