Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles committed Jul 3, 2023
1 parent 47ce3ef commit 71f28a4
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/Transformations_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,4 +254,33 @@
@test contract(A, contract(B, C, dims = [])) contract(A_2, contract(B_2, C_2, dims = []))
end
end

@testset "SplitSimplification" begin
using Tenet: SplitSimplification

v1 = Tensor([1, 2, 3], (:i,))
v2 = Tensor([4, 5, 6], (:j, ))
m1 = Tensor(rand(3, 3), (:k, :l))

t1 = contract(v1, v2)
tensor = contract(t1, m1) # Define a tensor which can be splitted in three

tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))])
reduced = transform(tn, SplitSimplification)

# Test that the new tensors in reduced are smaller than the deleted ones
deleted_tensors = filter(t -> labels(t) labels.(tensors(reduced)), tensors(tn))
new_tensors = filter(t -> labels(t) labels.(tensors(tn)), tensors(reduced))

smallest_deleted = minimum(prod size, deleted_tensors)
largest_new = maximum(prod size, new_tensors)

@test smallest_deleted > largest_new

# Test that the resulting contraction is the same as the original
# TODO: Change for: @test contract(reduced) ≈ contract(tn), when is fixed
A_2, B_2, C_2, D_2, E_2 = tensors(reduced)
c_reduced = contract(contract(contract(contract(A_2, B_2), C_2), D_2), E_2)
@test contract(contract(tensors(tn)[1], tensors(tn)[2]), tensors(tn)[3]) c_reduced
end
end

0 comments on commit 71f28a4

Please sign in to comment.