Skip to content

Handle commutativity correctly in scalar rules #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ ChainRulesCore = "0.9.12"
ChainRulesTestUtils = "0.4.2, 0.5"
Compat = "3"
FiniteDifferences = "0.10"
Quaternions = "0.4"
Reexport = "0.2"
Requires = "0.5.2, 1"
julia = "1"
Expand All @@ -24,9 +25,10 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Quaternions = "94ee1d12-ae83-5a48-8b1c-48b8ff168ae0"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Random", "SpecialFunctions", "Test"]
test = ["ChainRulesTestUtils", "Compat", "FiniteDifferences", "NaNMath", "Quaternions", "Random", "SpecialFunctions", "Test"]
2 changes: 2 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ if VERSION < v"1.3.0-DEV.142"
import LinearAlgebra: dot
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the style guide said to put spaces here.
But now that i look, I am not sure that it mentions it JuliaDiff/BlueStyle#77
Still it is what we do else-where

Suggested change
const CommutativeMulNumber = Union{Real,Complex}
const CommutativeMulNumber = Union{Real, Complex}


include("rulesets/Core/core.jl")

Expand Down
97 changes: 63 additions & 34 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ function rrule(::typeof(hypot), z::Complex)
end

@scalar_rule fma(x, y, z) (y, x, One())
@scalar_rule muladd(x, y, z) (y, x, One())
@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, One())
@scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist())
@scalar_rule(
mod(x, y),
Expand All @@ -90,47 +90,47 @@ end
@scalar_rule(ldexp(x, y), (2^y, DoesNotExist()))

# Can't multiply though sqrt in acosh because of negative complex case for x
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1))
@scalar_rule acoth(x) inv(1 - x ^ 2)
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the logic used to determine if Multiplicative Commutative is needed for univariate functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the extension of a function to the complex numbers and matrices is in the form of a power series, then the non-commutativity becomes a problem for non-commutative numbers, and I restrict it.

@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2)
@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2)))
@scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2)))
@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2)))
@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1))
@scalar_rule atanh(x) inv(1 - x ^ 2)
@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2)))
@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1))
@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2)


@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2)
@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2)
@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2)
@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2)
@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2)
@scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1)
@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2)
@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2)

@scalar_rule cot(x) -((1 + Ω ^ 2))
@scalar_rule coth(x) -(csch(x) ^ 2)
@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule csc(x) -Ω * cot(x)
@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x)
@scalar_rule csch(x) -(coth(x)) * Ω
@scalar_rule sec(x) Ω * tan(x)
@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x)
@scalar_rule sech(x) -(tanh(x)) * Ω

@scalar_rule acot(x) -(inv(1 + x ^ 2))
@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2)))
@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2)
@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2)

@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2))
@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2)
@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x)
@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x)
@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω
@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x)
@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x)
@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω

@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2))
@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2)))
@scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1)))
@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2))
@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2))
@scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1))

@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x)
@scalar_rule cospi(x) -π * sinpi(x)
@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x)
@scalar_rule sinpi(x) π * cospi(x)
@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2)
@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x)
@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x)
@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x)
@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x)
@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2)

@scalar_rule sinc(x) cosc(x)
@scalar_rule sinc(x::CommutativeMulNumber) cosc(x)

@scalar_rule(
clamp(x, low, high),
Expand All @@ -140,7 +140,36 @@ end
),
(!(islow | ishigh), islow, ishigh),
)
@scalar_rule x \ y (-(Ω / x), one(y) / x)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the other scalar rule for muladd to be here also?

# product rule requires special care for arguments where `muladd` is non-commutative
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number)
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use MulAddMacro.jl here?
Its a dependency of ChainRulesCore already.
I think then we can just write:

Suggested change
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz))
@muladd ∂xyz = Δx*y + x*Δy + Δz

and then i think the macro takes care of rearranging.

idk if that really adds clarity or not. What do you thing?

