Skip to content

Commit

Permalink
Fix inds(; parallelto) on hyperedges
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Sep 9, 2024
1 parent 570fcec commit c53c6e4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,7 +416,10 @@ end
end

@kwmethod function inds(tn::AbstractTensorNetwork; parallelto)
return mapreduce(inds, , tensors(tn; contains=parallelto))
candidates = filter!(!=(parallelto), mapreduce(inds, , tensors(tn; contains=parallelto)))
return filter(candidates) do i
length(tensors(tn; contains=i)) == length(tensors(tn; contains=parallelto))
end
end

@kwmethod function tensors(tn::AbstractTensorNetwork;)
Expand Down
13 changes: 12 additions & 1 deletion test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,17 @@
@test issetequal(inds(tn; set=:open), [:j, :k])
@test issetequal(inds(tn; set=:inner), [:i, :l, :m])
@test issetequal(inds(tn; set=:hyper), [:i])

@testset "parallelto" begin
tn = TensorNetwork([Tensor(zeros(2, 2), [:i, :j]), Tensor(zeros(2, 2), [:i, :j])])
@test issetequal(inds(tn; parallelto=:i), [:j])

tn = TensorNetwork([Tensor(zeros(2, 2), [:i, :j]), Tensor(zeros(2, 2, 2), [:i, :j, :k])])
@test issetequal(inds(tn; parallelto=:i), [:j])

tn = TensorNetwork([Tensor(zeros(2, 2), [:i, :j]), Tensor(zeros(2, 2), [:i, :j]), Tensor(zeros(2), [:j])])
@test isempty(inds(tn; parallelto=:i))
end
end

@testset "size" begin
Expand Down Expand Up @@ -597,7 +608,7 @@
@test issetequal(inds(tn), [:i, :j])
@test size(tn, :i) == 2
@test size(tn, :j) == 2
@test Tenet.ntensors(tn) == 2
@test Tenet.ntensors(tn) == 3
end

@testset "selectdim" begin
Expand Down

0 comments on commit c53c6e4

Please sign in to comment.