Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Extend evolve! for Chain in canonical form (#31)
Browse files Browse the repository at this point in the history
* Add first implementation

* Fix typo

* Fix pinv atol

* Add Quac integration tests for evolve

* Fix format

* Add missing import

* Fix location of tests

* Remove unnecessary import

* Remove unnecessary import

* Remove tests

* Change function names

* Fix code

* Fix typo

* Replace condition for isnothing function

* Replace condition for isnothing function

* Fix format

* Lower the atol pinv threshold

* Add delete_lambda as kwarg argument in contract

* Refactor code from main functions

* Fix typo

* Fix default delete_lambda kwarg

* Format code

* Fix format

* Fix format

* Add docstrings

* Change name to contract_2sitewf!

* Format code

* Create unpack_2sitewf! function

* Refactor `Site` to N-dimensional coordinates

* Refactor code

* Fix typo

* Fix typo

* Fix typo

* Fix typo

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
3 people authored Mar 25, 2024
1 parent 5be4435 commit 6a189a1
Showing 1 changed file with 88 additions and 8 deletions.
96 changes: 88 additions & 8 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,21 @@ end
Tenet.contract(tn::Chain, query::Symbol, args...; kwargs...) = contract!(copy(tn), Val(query), args...; kwargs...)
Tenet.contract!(tn::Chain, query::Symbol, args...; kwargs...) = contract!(tn, Val(query), args...; kwargs...)

function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left)
"""
Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; direction::Symbol = :left, delete_Λ = true)
For a given [`Chain`](@ref) tensor network, contracts the singular values Λ between two sites `site1` and `site2`.
The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument
specifies whether to delete the singular values tensor after the contraction.
"""
function Tenet.contract!(
tn::Chain,
::Val{:between},
site1::Site,
site2::Site;
direction::Symbol = :left,
delete_Λ = true,
)
Λᵢ = select(tn, :between, site1, site2)
Λᵢ === nothing && return tn

Expand All @@ -244,7 +258,7 @@ function Tenet.contract!(tn::Chain, ::Val{:between}, site1::Site, site2::Site; d
throw(ArgumentError("Unknown direction=:$direction"))
end

delete!(TensorNetwork(tn), Λᵢ)
delete_Λ && delete!(TensorNetwork(tn), Λᵢ)

return tn
end
Expand Down Expand Up @@ -447,7 +461,7 @@ end
Applies a local operator `gate` to the [`Chain`](@ref) tensor network.
"""
function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing)
function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing, iscanonical = false)
# check gate is a valid operator
if !(socket(gate) isa Operator)
throw(ArgumentError("Gate must be an operator, but got $(socket(gate))"))
Expand All @@ -474,7 +488,7 @@ function evolve!(qtn::Chain, gate::Dense; threshold = nothing, maxdim = nothing)
range != gate_inputs && throw(ArgumentError("Gate lanes must be contiguous"))

# TODO check correctly for periodic boundary conditions
evolve_2site!(qtn, gate; threshold, maxdim)
evolve_2site!(qtn, gate; threshold, maxdim, iscanonical = iscanonical)
else
# TODO generalize for more than 2 lanes
throw(ArgumentError("Invalid number of lanes $(nlanes(gate)), maximum is 2"))
Expand Down Expand Up @@ -502,17 +516,18 @@ function evolve_1site!(qtn::Chain, gate::Dense)
contract!(TensorNetwork(qtn), contracting_index)
end

function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim)
# TODO: Maybe rename iscanonical kwarg ?
function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim, iscanonical = false)
# shallow copy to avoid problems if errors in mid execution
gate = copy(gate)

bond = sitel, siter = minmax(outputs(gate)...)
left_inds::Vector{Symbol} = !isnothing(leftindex(qtn, sitel)) ? [leftindex(qtn, sitel)] : Symbol[]
right_inds::Vector{Symbol} = !isnothing(rightindex(qtn, siter)) ? [rightindex(qtn, siter)] : Symbol[]

# contract virtual index
virtualind::Symbol = select(qtn, :bond, bond...)
contract!(TensorNetwork(qtn), virtualind)

iscanonical ? contract_2sitewf!(qtn, bond) : contract!(TensorNetwork(qtn), virtualind)

# reindex contracting index
contracting_inds = [gensym(:tmp) for _ in inputs(gate)]
Expand All @@ -537,8 +552,12 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim)
# decompose using SVD
push!(left_inds, select(qtn, :index, sitel))
push!(right_inds, select(qtn, :index, siter))
svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind)

if iscanonical
unpack_2sitewf!(qtn, bond, left_inds, right_inds, virtualind)
else
svd!(TensorNetwork(qtn); left_inds, right_inds, virtualind)
end
# truncate virtual index
if any(!isnothing, [threshold, maxdim])
truncate!(qtn, bond; threshold, maxdim)
Expand All @@ -547,6 +566,67 @@ function evolve_2site!(qtn::Chain, gate::Dense; threshold, maxdim)
return qtn
end

"""
contract_2sitewf!(ψ::Chain, bond)
For a given [`Chain`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁,
where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ.
"""
function contract_2sitewf!::Chain, bond)
# TODO Check if ψ is in canonical form

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel)
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : select(ψ, :between, siter, Site(id(siter) + 1))

!isnothing(Λᵢ₋₁) && contract!(ψ, :between, Site(id(sitel) - 1), sitel; direction = :right, delete_Λ = false)
!isnothing(Λᵢ₊₁) && contract!(ψ, :between, siter, Site(id(siter) + 1); direction = :left, delete_Λ = false)

contract!(TensorNetwork(ψ), select(ψ, :bond, bond...))

return ψ
end

"""
unpack_2sitewf!(ψ::Chain, bond)
For a given [`Chain`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical
form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`.
"""
function unpack_2sitewf!::Chain, bond, left_inds, right_inds, virtualind)
# TODO Check if ψ is in canonical form

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : select(ψ, :between, Site(id(sitel) - 1), sitel)
Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : select(ψ, :between, siter, Site(id(siter) + 1))

# do svd of the θ tensor
θ = select(ψ, :tensor, sitel)
U, s, Vt = svd(θ; left_inds, right_inds, virtualind)

# contract with the inverse of Λᵢ and Λᵢ₊₂
Γᵢ₋₁ =
isnothing(Λᵢ₋₁) ? U :
contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)), atol = 1e-32)), inds(Λᵢ₋₁)), dims = ())
Γᵢ =
isnothing(Λᵢ₊₁) ? Vt :
contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)), atol = 1e-32)), inds(Λᵢ₊₁)), Vt, dims = ())

delete!(TensorNetwork(ψ), θ)

push!(TensorNetwork(ψ), Γᵢ₋₁)
push!(TensorNetwork(ψ), s)
push!(TensorNetwork(ψ), Γᵢ)

return ψ
end

function expect::Chain, observables)
# contract observable with TN
ϕ = copy(ψ)
Expand Down

0 comments on commit 6a189a1

Please sign in to comment.