return muladd(x, y, z), ∂xyz
end

function rrule(::typeof(muladd), x::Number, y::Number, z::Number)
function muladd_pullback(ΔΩ)
return (NO_FIELDS, ΔΩ * y', x' * ΔΩ, ΔΩ)
end
return muladd(x, y, z), muladd_pullback
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given we have both of these do we actually get a benefit from defining the @scalar_rule one?
I suspect it optimizes to be identical.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably not. Will delete.


@scalar_rule x::CommutativeMulNumber \ y (-(x \ Ω), x \ one(y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets not use infix form.
I don't think we should have been using it in the first place, but i think it is particularly hard to read once adding type constraints.

Suggested change
@scalar_rule x::CommutativeMulNumber \ y (-(x \ Ω), x \ one(y))
@scalar_rule \(x::CommutativeMulNumber, y) (-(x \ Ω), x \ one(y))

though that isn't a huge amount better.

As i mentioned above, do we need this given we have the cases for Number, Number?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not, I'll benchmark to be sure.


# quotient rule requires special care for arguments where `\` is non-commutative
function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number)
Ω = x \ y
return Ω, x \ muladd(-Δx, Ω, Δy)
end

function rrule(::typeof(\), x::Number, y::Number)
Ω = x \ y
function ldiv_pullback(ΔΩ)
∂y = x' \ ΔΩ
return (NO_FIELDS, -(∂y * Ω'), ∂y)
end
return Ω, ldiv_pullback
end

function frule((_, ẏ), ::typeof(identity), x)
return (x, ẏ)
Expand Down
83 changes: 55 additions & 28 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,36 @@ let
# Include inside this quote any rules that should have FastMath versions
fastable_ast = quote
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to have a think to be sure that it was fine to add the type constraints to this code since it is going to be transformed to fast math versions.
I think it is fine.

# Trig-Basics
@scalar_rule cos(x) -(sin(x))
@scalar_rule sin(x) cos(x)
@scalar_rule tan(x) 1 + Ω ^ 2

@scalar_rule cos(x::CommutativeMulNumber) -(sin(x))
@scalar_rule sin(x::CommutativeMulNumber) cos(x)
@scalar_rule tan(x::CommutativeMulNumber) 1 + Ω ^ 2

# Trig-Hyperbolic
@scalar_rule cosh(x) sinh(x)
@scalar_rule sinh(x) cosh(x)
@scalar_rule tanh(x) 1 - Ω ^ 2
@scalar_rule cosh(x::CommutativeMulNumber) sinh(x)
@scalar_rule sinh(x::CommutativeMulNumber) cosh(x)
@scalar_rule tanh(x::CommutativeMulNumber) 1 - Ω ^ 2

# Trig- Inverses
@scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2)))
@scalar_rule asin(x) inv(sqrt(1 - x ^ 2))
@scalar_rule atan(x) inv(1 + x ^ 2)
@scalar_rule acos(x::CommutativeMulNumber) -(inv(sqrt(1 - x ^ 2)))
@scalar_rule asin(x::CommutativeMulNumber) inv(sqrt(1 - x ^ 2))
@scalar_rule atan(x::CommutativeMulNumber) inv(1 + x ^ 2)

# Trig-Multivariate
@scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u)
@scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx
@scalar_rule atan(y::Real, x::Real) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only 2-arg atan implemented in Base is for real args. In principle, someone could define a 2-arg atan for a different number type, and this rule wouldn't work.

@scalar_rule sincos(x::CommutativeMulNumber) @setup((sinx, cosx) = Ω) cosx -sinx

