Skip to content

Commit a2366c4

Browse files
committed
Fixed Jacobians, added inplace function support to gradients, fixed all tests.
1 parent a0dc5d0 commit a2366c4

File tree

5 files changed

+227
-290
lines changed

5 files changed

+227
-290
lines changed

src/derivatives.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#=
2-
Sinple-point derivatives of scalar->scalar maps.
2+
Single-point derivatives of scalar->scalar maps.
33
=#
44
function finite_difference_derivative(f, x::T, fdtype::DataType=Val{:central},
55
returntype::DataType=eltype(x), f_x::Union{Void,T}=nothing) where T<:Number
66

7-
epsilon = compute_epsilon(fdtype, real(x))
7+
epsilon = compute_epsilon(fdtype, x)
88
if fdtype==Val{:forward}
99
return (f(x+epsilon) - f(x)) / epsilon
1010
elseif fdtype==Val{:central}
@@ -73,7 +73,7 @@ end
7373
function DerivativeCache(
7474
x :: AbstractArray{<:Number},
7575
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
76-
epsilon :: Union{Void,AbstractArray{<:Number}} = nothing,
76+
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing,
7777
fdtype :: DataType = Val{:central},
7878
returntype :: DataType = eltype(x))
7979

@@ -99,8 +99,8 @@ function DerivativeCache(
9999
if typeof(epsilon)==Void || eltype(epsilon)!=real(eltype(x))
100100
epsilon = zeros(real(eltype(x)), size(x))
101101
end
102-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
103-
@. epsilon = compute_epsilon(fdtype, real(x), epsilon_factor)
102+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
103+
@. epsilon = compute_epsilon(fdtype, x, epsilon_factor)
104104
_epsilon = epsilon
105105
end
106106
DerivativeCache{typeof(_fx),typeof(_epsilon),fdtype,returntype}(_fx,_epsilon)
@@ -115,7 +115,7 @@ function finite_difference_derivative(
115115
fdtype :: DataType = Val{:central},
116116
returntype :: DataType = eltype(x), # return type of f
117117
fx :: Union{Void,AbstractArray{<:Number}} = nothing,
118-
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing)
118+
epsilon :: Union{Void,AbstractArray{<:Real}} = nothing)
119119

120120
df = zeros(returntype, size(x))
121121
finite_difference_derivative!(df, f, x, fdtype, returntype, fx, epsilon)
@@ -160,10 +160,10 @@ Optimized implementations for StridedArrays.
160160
Essentially, the only difference between these and the AbstractArray case
161161
is that here we can compute the epsilon one by one in local variables and avoid caching it.
162162
=#
163-
function _finite_difference_derivative!(df::StridedArray, f, x::StridedArray,
163+
function finite_difference_derivative!(df::StridedArray, f, x::StridedArray,
164164
cache::DerivativeCache{T1,T2,fdtype,returntype}) where {T1,T2,fdtype,returntype}
165165

166-
epsilon_factor = compute_epsilon_factor(fdtype, real(eltype(x)))
166+
epsilon_factor = compute_epsilon_factor(fdtype, eltype(x))
167167
if fdtype == Val{:forward}
168168
fx = cache.fx
169169
@inbounds for i eachindex(x)

src/finitediff.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ Very heavily inspired by Calculus.jl, but with an emphasis on performance and Di
66
Compute the finite difference interval epsilon.
77
Reference: Numerical Recipes, chapter 5.7.
88
=#
9-
@inline function compute_epsilon(::Type{Val{:forward}}, x::T, eps_sqrt::T=sqrt(eps(T))) where T<:Real
10-
eps_sqrt * max(one(T), abs(x))
9+
@inline function compute_epsilon(::Type{Val{:forward}}, x::T, eps_sqrt=sqrt(eps(real(T)))) where T<:Number
10+
eps_sqrt * max(one(real(T)), abs(x))
1111
end
1212

13-
@inline function compute_epsilon(::Type{Val{:central}}, x::T, eps_cbrt::T=cbrt(eps(T))) where T<:Real
14-
eps_cbrt * max(one(T), abs(x))
13+
@inline function compute_epsilon(::Type{Val{:central}}, x::T, eps_cbrt=cbrt(eps(real(T)))) where T<:Number
14+
eps_cbrt * max(one(real(T)), abs(x))
1515
end
1616

1717
@inline function compute_epsilon(::Type{Val{:complex}}, x::T, ::Union{Void,T}=nothing) where T<:Real
@@ -20,20 +20,20 @@ end
2020

2121
@inline function compute_epsilon_factor(fdtype::DataType, ::Type{T}) where T<:Number
2222
if fdtype==Val{:forward}
23-
return sqrt(eps(T))
23+
return sqrt(eps(real(T)))
2424
elseif fdtype==Val{:central}
25-
return cbrt(eps(T))
25+
return cbrt(eps(real(T)))
2626
else
27-
return one(T)
27+
return one(real(T))
2828
end
2929
end
3030

31-
function fdtype_error(funtype::DataType=Val{:Real})
32-
if funtype == Val{:Real}
31+
function fdtype_error(funtype::DataType=Float64)
32+
if funtype<:Real
3333
error("Unrecognized fdtype: valid values are Val{:forward}, Val{:central} and Val{:complex}.")
34-
elseif funtype == Val{:Complex}
34+
elseif funtype<:Complex
3535
error("Unrecognized fdtype: valid values are Val{:forward} or Val{:central}.")
3636
else
37-
error("Unrecognized funtype: valid values are Val{:Real} or Val{:Complex}.")
37+
error("Unrecognized returntype: should be a subtype of Real or Complex.")
3838
end
3939
end

0 commit comments

Comments
 (0)