Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Refactor mixed_canonize! #14

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 52 additions & 26 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,24 @@ end
canonize_site(tn::Chain, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)
canonize_site!(tn::Chain, args...; kwargs...) = canonize_site!(boundary(tn), tn, args...; kwargs...)

# NOTE: in mode == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode = :qr)
# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex!
function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, method = :qr)
left_inds = Symbol[]
right_inds = Symbol[]

virtualind = if direction === :left
site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor"))
push!(right_inds, rightindex(tn, site))
site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor"))
push!(right_inds, leftindex(tn, site))

site == Site(1) || push!(left_inds, leftindex(tn, site))
site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
elseif direction === :right
site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor"))
push!(right_inds, leftindex(tn, site))
site == Site(nsites(tn)) && throw(ArgumentError("Cannot left-canonize right-most tensor"))
push!(right_inds, rightindex(tn, site))

site == Site(nsites(tn)) || push!(left_inds, rightindex(tn, site))
site == Site(1) || push!(left_inds, leftindex(tn, site))
push!(left_inds, Quantum(tn)[site])

only(right_inds)
Expand All @@ -173,12 +173,12 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode =
end

tmpind = gensym(:tmp)
if mode == :qr
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
elseif mode == :svd
if method === :svd
svd!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
elseif method === :qr
qr!(TensorNetwork(tn); left_inds, right_inds, virtualind = tmpind)
else
throw(ArgumentError("Unknown mode=:$mode"))
throw(ArgumentError("Unknown factorization method=:$method"))
end

