Skip to content

Commit

Permalink
Test AD rules for Tensor contructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 12, 2023
1 parent 5f0e058 commit 78ec48c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 19 deletions.
26 changes: 7 additions & 19 deletions test/integration/ChainRules_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,14 @@
)
end

@testset "contract" begin
@testset "TensorNetwork" begin
tn = rand(TensorNetwork, 2, 3)
@testset "Tensor" begin
test_frule(Tensor, fill(1.0), Symbol[])
test_rrule(Tensor, fill(1.0), Symbol[])

@test frule((nothing, tn), contract, tn) isa Tuple{Tensor{eltype(tn),0},Tensor{eltype(tn),0}}
@test rrule(contract, tn) isa Tuple{Tensor{eltype(tn),0},Function}
test_frule(Tensor, fill(1.0, 2), Symbol[:i])
test_rrule(Tensor, fill(1.0, 2), Symbol[:i])

# TODO FiniteDifferences crashes
# test_frule(contract, tn)
# test_rrule(contract, tn)
end
end

@testset "replace" begin
using UUIDs: uuid4

tn = rand(TensorNetwork, 10, 3)
mapping = [label => Symbol(uuid4()) for label in inds(tn)]

# TODO fails in check_result.jl@161 -> `c_actual = collect(Broadcast.materialize(actual))`
# test_rrule(replace, tn, mapping...)
test_frule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j])
test_rrule(Tensor, fill(1.0, 2, 3), Symbol[:i, :j])
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using OMEinsum
end

@testset "Integration tests" verbose = true begin
include("integration/ChainRules_test.jl")
include("integration/BlockArray_test.jl")
include("integration/Quac_test.jl")
include("integration/Makie_test.jl")
Expand Down

0 comments on commit 78ec48c

Please sign in to comment.