# exponents
@scalar_rule cbrt(x) inv(3 * Ω ^ 2)
@scalar_rule inv(x) -(Ω ^ 2)
@scalar_rule sqrt(x) inv(2Ω)
@scalar_rule exp(x) Ω
@scalar_rule exp10(x) Ω * log(oftype(x, 10))
@scalar_rule exp2(x) Ω * log(oftype(x, 2))
@scalar_rule expm1(x) exp(x)
@scalar_rule log(x) inv(x)
@scalar_rule log10(x) inv(x) / log(oftype(x, 10))
@scalar_rule log1p(x) inv(x + 1)
@scalar_rule log2(x) inv(x) / log(oftype(x, 2))
@scalar_rule cbrt(x::CommutativeMulNumber) inv(3 * Ω ^ 2)
@scalar_rule inv(x::CommutativeMulNumber) -(Ω ^ 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this down to be with the general Number case?

@scalar_rule sqrt(x::CommutativeMulNumber) inv(2Ω)
@scalar_rule exp(x::CommutativeMulNumber) Ω
@scalar_rule exp10(x::CommutativeMulNumber) Ω * log(oftype(x, 10))
@scalar_rule exp2(x::CommutativeMulNumber) Ω * log(oftype(x, 2))
@scalar_rule expm1(x::CommutativeMulNumber) exp(x)
@scalar_rule log(x::CommutativeMulNumber) inv(x)
@scalar_rule log10(x::CommutativeMulNumber) inv(x) / log(oftype(x, 10))
@scalar_rule log1p(x::CommutativeMulNumber) inv(x + 1)
@scalar_rule log2(x::CommutativeMulNumber) inv(x) / log(oftype(x, 2))


# Unary complex functions
Expand Down Expand Up @@ -78,7 +77,7 @@ let
end

## angle
function frule((_, Δx), ::typeof(angle), x)
function frule((_, Δx), ::typeof(angle), x::Union{Real, Complex})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a comment here saying that this is specificially only for Real and Complex
since angle isn't mathematically well defined for anything else,
even if it is a CommutativeMulNumber

Ω = angle(x)
# `ifelse` is applied only to denominator to ensure type-stability.
∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x))
Expand Down Expand Up @@ -133,9 +132,9 @@ let

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x / y (one(x) / y, -(Ω / y))
@scalar_rule x / y::CommutativeMulNumber (one(x) / y, -(Ω / y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per above

Suggested change
@scalar_rule x / y::CommutativeMulNumber (one(x) / y, -/ y))
@scalar_rule /(x, y::CommutativeMulNumber) (one(x) / y, -/ y))

#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
@scalar_rule(x::CommutativeMulNumber ^ y::CommutativeMulNumber,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per style guide for over length lines. https://github.com/invenia/BlueStyle#method-definitions

Suggested change
@scalar_rule(x::CommutativeMulNumber ^ y::CommutativeMulNumber,
@scalar_rule(
^(x::CommutativeMulNumber, y::CommutativeMulNumber),

(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))),
)
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y
Expand All @@ -157,14 +156,14 @@ let

# `sign`

function frule((_, Δx), ::typeof(sign), x)
function frule((_, Δx), ::typeof(sign), x::Number)
n = ifelse(iszero(x), one(x), abs(x))
Ω = x isa Real ? sign(x) : x / n
∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im
return Ω, ∂Ω
end

function rrule(::typeof(sign), x)
function rrule(::typeof(sign), x::Number)
n = ifelse(iszero(x), one(x), abs(x))
Ω = x isa Real ? sign(x) : x / n
function sign_pullback(ΔΩ)
Expand All @@ -174,6 +173,34 @@ let
return Ω, sign_pullback
end

function frule((_, Δx), ::typeof(inv), x::Number)
Ω = inv(x)
return Ω, -(Ω * Δx * Ω)
end

function rrule(::typeof(inv), x::Number)
Ω = inv(x)
function inv_pullback(ΔΩ)
return (NO_FIELDS, -(Ω' * ΔΩ * Ω'))
end
return Ω, inv_pullback
end

# quotient rule requires special care for arguments where `/` is non-commutative
function frule((_, Δx, Δy), ::typeof(/), x::Number, y::Number)
Ω = x / y
return Ω, muladd(-Ω, Δy, Δx) / y
end

