From b86c970dca698116db65788d238bab2922092c8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Thu, 20 Jun 2024 13:32:12 +0200 Subject: [PATCH] Fix `leftsite` and `rightsite` for `adjoint` `Chain`s (#50) * Fix rightsite and leftsite for adjoint Chains * Add tests for adjoint in Chain --- src/Ansatz/Chain.jl | 13 ++++++++----- test/Ansatz/Chain_test.jl | 13 +++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index abb88aa..bb0ff55 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -183,19 +183,22 @@ function Base.convert(::Type{Chain}, qtn::Product) end leftsite(tn::Chain, site::Site) = leftsite(boundary(tn), tn, site) -leftsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1) : nothing -leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn))) +leftsite(::Open, tn::Chain, site::Site) = + id(site) ∈ range(2, nlanes(tn)) ? Site(id(site) - 1; dual = isdual(site)) : nothing +leftsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) - 1, nlanes(tn)); dual = isdual(site)) rightsite(tn::Chain, site::Site) = rightsite(boundary(tn), tn, site) -rightsite(::Open, tn::Chain, site::Site) = id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1) : nothing -rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn))) +rightsite(::Open, tn::Chain, site::Site) = + id(site) ∈ range(1, nlanes(tn) - 1) ? Site(id(site) + 1; dual = isdual(site)) : nothing +rightsite(::Periodic, tn::Chain, site::Site) = Site(mod1(id(site) + 1, nlanes(tn)); dual = isdual(site)) leftindex(tn::Chain, site::Site) = leftindex(boundary(tn), tn, site) leftindex(::Open, tn::Chain, site::Site) = site == site"1" ? nothing : leftindex(Periodic(), tn, site) leftindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, leftsite(tn, site))) rightindex(tn::Chain, site::Site) = rightindex(boundary(tn), tn, site) -rightindex(::Open, tn::Chain, site::Site) = site == Site(nlanes(tn)) ? nothing : rightindex(Periodic(), tn, site) +rightindex(::Open, tn::Chain, site::Site) = + site == Site(nlanes(tn); dual = isdual(site)) ? nothing : rightindex(Periodic(), tn, site) rightindex(::Periodic, tn::Chain, site::Site) = inds(tn; bond = (site, rightsite(tn, site))) Base.adjoint(chain::Chain) = Chain(adjoint(Quantum(chain)), boundary(chain)) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index d332397..1b7ab95 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -380,5 +380,18 @@ isapprox(norm(qtn), 1.0) end + @testset "adjoint" begin + qtn = rand(Chain, Open, State; n = 5, p = 2, χ = 10) + adjoint_qtn = adjoint(qtn) + + for i in 1:nsites(qtn) + i < nsites(qtn) && + @test rightindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") + i > 1 && @test leftindex(adjoint_qtn, Site(i; dual = true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") + end + + @test isapprox(contract(TensorNetwork(qtn)), contract(TensorNetwork(adjoint_qtn))) + end + # TODO test `evolve!` methods end