diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index cfafe6a5..b697c5b7 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -18,6 +18,7 @@ include("derivative.jl") include("gradient.jl") include("jacobian.jl") include("hessian.jl") +include("complex.jl") export DiffResults diff --git a/src/complex.jl b/src/complex.jl new file mode 100644 index 00000000..41551488 --- /dev/null +++ b/src/complex.jl @@ -0,0 +1,172 @@ +Base.prevfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = prevfloat(x.value) +Base.nextfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = nextfloat(x.value) + +function Base.log(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}} + T1::T = 1.25 + T2::T = 3 + ln2::T = log(convert(T,2)) #0.6931471805599453 + x, y = reim(z) + ρ, k = Base.ssqs(x,y) + ax = abs(x) + ay = abs(y) + if ax < ay + θ, β = ax, ay + else + θ, β = ay, ax + end + if k==0 && (0.5 < β*β) && (β <= T1 || ρ < T2) + ρρ = log1p((β-1)*(β+1)+θ*θ)/2 + else + ρρ = log(ρ)/2 + k*ln2 + end + Complex(ρρ, angle(z)) +end +function Base.tanh(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}} + Ω = prevfloat(typemax(T)) + ξ, η = reim(z) + if isnan(ξ) && η==0 return Complex(ξ, η) end + if 4*abs(ξ) > asinh(Ω) #Overflow? + Complex(copysign(one(T),ξ), + copysign(zero(T),η*(isfinite(η) ? sin(2*abs(η)) : one(η)))) + else + t = tan(η) + β = 1+t*t #sec(η)^2 + s = sinh(ξ) + ρ = sqrt(1 + s*s) #cosh(ξ) + if isinf(t) + Complex(ρ/s,1/t) + else + Complex(β*ρ*s,t)/(1+β*s*s) + end + end +end + +_convert(T, x::Dual) = convert(T, x.value) +function Base._cpow(z::Union{Dual{A,T}, Complex{<:Dual{A,T}}}, p::Union{Dual{B,T}, Complex{<:Dual{B,T}}}) where {T,A,B} + if isreal(p) + pᵣ = real(p) + if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32) + # |p| < typemax(Int32) serves two purposes: it prevents overflow + # when converting p to Int, and it also turns out to be roughly + # the crossover point for exp(p*log(z)) or similar to be faster. + if iszero(pᵣ) # fix signs of imaginary part for z^0 + zer = flipsign(copysign(zero(T),pᵣ), imag(z)) + return Complex(one(T), zer) + end + ip = _convert(Int, pᵣ) + if isreal(z) + zᵣ = real(z) + if ip < 0 + iszero(z) && return Complex(T(NaN),T(NaN)) + re = Base.power_by_squaring(inv(zᵣ), -ip) + im = -imag(z) + else + re = Base.power_by_squaring(zᵣ, ip) + im = imag(z) + end + # slightly tricky to get the correct sign of zero imag. part + return Complex(re, ifelse(iseven(ip) & signbit(zᵣ), -im, im)) + else + return ip < 0 ? Base.power_by_squaring(inv(z), -ip) : Base.power_by_squaring(z, ip) + end + elseif isreal(z) + # (note: if both z and p are complex with ±0.0 imaginary parts, + # the sign of the ±0.0 imaginary part of the result is ambiguous) + if iszero(real(z)) + return pᵣ > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im + elseif real(z) > 0 + return Complex(real(z)^pᵣ, z isa Real ? ifelse(real(z) < 1, -imag(p), imag(p)) : flipsign(imag(z), pᵣ)) + else + zᵣ = real(z) + rᵖ = (-zᵣ)^pᵣ + if isfinite(pᵣ) + # figuring out the sign of 0.0 when p is a complex number + # with zero imaginary part and integer/2 real part could be + # improved here, but it's not clear if it's worth it… + return rᵖ * complex(cospi(pᵣ), flipsign(sinpi(pᵣ),imag(z))) + else + iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0 + return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input + end + end + else + rᵖ = abs(z)^pᵣ + ϕ = pᵣ*angle(z) + end + elseif isreal(z) + iszero(z) && return real(p) > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im + zᵣ = real(z) + pᵣ, pᵢ = reim(p) + if zᵣ > 0 + rᵖ = zᵣ^pᵣ + ϕ = pᵢ*log(zᵣ) + else + r = -zᵣ + θ = copysign(T(π),imag(z)) + rᵖ = r^pᵣ * exp(-pᵢ*θ) + ϕ = pᵣ*θ + pᵢ*log(r) + end + else + pᵣ, pᵢ = reim(p) + r = abs(z) + θ = angle(z) + rᵖ = r^pᵣ * exp(-pᵢ*θ) + ϕ = pᵣ*θ + pᵢ*log(r) + end + + if isfinite(ϕ) + return rᵖ * cis(ϕ) + else + iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0 + return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input + end +end + +function Base.ssqs(x::T, y::T) where T<:Dual + k::Int = 0 + ρ = x*x + y*y + if !isfinite(ρ) && (isinf(x) || isinf(y)) + ρ = convert(T, Inf) + elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ=0 + x * (1<x^3, x->x^0.5, sqrt] + println(" ...testing Complex Valued $OP") + check_complex_jacobian(OP, 4.0+2im) + end +end +end diff --git a/test/runtests.jl b/test/runtests.jl index e19c9527..e610d0b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,3 +31,7 @@ println("done (took $t seconds).") println("Testing miscellaneous functionality...") t = @elapsed include("MiscTest.jl") println("done (took $t seconds).") + +println("Testing complex numbers...") +t = @elapsed include("ComplexTest.jl") +println("done (took $t seconds).")