contract!(TensorNetwork(tn), virtualind)
Expand All @@ -187,6 +187,40 @@ function canonize_site!(::Open, tn::Chain, site::Site; direction::Symbol, mode =
return tn
end

function isleftcanonical(qtn::Chain, site; atol::Real = 1e-12)
right_ind = rightindex(qtn, site)

# we are at right-most site, which cannot be left-canonical
if isnothing(right_ind)
return false
end

# TODO is replace(conj(A)...) copying too much?
tensor = select(qtn, :tensor, site)
contracted = contract(tensor, replace(conj(tensor), right_ind => :new_ind_name))
n = size(tensor, right_ind)
identity_matrix = Matrix(I, n, n)

return isapprox(contracted, identity_matrix; atol)
end

function isrightcanonical(qtn::Chain, site; atol::Real = 1e-12)
left_ind = leftindex(qtn, site)

# we are at left-most site, which cannot be right-canonical
if isnothing(left_ind)
return false
end

#TODO is replace(conj(A)...) copying too much?
tensor = select(qtn, :tensor, site)
contracted = contract(tensor, replace(conj(tensor), left_ind => :new_ind_name))
n = size(tensor, left_ind)
identity_matrix = Matrix(I, n, n)

return isapprox(contracted, identity_matrix; atol)
end

mixed_canonize(tn::Chain, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...)
mixed_canonize!(tn::Chain, args...; kwargs...) = mixed_canonize!(boundary(tn), tn, args...; kwargs...)

Expand All @@ -198,22 +232,14 @@ for i < center the tensors are left-canonical and for i > center the tensors are
and in the center there is a matrix with singular values.
"""
function mixed_canonize!(::Open, tn::Chain, center::Site) # TODO: center could be a range of sites
N = length(sites(tn))

# Left-to-right QR sweep -> get left-canonical tensors
for i in 1:N-1
canonize_site!(tn, Site(i); direction = :left, mode = :qr)
# left-to-right QR sweep (left-canonical tensors)
for i in 1:center.id-1
canonize_site!(tn, Site(i); direction = :right, method = :qr)
mofeing marked this conversation as resolved.
Show resolved Hide resolved
end

# Right-to-left QR sweep -> get right-canonical tensors for i > center
for i in N:-1:1
if i > center.id
canonize_site!(tn, Site(i); direction = :right, mode = :qr)
elseif i == center.id
canonize_site!(tn, Site(i); direction = :left, mode = :svd)
else
canonize_site!(tn, Site(i); direction = :left, mode = :qr)
end
# right-to-left QR sweep (right-canonical tensors)
for i in nsites(tn):-1:center.id+1
canonize_site!(tn, Site(i); direction = :left, method = :qr)
mofeing marked this conversation as resolved.
Show resolved Hide resolved
end

return tn
mofeing marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
4 changes: 2 additions & 2 deletions src/Qrochet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ export Product
include("Ansatz/Chain.jl")
export Chain
export MPS, pMPS, MPO, pMPO
export leftindex, rightindex, canonize_site, canonize_site!
export mixed_canonize, mixed_canonize!
export leftindex, rightindex, canonize_site, canonize_site!, truncate!
export mixed_canonize, mixed_canonize!, isleftcanonical, isrightcanonical

# reexports from Tenet
using Tenet
Expand Down
88 changes: 37 additions & 51 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,69 +54,55 @@
@testset "Canonization" begin
using Tenet

function is_left_canonical(qtn, s::Site)
label_r = rightindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_r => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_r), size(A, label_r)), atol = 1e-12)
catch
return false
end
end

function is_right_canonical(qtn, s::Site)
label_l = leftindex(qtn, s)
A = select(qtn, :tensor, s)
try
contracted = contract(A, replace(conj(A), label_l => :new_ind_name))
return isapprox(contracted, Matrix{Float64}(I, size(A, label_l), size(A, label_l)), atol = 1e-12)
catch
return false
end
end

@testset "canonize_site" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)])

@test_throws ArgumentError canonize_site!(qtn, Site(1); direction = :right)
@test_throws ArgumentError canonize_site!(qtn, Site(3); direction = :left)

for mode in [:qr, :svd]
for i in 1:length(sites(qtn))
if i != 1
canonized = canonize_site(qtn, Site(i); direction = :right, mode = mode)
@test is_right_canonical(canonized, Site(i))
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)
elseif i != length(sites(qtn))
canonized = canonize_site(qtn, Site(i); direction = :left, mode = mode)
@test is_left_canonical(canonized, Site(i))
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)
end
end
@test_throws ArgumentError canonize_site!(qtn, Site(1); direction = :left)
@test_throws ArgumentError canonize_site!(qtn, Site(3); direction = :right)

for method in [:qr, :svd]
canonized = canonize_site(qtn, site"1"; direction = :right, method = method)
@test isleftcanonical(canonized, site"1")
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)

canonized = canonize_site(qtn, site"2"; direction = :right, method = method)
@test isleftcanonical(canonized, site"2")
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)

canonized = canonize_site(qtn, site"2"; direction = :left, method = method)
@test isrightcanonical(canonized, site"2")
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)

canonized = canonize_site(qtn, site"3"; direction = :left, method = method)
@test isrightcanonical(canonized, site"3")
@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
contract(TensorNetwork(qtn)),
)
end

# Ensure that svd creates a new tensor
@test length(tensors(canonize_site(qtn, Site(2); direction = :right, mode = :svd))) == 4
@test length(tensors(canonize_site(qtn, Site(2); direction = :left, method = :svd))) == 4
end

@testset "mixed_canonize" begin
qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)])
canonized = mixed_canonize(qtn, Site(3))

@test is_left_canonical(canonized, Site(1))
@test is_left_canonical(canonized, Site(2))
@test is_left_canonical(canonized, Site(3))
@test is_right_canonical(canonized, Site(4))
@test is_right_canonical(canonized, Site(5))

@test length(tensors(canonized)) == 6 # 5 tensors + 1 singular value matrix
@test isleftcanonical(canonized, Site(1))
@test isleftcanonical(canonized, Site(2))
@test !isleftcanonical(canonized, Site(3)) && !isrightcanonical(canonized, Site(3))
@test isrightcanonical(canonized, Site(4))
@test isrightcanonical(canonized, Site(5))

@test isapprox(
contract(transform(TensorNetwork(canonized), Tenet.HyperindConverter())),
Expand Down
Loading