diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index 1804293..d034873 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -343,6 +343,45 @@ function isrightcanonical(qtn::Chain, site; atol::Real = 1e-12) return isapprox(contracted, identity_matrix; atol) end +canonize(tn::Chain, args...; kwargs...) = canonize!(copy(tn), args...; kwargs...) +canonize!(tn::Chain, args...; kwargs...) = canonize!(boundary(tn), tn, args...; kwargs...) + +""" +canonize(boundary::Boundary, tn::Chain) + +Transform a `Chain` tensor network into the canonical form (Vidal form), that is, +we have the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. +""" +function canonize!(::Open, tn::Chain) + Λ = Tensor[] + + # right-to-left QR sweep, get right-canonical tensors + for i in nsites(tn):-1:2 + canonize_site!(tn, Site(i); direction = :left, method = :qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing + for i in 1:nsites(tn)-1 + canonize_site!(tn, Site(i); direction = :right, method = :svd) + + # extract the singular values and contract them with the next tensor + Λᵢ = pop!(TensorNetwork(tn), select(tn, :between, Site(i), Site(i + 1))) + Aᵢ₊₁ = select(tn, :tensor, Site(i + 1)) + replace!(TensorNetwork(tn), Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ, dims = ())) + push!(Λ, Λᵢ) + end + + for i in 2:nsites(tn) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ + Λᵢ = Λ[i-1] # singular values start between site 1 and 2 + A = select(tn, :tensor, Site(i)) + Γᵢ = contract(A, Tensor(pinv.(parent(Λᵢ)), inds(Λᵢ)), dims = ()) + replace!(TensorNetwork(tn), A => Γᵢ) + push!(TensorNetwork(tn), Λᵢ) + end + + return tn +end + mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...) diff --git a/src/Qrochet.jl b/src/Qrochet.jl index 41e1040..3bee8a0 100644 --- a/src/Qrochet.jl +++ b/src/Qrochet.jl @@ -22,7 +22,8 @@ include("Ansatz/Chain.jl") export Chain export MPS, pMPS, MPO, pMPO export leftindex, rightindex, isleftcanonical, isrightcanonical -export canonize_site, canonize_site!, truncate!, mixed_canonize, mixed_canonize! +export canonize_site, canonize_site!, truncate! +export canonize, canonize!, mixed_canonize, mixed_canonize! export evolve! diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 6cff86c..22b8b10 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -147,6 +147,50 @@ @test length(tensors(canonize_site(qtn, Site(2); direction = :left, method = :svd))) == 4 end + @testset "canonize" begin + using Qrochet: isleftcanonical, isrightcanonical + + qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = canonize(qtn) + + @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors + @test isapprox( + contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())), + contract(TensorNetwork(qtn)), + ) + @test isapprox(norm(qtn), norm(canonized)) + + # Extract the singular values between each adjacent pair of sites in the canonized chain + Λ = [select(canonized, :between, Site(i), Site(i + 1)) for i in 1:4] + @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 + + for i in 1:4 + canonized = canonize(qtn) + + if i == 1 + @test isleftcanonical(canonized, Site(i)) + else + Γᵢ = select(canonized, :tensor, Site(i)) + Λᵢ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i - 1), Site(i))) + replace!(TensorNetwork(canonized), Γᵢ => contract(Λᵢ, Γᵢ; dims = ())) + @test isleftcanonical(canonized, Site(i)) + end + end + + for i in 2:5 + canonized = canonize(qtn) + + if i == 5 + @test isrightcanonical(canonized, Site(i)) + else + Γᵢ = select(canonized, :tensor, Site(i)) + Λᵢ₊₁ = pop!(TensorNetwork(canonized), select(canonized, :between, Site(i), Site(i + 1))) + replace!(TensorNetwork(canonized), Γᵢ => contract(Γᵢ, Λᵢ₊₁; dims = ())) + @test isrightcanonical(canonized, Site(i)) + end + end + end + @testset "mixed_canonize" begin qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) canonized = mixed_canonize(qtn, Site(3))