diff --git a/Project.toml b/Project.toml index b50964182..cb2715d7e 100644 --- a/Project.toml +++ b/Project.toml @@ -15,6 +15,7 @@ ChainRulesCore = "0.9.12" ChainRulesTestUtils = "0.4.2, 0.5" Compat = "3" FiniteDifferences = "0.10" +Quaternions = "0.4" Reexport = "0.2" Requires = "0.5.2, 1" julia = "1" @@ -24,9 +25,10 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"] +test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Quaternions", "Random", "SpecialFunctions", "Test"] diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 09dce121c..3ae0b80d2 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -22,6 +22,8 @@ if VERSION < v"1.3.0-DEV.142" import LinearAlgebra: dot end +# numbers that we know commute under multiplication +const CommutativeMulNumber = Union{Real,Complex} include("rulesets/Core/core.jl") diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index aa5ffbd28..12bdb9b88 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -76,7 +76,7 @@ function rrule(::typeof(hypot), z::Complex) end @scalar_rule fma(x, y, z) (y, x, One()) -@scalar_rule muladd(x, y, z) (y, x, One()) +@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, One()) @scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist()) @scalar_rule( mod(x, y), @@ -90,47 +90,47 @@ end @scalar_rule(ldexp(x, y), (2^y, DoesNotExist())) # Can't multiply though sqrt in acosh because of negative complex case for x -@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) -@scalar_rule acoth(x) inv(1 - x ^ 2) -@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) +@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) +@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2) +@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) @scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2))) -@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2))) -@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1)) -@scalar_rule atanh(x) inv(1 - x ^ 2) +@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2))) +@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1)) +@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2) -@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) -@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2) -@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) +@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) +@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2) +@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) @scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1) -@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) +@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) @scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1) -@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) -@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2) - -@scalar_rule cot(x) -((1 + Ω ^ 2)) -@scalar_rule coth(x) -(csch(x) ^ 2) -@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2) -@scalar_rule csc(x) -Ω * cot(x) -@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x) -@scalar_rule csch(x) -(coth(x)) * Ω -@scalar_rule sec(x) Ω * tan(x) -@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x) -@scalar_rule sech(x) -(tanh(x)) * Ω - -@scalar_rule acot(x) -(inv(1 + x ^ 2)) -@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) +@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) +@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2) + +@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2)) +@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2) +@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2) +@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x) +@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x) +@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω +@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x) +@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x) +@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω + +@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2)) +@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) @scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1))) -@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2)) +@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2)) @scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1)) -@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x) -@scalar_rule cospi(x) -π * sinpi(x) -@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x) -@scalar_rule sinpi(x) π * cospi(x) -@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2) +@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x) +@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x) +@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x) +@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x) +@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2) -@scalar_rule sinc(x) cosc(x) +@scalar_rule sinc(x::CommutativeMulNumber) cosc(x) @scalar_rule( clamp(x, low, high), @@ -140,7 +140,36 @@ end ), (!(islow | ishigh), islow, ishigh), ) -@scalar_rule x \ y (-(Ω / x), one(y) / x) + +# product rule requires special care for arguments where `muladd` is non-commutative +function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) + ∂xyz = muladd(Δx, y, muladd(x, Δy, Δz)) + return muladd(x, y, z), ∂xyz +end + +function rrule(::typeof(muladd), x::Number, y::Number, z::Number) + function muladd_pullback(ΔΩ) + return (NO_FIELDS, ΔΩ * y', x' * ΔΩ, ΔΩ) + end + return muladd(x, y, z), muladd_pullback +end + +@scalar_rule x::CommutativeMulNumber \ y (-(x \ Ω), x \ one(y)) + +# quotient rule requires special care for arguments where `\` is non-commutative +function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) + Ω = x \ y + return Ω, x \ muladd(-Δx, Ω, Δy) +end + +function rrule(::typeof(\), x::Number, y::Number) + Ω = x \ y + function ldiv_pullback(ΔΩ) + ∂y = x' \ ΔΩ + return (NO_FIELDS, -(∂y * Ω'), ∂y) + end + return Ω, ldiv_pullback +end function frule((_, ẏ), ::typeof(identity), x) return (x, ẏ) diff --git a/src/rulesets/Base/fastmath_able.jl b/src/rulesets/Base/fastmath_able.jl index 477bd5851..536982ef9 100644 --- a/src/rulesets/Base/fastmath_able.jl +++ b/src/rulesets/Base/fastmath_able.jl @@ -2,37 +2,36 @@ let # Include inside this quote any rules that should have FastMath versions fastable_ast = quote # Trig-Basics - @scalar_rule cos(x) -(sin(x)) - @scalar_rule sin(x) cos(x) - @scalar_rule tan(x) 1 + Ω ^ 2 - + @scalar_rule cos(x::CommutativeMulNumber) -(sin(x)) + @scalar_rule sin(x::CommutativeMulNumber) cos(x) + @scalar_rule tan(x::CommutativeMulNumber) 1 + Ω ^ 2 # Trig-Hyperbolic - @scalar_rule cosh(x) sinh(x) - @scalar_rule sinh(x) cosh(x) - @scalar_rule tanh(x) 1 - Ω ^ 2 + @scalar_rule cosh(x::CommutativeMulNumber) sinh(x) + @scalar_rule sinh(x::CommutativeMulNumber) cosh(x) + @scalar_rule tanh(x::CommutativeMulNumber) 1 - Ω ^ 2 # Trig- Inverses - @scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2))) - @scalar_rule asin(x) inv(sqrt(1 - x ^ 2)) - @scalar_rule atan(x) inv(1 + x ^ 2) + @scalar_rule acos(x::CommutativeMulNumber) -(inv(sqrt(1 - x ^ 2))) + @scalar_rule asin(x::CommutativeMulNumber) inv(sqrt(1 - x ^ 2)) + @scalar_rule atan(x::CommutativeMulNumber) inv(1 + x ^ 2) # Trig-Multivariate - @scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) - @scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx + @scalar_rule atan(y::Real, x::Real) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) + @scalar_rule sincos(x::CommutativeMulNumber) @setup((sinx, cosx) = Ω) cosx -sinx # exponents - @scalar_rule cbrt(x) inv(3 * Ω ^ 2) - @scalar_rule inv(x) -(Ω ^ 2) - @scalar_rule sqrt(x) inv(2Ω) - @scalar_rule exp(x) Ω - @scalar_rule exp10(x) Ω * log(oftype(x, 10)) - @scalar_rule exp2(x) Ω * log(oftype(x, 2)) - @scalar_rule expm1(x) exp(x) - @scalar_rule log(x) inv(x) - @scalar_rule log10(x) inv(x) / log(oftype(x, 10)) - @scalar_rule log1p(x) inv(x + 1) - @scalar_rule log2(x) inv(x) / log(oftype(x, 2)) + @scalar_rule cbrt(x::CommutativeMulNumber) inv(3 * Ω ^ 2) + @scalar_rule inv(x::CommutativeMulNumber) -(Ω ^ 2) + @scalar_rule sqrt(x::CommutativeMulNumber) inv(2Ω) + @scalar_rule exp(x::CommutativeMulNumber) Ω + @scalar_rule exp10(x::CommutativeMulNumber) Ω * log(oftype(x, 10)) + @scalar_rule exp2(x::CommutativeMulNumber) Ω * log(oftype(x, 2)) + @scalar_rule expm1(x::CommutativeMulNumber) exp(x) + @scalar_rule log(x::CommutativeMulNumber) inv(x) + @scalar_rule log10(x::CommutativeMulNumber) inv(x) / log(oftype(x, 10)) + @scalar_rule log1p(x::CommutativeMulNumber) inv(x + 1) + @scalar_rule log2(x::CommutativeMulNumber) inv(x) / log(oftype(x, 2)) # Unary complex functions @@ -78,7 +77,7 @@ let end ## angle - function frule((_, Δx), ::typeof(angle), x) + function frule((_, Δx), ::typeof(angle), x::Union{Real, Complex}) Ω = angle(x) # `ifelse` is applied only to denominator to ensure type-stability. ∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x)) @@ -133,9 +132,9 @@ let @scalar_rule x + y (One(), One()) @scalar_rule x - y (One(), -1) - @scalar_rule x / y (one(x) / y, -(Ω / y)) + @scalar_rule x / y::CommutativeMulNumber (one(x) / y, -(Ω / y)) #log(complex(x)) is required so it gives correct complex answer for x<0 - @scalar_rule(x ^ y, + @scalar_rule(x::CommutativeMulNumber ^ y::CommutativeMulNumber, (ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))), ) # x^y for x < 0 errors when y is not an integer, but then derivative wrt y @@ -157,14 +156,14 @@ let # `sign` - function frule((_, Δx), ::typeof(sign), x) + function frule((_, Δx), ::typeof(sign), x::Number) n = ifelse(iszero(x), one(x), abs(x)) Ω = x isa Real ? sign(x) : x / n ∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im return Ω, ∂Ω end - function rrule(::typeof(sign), x) + function rrule(::typeof(sign), x::Number) n = ifelse(iszero(x), one(x), abs(x)) Ω = x isa Real ? sign(x) : x / n function sign_pullback(ΔΩ) @@ -174,6 +173,34 @@ let return Ω, sign_pullback end + function frule((_, Δx), ::typeof(inv), x::Number) + Ω = inv(x) + return Ω, -(Ω * Δx * Ω) + end + + function rrule(::typeof(inv), x::Number) + Ω = inv(x) + function inv_pullback(ΔΩ) + return (NO_FIELDS, -(Ω' * ΔΩ * Ω')) + end + return Ω, inv_pullback + end + + # quotient rule requires special care for arguments where `/` is non-commutative + function frule((_, Δx, Δy), ::typeof(/), x::Number, y::Number) + Ω = x / y + return Ω, muladd(-Ω, Δy, Δx) / y + end + + function rrule(::typeof(/), x::Number, y::Number) + Ω = x / y + function rdiv_pullback(ΔΩ) + ∂x = ΔΩ / y' + return (NO_FIELDS, ∂x, -(Ω' * ∂x)) + end + return Ω, rdiv_pullback + end + # product rule requires special care for arguments where `mul` is non-commutative function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) # Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 21a97c256..a590ec4b7 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -179,4 +179,64 @@ frule_test(clamp, (x, Δx), (y, Δy), (z, Δz)) rrule_test(clamp, Δk, (x, x̄), (y, ȳ), (z, z̄)) end + + @testset "non-commutative number (quaternion)" begin + function FiniteDifferences.to_vec(q::Quaternion) + function Quaternion_from_vec(q_vec) + return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) + end + return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec + end + + @testset "unary functions" begin + @testset "$f" for f in (+, -, inv, identity, one, zero, transpose, adjoint, real) + test_scalar(f, quatrand()) + end + end + + @testset "binary functions" begin + @testset "$f(::Quaternion, ::Quaternion)" for f in (+, -, *, /, \) + x, ẋ, x̄ = quatrand(), quatrand(), quatrand() + y, ẏ, ȳ = quatrand(), quatrand(), quatrand() + ΔΩ = quatrand() + frule_test(f, (x, ẋ), (y, ẏ)) + rrule_test(f, ΔΩ, (x, x̄), (y, ȳ)) + end + @testset "/(::Quaternion, ::Real)" begin + x, ẋ = quatrand(), quatrand(), quatrand() + y, ẏ = randn(3) + frule_test(/, (x, ẋ), (y, ẏ)) + # don't test rrule, because it doesn't project adjoint of y to the reals + # so fd won't agree + end + @testset "\\(::Real, ::Quaternion)" begin + x, ẋ, x̄ = randn(3) + y, ẏ, ȳ = quatrand(), quatrand(), quatrand() + ΔΩ = quatrand() + frule_test(\, (x, ẋ), (y, ẏ)) + # don't test rrule, because it doesn't project adjoint of x to the reals + # so fd won't agree + end + end + + @testset "ternary functions" begin + @testset "$f(::Quaternion, ::Quaternion, ::Quaternion)" for f in (muladd,) + x, ẋ, x̄ = quatrand(), quatrand(), quatrand() + y, ẏ, ȳ = quatrand(), quatrand(), quatrand() + z, ż, z̄ = quatrand(), quatrand(), quatrand() + ΔΩ = quatrand() + frule_test(f, (x, ẋ), (y, ẏ), (z, ż)) + rrule_test(f, ΔΩ, (x, x̄), (y, ȳ), (z, z̄)) + end + @testset "muladd(::Quaternion, ::Real, ::Quaternion)" begin + x, ẋ, x̄ = quatrand(), quatrand(), quatrand() + y, ẏ, ȳ = randn(3) + z, ż, z̄ = quatrand(), quatrand(), quatrand() + ΔΩ = quatrand() + frule_test(muladd, (x, ẋ), (y, ẏ), (z, ż)) + # don't test rrule, because it doesn't project adjoint of y to the reals + # so fd won't agree + end + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 154306904..78b2abf12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using FiniteDifferences using LinearAlgebra using LinearAlgebra.BLAS using LinearAlgebra: dot +using Quaternions using Random using Statistics using Test