From 771a06fb8d07b4ff464c7c433fea31ebe4824973 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sun, 24 May 2020 01:05:29 -0400 Subject: [PATCH 1/5] new, complex number support --- src/ForwardDiff.jl | 1 + src/complex.jl | 172 ++++++++++++++++++++++++++++++++++++++++++++ test/ComplexTest.jl | 55 ++++++++++++++ test/runtests.jl | 4 ++ 4 files changed, 232 insertions(+) create mode 100644 src/complex.jl create mode 100644 test/ComplexTest.jl diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index cfafe6a5..b697c5b7 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -18,6 +18,7 @@ include("derivative.jl") include("gradient.jl") include("jacobian.jl") include("hessian.jl") +include("complex.jl") export DiffResults diff --git a/src/complex.jl b/src/complex.jl new file mode 100644 index 00000000..57c7c143 --- /dev/null +++ b/src/complex.jl @@ -0,0 +1,172 @@ +Base.prevfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = prevfloat(x.value) +Base.nextfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = nextfloat(x.value) + +function Base.log(z::Complex{T}) where {A, FT<:AbstractFloat, T<:Dual{A,FT}} + T1::T = 1.25 + T2::T = 3 + ln2::T = log(convert(T,2)) #0.6931471805599453 + x, y = reim(z) + ρ, k = Base.ssqs(x,y) + ax = abs(x) + ay = abs(y) + if ax < ay + θ, β = ax, ay + else + θ, β = ay, ax + end + if k==0 && (0.5 < β*β) && (β <= T1 || ρ < T2) + ρρ = log1p((β-1)*(β+1)+θ*θ)/2 + else + ρρ = log(ρ)/2 + k*ln2 + end + Complex(ρρ, angle(z)) +end +function Base.tanh(z::Complex{T}) where {A, FT<:AbstractFloat, T<:Dual{A,FT}} + Ω = prevfloat(typemax(T)) + ξ, η = reim(z) + if isnan(ξ) && η==0 return Complex(ξ, η) end + if 4*abs(ξ) > asinh(Ω) #Overflow? + Complex(copysign(one(T),ξ), + copysign(zero(T),η*(isfinite(η) ? sin(2*abs(η)) : one(η)))) + else + t = tan(η) + β = 1+t*t #sec(η)^2 + s = sinh(ξ) + ρ = sqrt(1 + s*s) #cosh(ξ) + if isinf(t) + Complex(ρ/s,1/t) + else + Complex(β*ρ*s,t)/(1+β*s*s) + end + end +end + +_convert(T, x::Dual) = convert(T, x.value) +function Base._cpow(z::Union{Dual{A,T}, Complex{<:Dual{A,T}}}, p::Union{Dual{B,T}, Complex{<:Dual{B,T}}}) where {T,A,B} + if isreal(p) + pᵣ = real(p) + if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32) + # |p| < typemax(Int32) serves two purposes: it prevents overflow + # when converting p to Int, and it also turns out to be roughly + # the crossover point for exp(p*log(z)) or similar to be faster. + if iszero(pᵣ) # fix signs of imaginary part for z^0 + zer = flipsign(copysign(zero(T),pᵣ), imag(z)) + return Complex(one(T), zer) + end + ip = _convert(Int, pᵣ) + if isreal(z) + zᵣ = real(z) + if ip < 0 + iszero(z) && return Complex(T(NaN),T(NaN)) + re = Base.power_by_squaring(inv(zᵣ), -ip) + im = -imag(z) + else + re = Base.power_by_squaring(zᵣ, ip) + im = imag(z) + end + # slightly tricky to get the correct sign of zero imag. part + return Complex(re, ifelse(iseven(ip) & signbit(zᵣ), -im, im)) + else + return ip < 0 ? Base.power_by_squaring(inv(z), -ip) : Base.power_by_squaring(z, ip) + end + elseif isreal(z) + # (note: if both z and p are complex with ±0.0 imaginary parts, + # the sign of the ±0.0 imaginary part of the result is ambiguous) + if iszero(real(z)) + return pᵣ > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im + elseif real(z) > 0 + return Complex(real(z)^pᵣ, z isa Real ? ifelse(real(z) < 1, -imag(p), imag(p)) : flipsign(imag(z), pᵣ)) + else + zᵣ = real(z) + rᵖ = (-zᵣ)^pᵣ + if isfinite(pᵣ) + # figuring out the sign of 0.0 when p is a complex number + # with zero imaginary part and integer/2 real part could be + # improved here, but it's not clear if it's worth it… + return rᵖ * complex(cospi(pᵣ), flipsign(sinpi(pᵣ),imag(z))) + else + iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0 + return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input + end + end + else + rᵖ = abs(z)^pᵣ + ϕ = pᵣ*angle(z) + end + elseif isreal(z) + iszero(z) && return real(p) > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im + zᵣ = real(z) + pᵣ, pᵢ = reim(p) + if zᵣ > 0 + rᵖ = zᵣ^pᵣ + ϕ = pᵢ*log(zᵣ) + else + r = -zᵣ + θ = copysign(T(π),imag(z)) + rᵖ = r^pᵣ * exp(-pᵢ*θ) + ϕ = pᵣ*θ + pᵢ*log(r) + end + else + pᵣ, pᵢ = reim(p) + r = abs(z) + θ = angle(z) + rᵖ = r^pᵣ * exp(-pᵢ*θ) + ϕ = pᵣ*θ + pᵢ*log(r) + end + + if isfinite(ϕ) + return rᵖ * cis(ϕ) + else + iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0 + return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input + end +end + +function Base.ssqs(x::T, y::T) where T<:Dual + k::Int = 0 + ρ = x*x + y*y + if !isfinite(ρ) && (isinf(x) || isinf(y)) + ρ = convert(T, Inf) + elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ=0 + x * (1<x^3, x->x^0.5, sqrt] + check_complex_jacobian(OP, 4.0+2im) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e19c9527..e610d0b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,3 +31,7 @@ println("done (took $t seconds).") println("Testing miscellaneous functionality...") t = @elapsed include("MiscTest.jl") println("done (took $t seconds).") + +println("Testing complex numbers...") +t = @elapsed include("ComplexTest.jl") +println("done (took $t seconds).") From 27673919b48002373a9b67f0f41287c51c230955 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Sun, 24 May 2020 01:19:17 -0400 Subject: [PATCH 2/5] fix a test --- test/ComplexTest.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/ComplexTest.jl b/test/ComplexTest.jl index 84ce930b..f81f03bf 100644 --- a/test/ComplexTest.jl +++ b/test/ComplexTest.jl @@ -42,8 +42,6 @@ function check_complex_jacobian(f, args...; kwargs...) end @testset "complex" begin - ForwardDiff._ldexp(2.0, -3) == 0.25 - ForwardDiff._ldexp(2.0, 3) == 16.0 for OP in [+, *, /, -, ^] @show OP check_complex_jacobian(OP, 4.0+2im, 2.0+1im) From 5625f3b40313bceda0234c43e4f7c1a0dcd5e696 Mon Sep 17 00:00:00 2001 From: Leo Date: Mon, 25 May 2020 23:41:48 -0400 Subject: [PATCH 3/5] Update src/complex.jl Co-authored-by: Yingbo Ma --- src/complex.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/complex.jl b/src/complex.jl index 57c7c143..800f873d 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -1,7 +1,7 @@ Base.prevfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = prevfloat(x.value) Base.nextfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = nextfloat(x.value) -function Base.log(z::Complex{T}) where {A, FT<:AbstractFloat, T<:Dual{A,FT}} +function Base.log(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}} T1::T = 1.25 T2::T = 3 ln2::T = log(convert(T,2)) #0.6931471805599453 From 836c028613bdbcb0e446c7e39ba68bd07b36e05b Mon Sep 17 00:00:00 2001 From: Leo Date: Mon, 25 May 2020 23:42:03 -0400 Subject: [PATCH 4/5] Update src/complex.jl Co-authored-by: Yingbo Ma --- src/complex.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/complex.jl b/src/complex.jl index 800f873d..41551488 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -21,7 +21,7 @@ function Base.log(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}} end Complex(ρρ, angle(z)) end -function Base.tanh(z::Complex{T}) where {A, FT<:AbstractFloat, T<:Dual{A,FT}} +function Base.tanh(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}} Ω = prevfloat(typemax(T)) ξ, η = reim(z) if isnan(ξ) && η==0 return Complex(ξ, η) end From 7851f5f6efe505b14487966685917d042c7cc4d2 Mon Sep 17 00:00:00 2001 From: GiggleLiu Date: Tue, 26 May 2020 00:03:51 -0400 Subject: [PATCH 5/5] fix tests --- test/ComplexTest.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/ComplexTest.jl b/test/ComplexTest.jl index f81f03bf..f2cc438c 100644 --- a/test/ComplexTest.jl +++ b/test/ComplexTest.jl @@ -1,3 +1,4 @@ +module ComplexTest using ForwardDiff: Dual using Test, ForwardDiff @@ -41,13 +42,15 @@ function check_complex_jacobian(f, args...; kwargs...) @test isapprox(nj, fj, atol=1e-5) end -@testset "complex" begin +@testset "complex instructions" begin for OP in [+, *, /, -, ^] - @show OP + println(" ...testing Complex Valued $OP") check_complex_jacobian(OP, 4.0+2im, 2.0+1im) end for OP in [abs, abs2, real, imag, conj, adjoint, sin, cos, tan, sinh, cosh, tanh, exp, log, angle, x->x^3, x->x^0.5, sqrt] + println(" ...testing Complex Valued $OP") check_complex_jacobian(OP, 4.0+2im) end end +end