Skip to content
This repository has been archived by the owner on Jul 7, 2024. It is now read-only.

Commit

Permalink
Add order keyword argument in Chain constructor functions (#47)
Browse files Browse the repository at this point in the history
* Add order kwarg in Chain constructor

* Add tests

* Fix minor typos

* Add minor fixes in Chain constructor

* Enhance tests

* Format code

* Apply @mofeing suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Remove unnecessary helper function

* Apply suggestions from code review

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* Minor aesthetic updates in code

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
jofrevalles and mofeing authored Jun 19, 2024
1 parent 2721de5 commit 837642c
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 47 deletions.
95 changes: 79 additions & 16 deletions src/Ansatz/Chain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,52 +30,98 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...)
Chain(Quantum(tn, sites), args...; kwargs...)
end

function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray})
defaultorder(::Type{Chain}, ::State) = (:o, :l, :r)
defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r)

function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State()))
@assert all(==(3) ndims, arrays) "All arrays must have 3 dimensions"
issetequal(order, defaultorder(Chain, State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:2n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n+mod1(i, n)]
elseif dir == :l
symbols[n+mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray})
function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State()))
@assert ndims(arrays[1]) == 2 "First array must have 2 dimensions"
@assert all(==(3) ndims, arrays[2:end-1]) "All arrays must have 3 dimensions"
@assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions"
issetequal(order, defaultorder(Chain, State())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:2n-1]
symbols = [nextindex() for _ in 1:2n]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [symbols[1], symbols[1+n]])
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]])
filter(x -> x != :r, order)
else
Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]])
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :r
symbols[n+mod1(i, n)]
elseif dir == :l
symbols[n+mod1(i - 1, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)

Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray})
@assert all(==(4) ndims, arrays) "All arrays must have 3 dimensions"
function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator()))
@assert all(==(4) ndims, arrays) "All arrays must have 4 dimensions"
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n]

_tensors = map(enumerate(arrays)) do (i, array)
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
inds = map(order) do dir
if dir == :o
symbols[i]
elseif dir == :i
symbols[i+n]
elseif dir == :l
symbols[2n+mod1(i - 1, n)]
elseif dir == :r
symbols[2n+mod1(i, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
Expand All @@ -84,22 +130,39 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray})
Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary)
end

function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray})
function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator()))
@assert ndims(arrays[1]) == 3 "First array must have 3 dimensions"
@assert all(==(4) ndims, arrays[2:end-1]) "All arrays must have 4 dimensions"
@assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions"
issetequal(order, defaultorder(Chain, Operator())) ||
throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))"))

n = length(arrays)
symbols = [nextindex() for _ in 1:3n-1]

_tensors = map(enumerate(arrays)) do (i, array)
if i == 1
Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]])
_order = if i == 1
filter(x -> x != :l, order)
elseif i == n
Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]])
filter(x -> x != :r, order)
else
Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]])
order
end

inds = map(_order) do dir
if dir == :o
symbols[i]
elseif dir == :i
symbols[i+n]
elseif dir == :l
symbols[2n+mod1(i - 1, n)]
elseif dir == :r
symbols[2n+mod1(i, n)]
else
throw(ArgumentError("Invalid direction: $dir"))
end
end
Tensor(array, inds)
end

sitemap = Dict(Site(i) => symbols[i] for i in 1:n)
Expand Down
193 changes: 162 additions & 31 deletions test/Ansatz/Chain_test.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,166 @@
@testset "Chain ansatz" begin
qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3])
@test socket(qtn) == State()
@test ninputs(qtn) == 0
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3"])
@test boundary(qtn) == Periodic()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing

qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)])
@test socket(qtn) == State()
@test ninputs(qtn) == 0
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3"])
@test boundary(qtn) == Open()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing

qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3])
@test socket(qtn) == Operator()
@test ninputs(qtn) == 3
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Periodic()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing

qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)])
@test socket(qtn) == Operator()
@test ninputs(qtn) == 3
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Open()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing
@testset "Periodic boundary" begin
@testset "State" begin
qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3])
@test socket(qtn) == State()
@test ninputs(qtn) == 0
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3"])
@test boundary(qtn) == Periodic()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing

arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)]
qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r)

@test size(tensors(qtn; at = Site(1))) == (2, 1, 4)
@test size(tensors(qtn; at = Site(2))) == (2, 4, 3)
@test size(tensors(qtn; at = Site(3))) == (2, 3, 1)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3))
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1))
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2))

arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l)
qtn = Chain(State(), Periodic(), arrays, order = [:r, :o, :l])

@test size(tensors(qtn; at = Site(1))) == (4, 2, 1)
@test size(tensors(qtn; at = Site(2))) == (3, 2, 4)
@test size(tensors(qtn; at = Site(3))) == (1, 2, 3)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3))
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1))
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2))

for i in 1:nsites(qtn)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
end
end

@testset "Operator" begin
qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3])
@test socket(qtn) == Operator()
@test ninputs(qtn) == 3
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Periodic()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing

arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r)
qtn = Chain(Operator(), Periodic(), arrays)

@test size(tensors(qtn; at = Site(1))) == (2, 4, 1, 3)
@test size(tensors(qtn; at = Site(2))) == (2, 4, 3, 6)
@test size(tensors(qtn; at = Site(3))) == (2, 4, 6, 1)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3))
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1))
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2))

for i in 1:length(arrays)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4
end

arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i)
qtn = Chain(Operator(), Periodic(), arrays, order = [:r, :o, :l, :i])

@test size(tensors(qtn; at = Site(1))) == (3, 2, 1, 4)
@test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4)
@test size(tensors(qtn; at = Site(3))) == (1, 2, 6, 4)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing

for i in 1:length(arrays)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4
end
end
end

@testset "Open boundary" begin
@testset "State" begin
qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)])
@test socket(qtn) == State()
@test ninputs(qtn) == 0
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3"])
@test boundary(qtn) == Open()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing

arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)]
qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r)

@test size(tensors(qtn; at = Site(1))) == (2, 1)
@test size(tensors(qtn; at = Site(2))) == (2, 1, 3)
@test size(tensors(qtn; at = Site(3))) == (2, 3)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1))
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2))

arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l)
qtn = Chain(State(), Open(), arrays, order = [:r, :o, :l])

@test size(tensors(qtn; at = Site(1))) == (1, 2)
@test size(tensors(qtn; at = Site(2))) == (3, 2, 1)
@test size(tensors(qtn; at = Site(3))) == (2, 3)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing

for i in 1:nsites(qtn)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
end
end
@testset "Operator" begin
qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)])
@test socket(qtn) == Operator()
@test ninputs(qtn) == 3
@test noutputs(qtn) == 3
@test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"])
@test boundary(qtn) == Open()
@test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing

arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r)
qtn = Chain(Operator(), Open(), arrays)

@test size(tensors(qtn; at = Site(1))) == (2, 4, 1)
@test size(tensors(qtn; at = Site(2))) == (2, 4, 1, 3)
@test size(tensors(qtn; at = Site(3))) == (2, 4, 3)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing

for i in 1:length(arrays)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4
end

arrays = [
permutedims(arrays[1], (3, 1, 2)),
permutedims(arrays[2], (4, 1, 3, 2)),
permutedims(arrays[3], (1, 3, 2)),
] # now we have (:r, :o, :l, :i)
qtn = Chain(Operator(), Open(), arrays, order = [:r, :o, :l, :i])

@test size(tensors(qtn; at = Site(1))) == (1, 2, 4)
@test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4)
@test size(tensors(qtn; at = Site(3))) == (2, 3, 4)

@test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing
@test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing
@test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing

for i in 1:length(arrays)
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2
@test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4
end
end
end

@testset "Site" begin
using Qrochet: leftsite, rightsite
Expand Down

0 comments on commit 837642c

Please sign in to comment.