-
Notifications
You must be signed in to change notification settings - Fork 93
Handle commutativity correctly in scalar rules #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dbc3d08
8758584
4add85d
9401209
51180ca
431fe8c
2f14f73
d0d1954
967a1bc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -76,7 +76,7 @@ function rrule(::typeof(hypot), z::Complex) | |||||
end | ||||||
|
||||||
@scalar_rule fma(x, y, z) (y, x, One()) | ||||||
@scalar_rule muladd(x, y, z) (y, x, One()) | ||||||
@scalar_rule muladd(x, y::CommutativeMulNumber, z) (y, x, One()) | ||||||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
@scalar_rule rem2pi(x, r::RoundingMode) (One(), DoesNotExist()) | ||||||
@scalar_rule( | ||||||
mod(x, y), | ||||||
|
@@ -90,47 +90,47 @@ end | |||||
@scalar_rule(ldexp(x, y), (2^y, DoesNotExist())) | ||||||
|
||||||
# Can't multiply though sqrt in acosh because of negative complex case for x | ||||||
@scalar_rule acosh(x) inv(sqrt(x - 1) * sqrt(x + 1)) | ||||||
@scalar_rule acoth(x) inv(1 - x ^ 2) | ||||||
@scalar_rule acsch(x) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) | ||||||
@scalar_rule acosh(x::CommutativeMulNumber) inv(sqrt(x - 1) * sqrt(x + 1)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the logic used to determine if Multiplicative Commutative is needed for univariate functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the extension of a function to the complex numbers and matrices is in the form of a power series, then the non-commutativity becomes a problem for non-commutative numbers, and I restrict it. |
||||||
@scalar_rule acoth(x::CommutativeMulNumber) inv(1 - x ^ 2) | ||||||
@scalar_rule acsch(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 + x ^ -2))) | ||||||
@scalar_rule acsch(x::Real) -(inv(abs(x) * sqrt(1 + x ^ 2))) | ||||||
@scalar_rule asech(x) -(inv(x * sqrt(1 - x ^ 2))) | ||||||
@scalar_rule asinh(x) inv(sqrt(x ^ 2 + 1)) | ||||||
@scalar_rule atanh(x) inv(1 - x ^ 2) | ||||||
@scalar_rule asech(x::CommutativeMulNumber) -(inv(x * sqrt(1 - x ^ 2))) | ||||||
@scalar_rule asinh(x::CommutativeMulNumber) inv(sqrt(x ^ 2 + 1)) | ||||||
@scalar_rule atanh(x::CommutativeMulNumber) inv(1 - x ^ 2) | ||||||
|
||||||
|
||||||
@scalar_rule acosd(x) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) | ||||||
@scalar_rule acotd(x) (-(oftype(x, 180)) / π) / (1 + x ^ 2) | ||||||
@scalar_rule acscd(x) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||
@scalar_rule acosd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / sqrt(1 - x ^ 2) | ||||||
@scalar_rule acotd(x::CommutativeMulNumber) (-(oftype(x, 180)) / π) / (1 + x ^ 2) | ||||||
@scalar_rule acscd(x::CommutativeMulNumber) ((-(oftype(x, 180)) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||
@scalar_rule acscd(x::Real) ((-(oftype(x, 180)) / π) / abs(x)) / sqrt(x ^ 2 - 1) | ||||||
@scalar_rule asecd(x) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||
@scalar_rule asecd(x::CommutativeMulNumber) ((oftype(x, 180) / π) / x ^ 2) / sqrt(1 - x ^ -2) | ||||||
@scalar_rule asecd(x::Real) ((oftype(x, 180) / π) / abs(x)) / sqrt(x ^ 2 - 1) | ||||||
@scalar_rule asind(x) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) | ||||||
@scalar_rule atand(x) (oftype(x, 180) / π) / (1 + x ^ 2) | ||||||
|
||||||
@scalar_rule cot(x) -((1 + Ω ^ 2)) | ||||||
@scalar_rule coth(x) -(csch(x) ^ 2) | ||||||
@scalar_rule cotd(x) -(π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||
@scalar_rule csc(x) -Ω * cot(x) | ||||||
@scalar_rule cscd(x) -(π / oftype(x, 180)) * Ω * cotd(x) | ||||||
@scalar_rule csch(x) -(coth(x)) * Ω | ||||||
@scalar_rule sec(x) Ω * tan(x) | ||||||
@scalar_rule secd(x) (π / oftype(x, 180)) * Ω * tand(x) | ||||||
@scalar_rule sech(x) -(tanh(x)) * Ω | ||||||
|
||||||
@scalar_rule acot(x) -(inv(1 + x ^ 2)) | ||||||
@scalar_rule acsc(x) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) | ||||||
@scalar_rule asind(x::CommutativeMulNumber) (oftype(x, 180) / π) / sqrt(1 - x ^ 2) | ||||||
@scalar_rule atand(x::CommutativeMulNumber) (oftype(x, 180) / π) / (1 + x ^ 2) | ||||||
|
||||||
@scalar_rule cot(x::CommutativeMulNumber) -((1 + Ω ^ 2)) | ||||||
@scalar_rule coth(x::CommutativeMulNumber) -(csch(x) ^ 2) | ||||||
@scalar_rule cotd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||
@scalar_rule csc(x::CommutativeMulNumber) -Ω * cot(x) | ||||||
@scalar_rule cscd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * Ω * cotd(x) | ||||||
@scalar_rule csch(x::CommutativeMulNumber) -(coth(x)) * Ω | ||||||
@scalar_rule sec(x::CommutativeMulNumber) Ω * tan(x) | ||||||
@scalar_rule secd(x::CommutativeMulNumber) (π / oftype(x, 180)) * Ω * tand(x) | ||||||
@scalar_rule sech(x::CommutativeMulNumber) -(tanh(x)) * Ω | ||||||
|
||||||
@scalar_rule acot(x::CommutativeMulNumber) -(inv(1 + x ^ 2)) | ||||||
@scalar_rule acsc(x::CommutativeMulNumber) -(inv(x ^ 2 * sqrt(1 - x ^ -2))) | ||||||
@scalar_rule acsc(x::Real) -(inv(abs(x) * sqrt(x ^ 2 - 1))) | ||||||
@scalar_rule asec(x) inv(x ^ 2 * sqrt(1 - x ^ -2)) | ||||||
@scalar_rule asec(x::CommutativeMulNumber) inv(x ^ 2 * sqrt(1 - x ^ -2)) | ||||||
@scalar_rule asec(x::Real) inv(abs(x) * sqrt(x ^ 2 - 1)) | ||||||
|
||||||
@scalar_rule cosd(x) -(π / oftype(x, 180)) * sind(x) | ||||||
@scalar_rule cospi(x) -π * sinpi(x) | ||||||
@scalar_rule sind(x) (π / oftype(x, 180)) * cosd(x) | ||||||
@scalar_rule sinpi(x) π * cospi(x) | ||||||
@scalar_rule tand(x) (π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||
@scalar_rule cosd(x::CommutativeMulNumber) -(π / oftype(x, 180)) * sind(x) | ||||||
@scalar_rule cospi(x::CommutativeMulNumber) -π * sinpi(x) | ||||||
@scalar_rule sind(x::CommutativeMulNumber) (π / oftype(x, 180)) * cosd(x) | ||||||
@scalar_rule sinpi(x::CommutativeMulNumber) π * cospi(x) | ||||||
@scalar_rule tand(x::CommutativeMulNumber) (π / oftype(x, 180)) * (1 + Ω ^ 2) | ||||||
|
||||||
@scalar_rule sinc(x) cosc(x) | ||||||
@scalar_rule sinc(x::CommutativeMulNumber) cosc(x) | ||||||
|
||||||
@scalar_rule( | ||||||
clamp(x, low, high), | ||||||
|
@@ -140,7 +140,36 @@ end | |||||
), | ||||||
(!(islow | ishigh), islow, ishigh), | ||||||
) | ||||||
@scalar_rule x \ y (-(Ω / x), one(y) / x) | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we move the other scalar rule for |
||||||
# product rule requires special care for arguments where `muladd` is non-commutative | ||||||
function frule((_, Δx, Δy, Δz), ::typeof(muladd), x::Number, y::Number, z::Number) | ||||||
∂xyz = muladd(Δx, y, muladd(x, Δy, Δz)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use MulAddMacro.jl here?
Suggested change
and then i think the macro takes care of rearranging. idk if that really adds clarity or not. What do you thing? |
||||||
return muladd(x, y, z), ∂xyz | ||||||
end | ||||||
|
||||||
function rrule(::typeof(muladd), x::Number, y::Number, z::Number) | ||||||
function muladd_pullback(ΔΩ) | ||||||
return (NO_FIELDS, ΔΩ * y', x' * ΔΩ, ΔΩ) | ||||||
end | ||||||
return muladd(x, y, z), muladd_pullback | ||||||
end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given we have both of these do we actually get a benefit from defining the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, probably not. Will delete. |
||||||
|
||||||
@scalar_rule x::CommutativeMulNumber \ y (-(x \ Ω), x \ one(y)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets not use infix form.
Suggested change
though that isn't a huge amount better. As i mentioned above, do we need this given we have the cases for Number, Number? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not, I'll benchmark to be sure. |
||||||
|
||||||
# quotient rule requires special care for arguments where `\` is non-commutative | ||||||
function frule((_, Δx, Δy), ::typeof(\), x::Number, y::Number) | ||||||
Ω = x \ y | ||||||
return Ω, x \ muladd(-Δx, Ω, Δy) | ||||||
end | ||||||
|
||||||
function rrule(::typeof(\), x::Number, y::Number) | ||||||
Ω = x \ y | ||||||
function ldiv_pullback(ΔΩ) | ||||||
∂y = x' \ ΔΩ | ||||||
return (NO_FIELDS, -(∂y * Ω'), ∂y) | ||||||
end | ||||||
return Ω, ldiv_pullback | ||||||
end | ||||||
|
||||||
function frule((_, ẏ), ::typeof(identity), x) | ||||||
return (x, ẏ) | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -2,37 +2,36 @@ let | |||||||
# Include inside this quote any rules that should have FastMath versions | ||||||||
fastable_ast = quote | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to have a think to be sure that it was fine to add the type constraints to this code since it is going to be transformed to fast math versions. |
||||||||
# Trig-Basics | ||||||||
@scalar_rule cos(x) -(sin(x)) | ||||||||
@scalar_rule sin(x) cos(x) | ||||||||
@scalar_rule tan(x) 1 + Ω ^ 2 | ||||||||
|
||||||||
@scalar_rule cos(x::CommutativeMulNumber) -(sin(x)) | ||||||||
@scalar_rule sin(x::CommutativeMulNumber) cos(x) | ||||||||
@scalar_rule tan(x::CommutativeMulNumber) 1 + Ω ^ 2 | ||||||||
|
||||||||
# Trig-Hyperbolic | ||||||||
@scalar_rule cosh(x) sinh(x) | ||||||||
@scalar_rule sinh(x) cosh(x) | ||||||||
@scalar_rule tanh(x) 1 - Ω ^ 2 | ||||||||
@scalar_rule cosh(x::CommutativeMulNumber) sinh(x) | ||||||||
@scalar_rule sinh(x::CommutativeMulNumber) cosh(x) | ||||||||
@scalar_rule tanh(x::CommutativeMulNumber) 1 - Ω ^ 2 | ||||||||
|
||||||||
# Trig- Inverses | ||||||||
@scalar_rule acos(x) -(inv(sqrt(1 - x ^ 2))) | ||||||||
@scalar_rule asin(x) inv(sqrt(1 - x ^ 2)) | ||||||||
@scalar_rule atan(x) inv(1 + x ^ 2) | ||||||||
@scalar_rule acos(x::CommutativeMulNumber) -(inv(sqrt(1 - x ^ 2))) | ||||||||
@scalar_rule asin(x::CommutativeMulNumber) inv(sqrt(1 - x ^ 2)) | ||||||||
@scalar_rule atan(x::CommutativeMulNumber) inv(1 + x ^ 2) | ||||||||
|
||||||||
# Trig-Multivariate | ||||||||
@scalar_rule atan(y, x) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) | ||||||||
@scalar_rule sincos(x) @setup((sinx, cosx) = Ω) cosx -sinx | ||||||||
@scalar_rule atan(y::Real, x::Real) @setup(u = x ^ 2 + y ^ 2) (x / u, -y / u) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only 2-arg |
||||||||
@scalar_rule sincos(x::CommutativeMulNumber) @setup((sinx, cosx) = Ω) cosx -sinx | ||||||||
|
||||||||
# exponents | ||||||||
@scalar_rule cbrt(x) inv(3 * Ω ^ 2) | ||||||||
@scalar_rule inv(x) -(Ω ^ 2) | ||||||||
@scalar_rule sqrt(x) inv(2Ω) | ||||||||
@scalar_rule exp(x) Ω | ||||||||
@scalar_rule exp10(x) Ω * log(oftype(x, 10)) | ||||||||
@scalar_rule exp2(x) Ω * log(oftype(x, 2)) | ||||||||
@scalar_rule expm1(x) exp(x) | ||||||||
@scalar_rule log(x) inv(x) | ||||||||
@scalar_rule log10(x) inv(x) / log(oftype(x, 10)) | ||||||||
@scalar_rule log1p(x) inv(x + 1) | ||||||||
@scalar_rule log2(x) inv(x) / log(oftype(x, 2)) | ||||||||
@scalar_rule cbrt(x::CommutativeMulNumber) inv(3 * Ω ^ 2) | ||||||||
@scalar_rule inv(x::CommutativeMulNumber) -(Ω ^ 2) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we move this down to be with the general |
||||||||
@scalar_rule sqrt(x::CommutativeMulNumber) inv(2Ω) | ||||||||
@scalar_rule exp(x::CommutativeMulNumber) Ω | ||||||||
@scalar_rule exp10(x::CommutativeMulNumber) Ω * log(oftype(x, 10)) | ||||||||
@scalar_rule exp2(x::CommutativeMulNumber) Ω * log(oftype(x, 2)) | ||||||||
@scalar_rule expm1(x::CommutativeMulNumber) exp(x) | ||||||||
@scalar_rule log(x::CommutativeMulNumber) inv(x) | ||||||||
@scalar_rule log10(x::CommutativeMulNumber) inv(x) / log(oftype(x, 10)) | ||||||||
@scalar_rule log1p(x::CommutativeMulNumber) inv(x + 1) | ||||||||
@scalar_rule log2(x::CommutativeMulNumber) inv(x) / log(oftype(x, 2)) | ||||||||
|
||||||||
|
||||||||
# Unary complex functions | ||||||||
|
@@ -78,7 +77,7 @@ let | |||||||
end | ||||||||
|
||||||||
## angle | ||||||||
function frule((_, Δx), ::typeof(angle), x) | ||||||||
function frule((_, Δx), ::typeof(angle), x::Union{Real, Complex}) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add a comment here saying that this is specificially only for |
||||||||
Ω = angle(x) | ||||||||
# `ifelse` is applied only to denominator to ensure type-stability. | ||||||||
∂Ω = _imagconjtimes(x, Δx) / ifelse(iszero(x), one(x), abs2(x)) | ||||||||
|
@@ -133,9 +132,9 @@ let | |||||||
|
||||||||
@scalar_rule x + y (One(), One()) | ||||||||
@scalar_rule x - y (One(), -1) | ||||||||
@scalar_rule x / y (one(x) / y, -(Ω / y)) | ||||||||
@scalar_rule x / y::CommutativeMulNumber (one(x) / y, -(Ω / y)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. per above
Suggested change
|
||||||||
#log(complex(x)) is required so it gives correct complex answer for x<0 | ||||||||
@scalar_rule(x ^ y, | ||||||||
@scalar_rule(x::CommutativeMulNumber ^ y::CommutativeMulNumber, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per style guide for over length lines. https://github.com/invenia/BlueStyle#method-definitions
Suggested change
|
||||||||
(ifelse(iszero(x), zero(Ω), y * Ω / x), Ω * log(complex(x))), | ||||||||
) | ||||||||
# x^y for x < 0 errors when y is not an integer, but then derivative wrt y | ||||||||
|
@@ -157,14 +156,14 @@ let | |||||||
|
||||||||
# `sign` | ||||||||
|
||||||||
function frule((_, Δx), ::typeof(sign), x) | ||||||||
function frule((_, Δx), ::typeof(sign), x::Number) | ||||||||
n = ifelse(iszero(x), one(x), abs(x)) | ||||||||
Ω = x isa Real ? sign(x) : x / n | ||||||||
∂Ω = Ω * (_imagconjtimes(Ω, Δx) / n) * im | ||||||||
return Ω, ∂Ω | ||||||||
end | ||||||||
|
||||||||
function rrule(::typeof(sign), x) | ||||||||
function rrule(::typeof(sign), x::Number) | ||||||||
n = ifelse(iszero(x), one(x), abs(x)) | ||||||||
Ω = x isa Real ? sign(x) : x / n | ||||||||
function sign_pullback(ΔΩ) | ||||||||
|
@@ -174,6 +173,34 @@ let | |||||||
return Ω, sign_pullback | ||||||||
end | ||||||||
|
||||||||
function frule((_, Δx), ::typeof(inv), x::Number) | ||||||||
Ω = inv(x) | ||||||||
return Ω, -(Ω * Δx * Ω) | ||||||||
end | ||||||||
|
||||||||
function rrule(::typeof(inv), x::Number) | ||||||||
Ω = inv(x) | ||||||||
function inv_pullback(ΔΩ) | ||||||||
return (NO_FIELDS, -(Ω' * ΔΩ * Ω')) | ||||||||
end | ||||||||
return Ω, inv_pullback | ||||||||
end | ||||||||
|
||||||||
# quotient rule requires special care for arguments where `/` is non-commutative | ||||||||
function frule((_, Δx, Δy), ::typeof(/), x::Number, y::Number) | ||||||||
Ω = x / y | ||||||||
return Ω, muladd(-Ω, Δy, Δx) / y | ||||||||
end | ||||||||
|
||||||||
function rrule(::typeof(/), x::Number, y::Number) | ||||||||
Ω = x / y | ||||||||
function rdiv_pullback(ΔΩ) | ||||||||
∂x = ΔΩ / y' | ||||||||
return (NO_FIELDS, ∂x, -(Ω' * ∂x)) | ||||||||
end | ||||||||
return Ω, rdiv_pullback | ||||||||
end | ||||||||
|
||||||||
# product rule requires special care for arguments where `mul` is non-commutative | ||||||||
function frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) | ||||||||
# Optimized version of `Δx .* y .+ x .* Δy`. Also, it is potentially more | ||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -179,4 +179,64 @@ | |||||||||||
frule_test(clamp, (x, Δx), (y, Δy), (z, Δz)) | ||||||||||||
rrule_test(clamp, Δk, (x, x̄), (y, ȳ), (z, z̄)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
@testset "non-commutative number (quaternion)" begin | ||||||||||||
function FiniteDifferences.to_vec(q::Quaternion) | ||||||||||||
function Quaternion_from_vec(q_vec) | ||||||||||||
return Quaternion(q_vec[1], q_vec[2], q_vec[3], q_vec[4]) | ||||||||||||
end | ||||||||||||
return [q.s, q.v1, q.v2, q.v3], Quaternion_from_vec | ||||||||||||
end | ||||||||||||
Comment on lines
+184
to
+189
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should just have this defined in ChainRulesTestUtils.jl Its still type piracy there but it makes sense for us to define proper testing functionality for testing this using ChainRulesTestUtils. Or we could move it to FiniteDifferences.jl that would also be acceptable, and not type-piracy. |
||||||||||||
|
||||||||||||
@testset "unary functions" begin | ||||||||||||
@testset "$f" for f in (+, -, inv, identity, one, zero, transpose, adjoint, real) | ||||||||||||
test_scalar(f, quatrand()) | ||||||||||||
end | ||||||||||||
end | ||||||||||||
|
||||||||||||
@testset "binary functions" begin | ||||||||||||
@testset "$f(::Quaternion, ::Quaternion)" for f in (+, -, *, /, \) | ||||||||||||
x, ẋ, x̄ = quatrand(), quatrand(), quatrand() | ||||||||||||
y, ẏ, ȳ = quatrand(), quatrand(), quatrand() | ||||||||||||
ΔΩ = quatrand() | ||||||||||||
frule_test(f, (x, ẋ), (y, ẏ)) | ||||||||||||
rrule_test(f, ΔΩ, (x, x̄), (y, ȳ)) | ||||||||||||
end | ||||||||||||
@testset "/(::Quaternion, ::Real)" begin | ||||||||||||
x, ẋ = quatrand(), quatrand(), quatrand() | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
y, ẏ = randn(3) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo?
Suggested change
|
||||||||||||
frule_test(/, (x, ẋ), (y, ẏ)) | ||||||||||||
# don't test rrule, because it doesn't project adjoint of y to the reals | ||||||||||||
# so fd won't agree | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have a way to test these? Why don't we run into this problem for Complex numbers? |
||||||||||||
end | ||||||||||||
@testset "\\(::Real, ::Quaternion)" begin | ||||||||||||
x, ẋ, x̄ = randn(3) | ||||||||||||
y, ẏ, ȳ = quatrand(), quatrand(), quatrand() | ||||||||||||
ΔΩ = quatrand() | ||||||||||||
Comment on lines
+213
to
+215
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we are not testing
Suggested change
|
||||||||||||
frule_test(\, (x, ẋ), (y, ẏ)) | ||||||||||||
# don't test rrule, because it doesn't project adjoint of x to the reals | ||||||||||||
# so fd won't agree | ||||||||||||
end | ||||||||||||
end | ||||||||||||
|
||||||||||||
@testset "ternary functions" begin | ||||||||||||
@testset "$f(::Quaternion, ::Quaternion, ::Quaternion)" for f in (muladd,) | ||||||||||||
x, ẋ, x̄ = quatrand(), quatrand(), quatrand() | ||||||||||||
y, ẏ, ȳ = quatrand(), quatrand(), quatrand() | ||||||||||||
z, ż, z̄ = quatrand(), quatrand(), quatrand() | ||||||||||||
ΔΩ = quatrand() | ||||||||||||
frule_test(f, (x, ẋ), (y, ẏ), (z, ż)) | ||||||||||||
rrule_test(f, ΔΩ, (x, x̄), (y, ȳ), (z, z̄)) | ||||||||||||
end | ||||||||||||
@testset "muladd(::Quaternion, ::Real, ::Quaternion)" begin | ||||||||||||
x, ẋ, x̄ = quatrand(), quatrand(), quatrand() | ||||||||||||
y, ẏ, ȳ = randn(3) | ||||||||||||
z, ż, z̄ = quatrand(), quatrand(), quatrand() | ||||||||||||
ΔΩ = quatrand() | ||||||||||||
frule_test(muladd, (x, ẋ), (y, ẏ), (z, ż)) | ||||||||||||
# don't test rrule, because it doesn't project adjoint of y to the reals | ||||||||||||
# so fd won't agree | ||||||||||||
end | ||||||||||||
end | ||||||||||||
end | ||||||||||||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the style guide said to put spaces here.
But now that i look, I am not sure that it mentions it JuliaDiff/BlueStyle#77
Still it is what we do else-where