From 78ec48cc5f1778797a4c26111a3274b48d561611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 12 Sep 2023 16:31:58 +0200 Subject: [PATCH] Test AD rules for `Tensor` contructor --- test/integration/ChainRules_test.jl | 26 +++++++------------------- test/runtests.jl | 1 + 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index d35f5a23..db80830d 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -11,26 +11,14 @@ ) end - @testset "contract" begin - @testset "TensorNetwork" begin - tn = rand(TensorNetwork, 2, 3) + @testset "Tensor" begin + test_frule(Tensor, fill(1.0), Symbol[]) + test_rrule(Tensor, fill(1.0), Symbol[]) - @test frule((nothing, tn), contract, tn) isa Tuple{Tensor{eltype(tn),0},Tensor{eltype(tn),0}} - @test rrule(contract, tn) isa Tuple{Tensor{eltype(tn),0},Function} + test_frule(Tensor, fill(1.0, 2), Symbol[:i]) + test_rrule(Tensor, fill(1.0, 2), Symbol[:i]) - # TODO FiniteDifferences crashes - # test_frule(contract, tn) - # test_rrule(contract, tn) - end - end - - @testset "replace" begin - using UUIDs: uuid4 - - tn = rand(TensorNetwork, 10, 3) - mapping = [label => Symbol(uuid4()) for label in inds(tn)] - - # TODO fails in check_result.jl@161 -> `c_actual = collect(Broadcast.materialize(actual))` - # test_rrule(replace, tn, mapping...) + test_frule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j]) + test_rrule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j]) end end diff --git a/test/runtests.jl b/test/runtests.jl index 4244cb16..032113af 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,6 +16,7 @@ using OMEinsum end @testset "Integration tests" verbose = true begin + include("integration/ChainRules_test.jl") include("integration/BlockArray_test.jl") include("integration/Quac_test.jl") include("integration/Makie_test.jl")