function rrule(::typeof(/), x::Number, y::Number)
Ω = x / y
function rdiv_pullback(ΔΩ)
∂x = ΔΩ / y'
return (NO_FIELDS, ∂x, -(Ω' * ∂x))
end
return Ω, rdiv_pullback
end

# product rule requires special care for arguments where `mul` is non-commutative
function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number)
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more
Expand Down
60 changes: 60 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,64 @@
frule_test(clamp, (x, Δx), (y, Δy), (z, Δz))
rrule_test(clamp, Δk, (x, x̄), (y, ȳ), (z, z̄))
end

@testset "non-commutative number (quaternion)" begin
function FiniteDifferences.to_vec(q::Quaternion)
function Quaternion_from_vec(q_vec)
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4])
end
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec
end
Comment on lines +184 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just have this defined in ChainRulesTestUtils.jl
Whole point of that package is to avoid defining re-usable stuff inside the tests.

Its still type piracy there but it makes sense for us to define proper testing functionality for testing this using ChainRulesTestUtils.
JuliaDiff/ChainRulesTestUtils.jl#61

Or we could move it to FiniteDifferences.jl that would also be acceptable, and not type-piracy.


@testset "unary functions" begin
@testset "$f" for f in (+, -, inv, identity, one, zero, transpose, adjoint, real)
test_scalar(f, quatrand())
end
end

@testset "binary functions" begin
@testset "$f(::Quaternion, ::Quaternion)" for f in (+, -, *, /, \)
x, ẋ, x̄ = quatrand(), quatrand(), quatrand()
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
frule_test(f, (x, ẋ), (y, ẏ))
rrule_test(f, ΔΩ, (x, x̄), (y, ȳ))
end
@testset "/(::Quaternion, ::Real)" begin
x, ẋ = quatrand(), quatrand(), quatrand()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x, ẋ = quatrand(), quatrand(), quatrand()
x, ẋ = quatrand(), quatrand()

y, ẏ = randn(3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

Suggested change
y, ẏ = randn(3)
y, ẏ = randn(2)

frule_test(/, (x, ẋ), (y, ẏ))
# don't test rrule, because it doesn't project adjoint of y to the reals
# so fd won't agree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a way to test these?

Why don't we run into this problem for Complex numbers?

end
@testset "\\(::Real, ::Quaternion)" begin
x, ẋ, x̄ = randn(3)
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
Comment on lines +213 to +215
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are not testing rrule don't need these/

Suggested change
x, ẋ, x̄ = randn(3)
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
x, ẋ = randn(2)
y, ẏ = quatrand(), quatrand()

frule_test(\, (x, ẋ), (y, ẏ))
# don't test rrule, because it doesn't project adjoint of x to the reals
# so fd won't agree
end
end

@testset "ternary functions" begin
@testset "$f(::Quaternion, ::Quaternion, ::Quaternion)" for f in (muladd,)
x, ẋ, x̄ = quatrand(), quatrand(), quatrand()
y, ẏ, ȳ = quatrand(), quatrand(), quatrand()
z, ż, z̄ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
frule_test(f, (x, ẋ), (y, ẏ), (z, ż))
rrule_test(f, ΔΩ, (x, x̄), (y, ȳ), (z, z̄))
end
@testset "muladd(::Quaternion, ::Real, ::Quaternion)" begin
x, ẋ, x̄ = quatrand(), quatrand(), quatrand()
y, ẏ, ȳ = randn(3)
z, ż, z̄ = quatrand(), quatrand(), quatrand()
ΔΩ = quatrand()
frule_test(muladd, (x, ẋ), (y, ẏ), (z, ż))
# don't test rrule, because it doesn't project adjoint of y to the reals
# so fd won't agree
end
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using FiniteDifferences
using LinearAlgebra
using LinearAlgebra.BLAS
using LinearAlgebra: dot
using Quaternions
using Random
using Statistics
using Test
Expand Down