From bd64864bb1604039098433a9e46abc11b122d0dd Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 7 Feb 2024 10:46:25 +0800 Subject: [PATCH 1/4] Implement Tangent subtraction --- src/tangent_arithmetic.jl | 1 + test/tangent_types/structural_tangent.jl | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/tangent_arithmetic.jl b/src/tangent_arithmetic.jl index 18ae7b3ad..f79957311 100644 --- a/src/tangent_arithmetic.jl +++ b/src/tangent_arithmetic.jl @@ -146,6 +146,7 @@ Base.:+(a::Dict, d::Tangent{P}) where {P} = merge(+, a, backing(d)) Base.:+(a::StructuralTangent{P}, b::P) where {P} = b + a Base.:-(tangent::StructuralTangent{P}) where {P} = map(-, tangent) +Base.:-(a::StructuralTangent{P}, b::StructuralTangent{P}) where {P} = a + (-b) # We intentionally do not define, `Base.*(::Tangent, ::Tangent)` as that is not meaningful # In general one doesn't have to represent multiplications of 2 tangents diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index c177b05f4..0cd06a16d 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -358,6 +358,20 @@ end @test -1.0 * t == -t end + @test "subtraction" begin + a = Tangent{Foo}(; x=2.0, y=-2.0) + b = Tangent{Foo}(; x=1.0, y=2.0) + @test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0) + + a = Tangent{Foo}(; x=2.0, y=-2.0) + b = Tangent{Foo}(; x=1.0) + @test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0) + + a = Tangent{Tuple{Float64, Float64}}(2.0, 3.0) + b = Tangent{Tuple{Float64, Float64}}(1.0, 1.0) + @test (a - b) == Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + end + @testset "scaling" begin @test ( 2 * Tangent{Foo}(; y=1.5, x=2.5) == From a52a213576d50701e2a98fe30d60bb7baae334a9 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 7 Feb 2024 10:50:14 +0800 Subject: [PATCH 2/4] Also test for mutable tangent --- test/tangent_types/structural_tangent.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 0cd06a16d..40b1e2a56 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -370,6 +370,11 @@ end a = Tangent{Tuple{Float64, Float64}}(2.0, 3.0) b = Tangent{Tuple{Float64, Float64}}(1.0, 1.0) @test (a - b) == Tangent{Tuple{Float64, Float64}}(1.0, 2.0) + + a = MutableTangent{MFoo}(x=1.5, y=1.5) + b = MutableTangent{MFoo}(x=0.5, y=0.5) + @test (a - b) == MutableTangent{MFoo}(x=1.0, y=1.0) + end @testset "scaling" begin From 7230e9403c250c56081a923103e2852664d73586 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 7 Feb 2024 10:52:15 +0800 Subject: [PATCH 3/4] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/tangent_types/structural_tangent.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 40b1e2a56..807c32a33 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -366,15 +366,14 @@ end a = Tangent{Foo}(; x=2.0, y=-2.0) b = Tangent{Foo}(; x=1.0) @test (a - b) == Tangent{Foo}(; x=1.0, y=-2.0) - - a = Tangent{Tuple{Float64, Float64}}(2.0, 3.0) - b = Tangent{Tuple{Float64, Float64}}(1.0, 1.0) - @test (a - b) == Tangent{Tuple{Float64, Float64}}(1.0, 2.0) - a = MutableTangent{MFoo}(x=1.5, y=1.5) - b = MutableTangent{MFoo}(x=0.5, y=0.5) - @test (a - b) == MutableTangent{MFoo}(x=1.0, y=1.0) - + a = Tangent{Tuple{Float64,Float64}}(2.0, 3.0) + b = Tangent{Tuple{Float64,Float64}}(1.0, 1.0) + @test (a - b) == Tangent{Tuple{Float64,Float64}}(1.0, 2.0) + + a = MutableTangent{MFoo}(; x=1.5, y=1.5) + b = MutableTangent{MFoo}(; x=0.5, y=0.5) + @test (a - b) == MutableTangent{MFoo}(; x=1.0, y=1.0) end @testset "scaling" begin From c7e00c7380910b444e66d8fa3ee7313133032502 Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 7 Feb 2024 11:42:18 +0800 Subject: [PATCH 4/4] testset not test --- test/tangent_types/structural_tangent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/tangent_types/structural_tangent.jl b/test/tangent_types/structural_tangent.jl index 807c32a33..d93487703 100644 --- a/test/tangent_types/structural_tangent.jl +++ b/test/tangent_types/structural_tangent.jl @@ -358,7 +358,7 @@ end @test -1.0 * t == -t end - @test "subtraction" begin + @testset "subtraction" begin a = Tangent{Foo}(; x=2.0, y=-2.0) b = Tangent{Foo}(; x=1.0, y=2.0) @test (a - b) == Tangent{Foo}(; x=1.0, y=-4.0)