From 837642c2f81fb3ae4c196fa0feca8c7b1cfa4ff5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jofre=20Vall=C3=A8s=20Muns?= <61060572+jofrevalles@users.noreply.github.com> Date: Wed, 19 Jun 2024 13:15:33 +0200 Subject: [PATCH] Add `order` keyword argument in `Chain` constructor functions (#47) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add order kwarg in Chain constructor * Add tests * Fix minor typos * Add minor fixes in Chain constructor * Enhance tests * Format code * Apply @mofeing suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> * Remove unnecessary helper function * Apply suggestions from code review Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> * Minor aesthetic updates in code --------- Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com> --- src/Ansatz/Chain.jl | 95 +++++++++++++++---- test/Ansatz/Chain_test.jl | 193 ++++++++++++++++++++++++++++++++------ 2 files changed, 241 insertions(+), 47 deletions(-) diff --git a/src/Ansatz/Chain.jl b/src/Ansatz/Chain.jl index f33e2ce..4bc699b 100644 --- a/src/Ansatz/Chain.jl +++ b/src/Ansatz/Chain.jl @@ -30,14 +30,30 @@ function Chain(tn::TensorNetwork, sites, args...; kwargs...) Chain(Quantum(tn, sites), args...; kwargs...) end -function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) +defaultorder(::Type{Chain}, ::State) = (:o, :l, :r) +defaultorder(::Type{Chain}, ::Operator) = (:o, :i, :l, :r) + +function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert all(==(3) ∘ ndims, arrays) "All arrays must have 3 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(State())))")) n = length(arrays) symbols = [nextindex() for _ in 1:2n] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n+mod1(i, n)] + elseif dir == :l + symbols[n+mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -45,22 +61,37 @@ function Chain(::State, boundary::Periodic, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, State())) @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" @assert all(==(3) ∘ ndims, arrays[2:end-1]) "All arrays must have 3 dimensions" @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(Chain, State())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, State())))")) n = length(arrays) - symbols = [nextindex() for _ in 1:2n-1] + symbols = [nextindex() for _ in 1:2n] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[1+n]]) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - Tensor(array, [symbols[n], symbols[n+mod1(n - 1, n)]]) + filter(x -> x != :r, order) else - Tensor(array, [symbols[i], symbols[n+mod1(i - 1, n)], symbols[n+mod1(i, n)]]) + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n+mod1(i, n)] + elseif dir == :l + symbols[n+mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -68,14 +99,29 @@ function Chain(::State, boundary::Open, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}) - @assert all(==(4) ∘ ndims, arrays) "All arrays must have 3 dimensions" +function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) + @assert all(==(4) ∘ ndims, arrays) "All arrays must have 4 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n] _tensors = map(enumerate(arrays)) do (i, array) - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + inds = map(order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i+n] + elseif dir == :l + symbols[2n+mod1(i - 1, n)] + elseif dir == :r + symbols[2n+mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) @@ -84,22 +130,39 @@ function Chain(::Operator, boundary::Periodic, arrays::Vector{<:AbstractArray}) Chain(Quantum(TensorNetwork(_tensors), sitemap), boundary) end -function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}) +function Chain(::Operator, boundary::Open, arrays::Vector{<:AbstractArray}; order = defaultorder(Chain, Operator())) @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" @assert all(==(4) ∘ ndims, arrays[2:end-1]) "All arrays must have 4 dimensions" @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(Chain, Operator())) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(Chain, Operator())))")) n = length(arrays) symbols = [nextindex() for _ in 1:3n-1] _tensors = map(enumerate(arrays)) do (i, array) - if i == 1 - Tensor(array, [symbols[1], symbols[n+1], symbols[1+2n]]) + _order = if i == 1 + filter(x -> x != :l, order) elseif i == n - Tensor(array, [symbols[n], symbols[2n], symbols[2n+mod1(n - 1, n)]]) + filter(x -> x != :r, order) else - Tensor(array, [symbols[i], symbols[i+n], symbols[2n+mod1(i - 1, n)], symbols[2n+mod1(i, n)]]) + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i+n] + elseif dir == :l + symbols[2n+mod1(i - 1, n)] + elseif dir == :r + symbols[2n+mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end end + Tensor(array, inds) end sitemap = Dict(Site(i) => symbols[i] for i in 1:n) diff --git a/test/Ansatz/Chain_test.jl b/test/Ansatz/Chain_test.jl index 0eb7bcb..7496d85 100644 --- a/test/Ansatz/Chain_test.jl +++ b/test/Ansatz/Chain_test.jl @@ -1,35 +1,166 @@ @testset "Chain ansatz" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test ninputs(qtn) == 0 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test ninputs(qtn) == 3 - @test noutputs(qtn) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + @testset "Periodic boundary" begin + @testset "State" begin + qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] + qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) + qtn = Chain(State(), Periodic(), arrays, order = [:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (4, 2, 1) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + + @testset "Operator" begin + qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Periodic() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing + + arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) + qtn = Chain(Operator(), Periodic(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 3, 6) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 6, 1) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + + arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Periodic(), arrays, order = [:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(2))) == (6, 2, 3, 4) + @test size(tensors(qtn; at = Site(3))) == (1, 2, 6, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + end + end + + @testset "Open boundary" begin + @testset "State" begin + qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(qtn) == State() + @test ninputs(qtn) == 0 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) + + @test size(tensors(qtn; at = Site(1))) == (2, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + qtn = Chain(State(), Open(), arrays, order = [:r, :o, :l]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1) + @test size(tensors(qtn; at = Site(3))) == (2, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:nsites(qtn) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + end + end + @testset "Operator" begin + qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(qtn) == Operator() + @test ninputs(qtn) == 3 + @test noutputs(qtn) == 3 + @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(qtn) == Open() + @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing + + arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) + qtn = Chain(Operator(), Open(), arrays) + + @test size(tensors(qtn; at = Site(1))) == (2, 4, 1) + @test size(tensors(qtn; at = Site(2))) == (2, 4, 1, 3) + @test size(tensors(qtn; at = Site(3))) == (2, 4, 3) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + + arrays = [ + permutedims(arrays[1], (3, 1, 2)), + permutedims(arrays[2], (4, 1, 3, 2)), + permutedims(arrays[3], (1, 3, 2)), + ] # now we have (:r, :o, :l, :i) + qtn = Chain(Operator(), Open(), arrays, order = [:r, :o, :l, :i]) + + @test size(tensors(qtn; at = Site(1))) == (1, 2, 4) + @test size(tensors(qtn; at = Site(2))) == (3, 2, 1, 4) + @test size(tensors(qtn; at = Site(3))) == (2, 3, 4) + + @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing + @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing + @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing + + for i in 1:length(arrays) + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i))) == 2 + @test size(TensorNetwork(qtn), inds(qtn; at = Site(i; dual = true))) == 4 + end + end + end @testset "Site" begin using Qrochet: leftsite, rightsite