From 146838d6fe3b35363df528d8d18c256831803696 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Wed, 4 Oct 2023 23:01:28 +0200 Subject: [PATCH 01/29] Refactor `TensorNetwork` to `@class` --- Project.toml | 1 + ext/TenetChainRulesCoreExt.jl | 17 +-- src/Helpers.jl | 10 -- src/TensorNetwork.jl | 254 +++++++++++++--------------------- src/Transformations.jl | 31 +++-- test/Helpers_test.jl | 19 --- test/TensorNetwork_test.jl | 56 ++++---- test/Transformations_test.jl | 16 +-- 8 files changed, 157 insertions(+), 247 deletions(-) diff --git a/Project.toml b/Project.toml index 737bdbe3..f47fde21 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.2.0" [deps] Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" +Classes = "1a9c1350-211b-5766-99cd-4544d885a0d1" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index b1235f8d..e86b8e17 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -1,6 +1,7 @@ module TenetChainRulesCoreExt using Tenet +using Classes using ChainRulesCore function ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor} @@ -26,29 +27,29 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull @non_differentiable intersect(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) @non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) -function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork} +function ChainRulesCore.ProjectTo(tn::T) where {T<:absclass(TensorNetwork)} ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata) end -function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetwork} +function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:absclass(TensorNetwork)} dx.tensors isa NoTangent && return NoTangent() Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors)) end -function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz} +function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:absclass(TensorNetwork)} # TODO match tensors by indices tensors = map(+, x.tensors, Δ.tensors) - TensorNetwork{A}(tensors; x.metadata...) + T(tensors, ...) # TODO fix how to pass metadata end -function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...) - T(tensors; metadata...), Tangent{TensorNetwork}(tensors = Δ) +function ChainRulesCore.frule((_, Δ), T::Type{<:absclass(TensorNetwork)}, tensors) + T(tensors), Tangent{TensorNetwork}(tensors = Δ) end TensorNetwork_pullback(Δ::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors) TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ)) -function ChainRulesCore.rrule(T::Type{<:TensorNetwork}, tensors; metadata...) - T(tensors; metadata...), TensorNetwork_pullback +function ChainRulesCore.rrule(T::Type{TensorNetwork}, tensors) + T(tensors), TensorNetwork_pullback end end \ No newline at end of file diff --git a/src/Helpers.jl b/src/Helpers.jl index 67cfb344..3aa502e1 100644 --- a/src/Helpers.jl +++ b/src/Helpers.jl @@ -67,16 +67,6 @@ julia> letter(20204) letter(i) = Iterators.drop(Iterators.filter(isletter, Iterators.map(Char, 1:2^21-1)), i - 1) |> iterate |> first |> Symbol -Base.merge(@nospecialize(A::Type{<:NamedTuple}), @nospecialize(Bs::Type{<:NamedTuple}...)) = NamedTuple{ - foldl((acc, B) -> (acc..., B...), Iterators.map(fieldnames, Bs); init = fieldnames(A)), - foldl((acc, B) -> Tuple{fieldtypes(acc)...,B...}, Iterators.map(fieldtypes, Bs); init = A), -} - -function superansatzes(T) - S = supertype(T) - return T === Ansatz ? (T,) : (T, superansatzes(S)...) -end - # NOTE from https://stackoverflow.com/q/54652787 function nonunique(x) uniqueindexes = indexin(unique(x), x) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 46762f34..86b50b1e 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -3,109 +3,54 @@ using Random using EinExprs using OMEinsum using ValSplit +using Classes """ - Ansatz - -Type representing the predefined form of the Tensor Network. -""" -abstract type Ansatz end - -""" - Arbitrary - -Tensor Networks without a predefined form. -""" -abstract type Arbitrary <: Ansatz end - -# NOTE currently, these are implementation details -function checkmeta end -function metadata end - -""" - TensorNetwork{Ansatz} + TensorNetwork Graph of interconnected tensors, representing a multilinear equation. Graph vertices represent tensors and graph edges, tensor indices. """ -struct TensorNetwork{A<:Ansatz,M<:NamedTuple} +@class TensorNetwork begin indices::Dict{Symbol,Vector{Int}} tensors::Vector{Tensor} - metadata::M - - function TensorNetwork{A}(tensors; metadata...) where {A} - indices = reduce(enumerate(tensors); init = Dict{Symbol,Vector{Int}}([])) do dict, (i, tensor) - mergewith(vcat, dict, Dict([index => [i] for index in inds(tensor)])) - end - - # Check for inconsistent dimensions - for (index, idxs) in indices - allequal(Iterators.map(i -> size(tensors[i], index), idxs)) || - throw(DimensionMismatch("Different sizes specified for index $index")) - end - - M = Tenet.metadata(A) - metadata = M((; metadata...)) - - tn = new{A,M}(indices, tensors, metadata) - - checkansatz(tn) - return tn - end end -TensorNetwork{A}(; metadata...) where {A<:Ansatz} = TensorNetwork{A}(Tensor[]; metadata...) - -# ansatz defaults to `Arbitrary` -TensorNetwork(args...; kwargs...) = TensorNetwork{Arbitrary}(args...; kwargs...) - -# TODO maybe rename it as `convert` method? -TensorNetwork{A}(tn::TensorNetwork{B}; metadata...) where {A,B} = - TensorNetwork{A}(tensors(tn); merge(tn.metadata, metadata)...) +TensorNetwork() = TensorNetwork(Tensor[]) +function TensorNetwork(tensors) + indices = reduce(enumerate(tensors); init = Dict{Symbol,Vector{Int}}([])) do dict, (i, tensor) + mergewith(vcat, dict, Dict([index => [i] for index in inds(tensor)])) + end -# TODO do sth to skip checkansatz? like @inbounds -function checkansatz(tn::TensorNetwork{A}) where {A<:Ansatz} - for T in superansatzes(A) - checkmeta(T, tn) || throw(ErrorException("\"$T\" metadata is not valid")) + # check for inconsistent dimensions + for (index, idxs) in indices + allequal(Iterators.map(i -> size(tensors[i], index), idxs)) || + throw(DimensionMismatch("Different sizes specified for index $index")) end + + return TensorNetwork(indices, tensors) end -checkmeta(::Type{<:Ansatz}, ::TensorNetwork) = true -checkmeta(tn::TensorNetwork{T}) where {T<:Ansatz} = all(A -> checkmeta(A, tn), superansatzes(T)) +# TODO maybe rename it as `convert` method? +# TensorNetwork{A}(tn::absclass(TensorNetwork){B}; metadata...) where {A,B} = +# TensorNetwork{A}(tensors(tn); merge(tn.metadata, metadata)...) -metadata(::Type{<:Ansatz}) = NamedTuple{(),Tuple{}} -metadata(T::Type{<:Arbitrary}) = metadata(supertype(T)) +Base.copy(tn::TensorNetwork) = TensorNetwork(copy(tensors(tn))) -Base.summary(io::IO, x::TensorNetwork) = print(io, "$(length(x))-tensors $(typeof(x))") -Base.show(io::IO, tn::TensorNetwork) = +Base.summary(io::IO, x::absclass(TensorNetwork)) = print(io, "$(length(x))-tensors $(typeof(x))") +Base.show(io::IO, tn::absclass(TensorNetwork)) = print(io, "$(typeof(tn))(#tensors=$(length(tn.tensors)), #inds=$(length(tn.indices)))") """ - copy(tn::TensorNetwork) - -Return a shallow copy of the [`TensorNetwork`](@ref). -""" -Base.copy(tn::TensorNetwork{A}) where {A} = TensorNetwork{A}(copy(tn.tensors); deepcopy(tn.metadata)...) - -""" - ansatz(::TensorNetwork{Ansatz}) - ansatz(::Type{<:TensorNetwork{Ansatz}}) - -Return the `Ansatz` of a [`TensorNetwork`](@ref) type or object. -""" -ansatz(::Type{<:TensorNetwork{A}}) where {A} = A -ansatz(::TensorNetwork{A}) where {A} = A - -""" - tensors(tn::TensorNetwork) + tensors(tn::AbstractTensorNetwork) Return a list of the `Tensor`s in the [`TensorNetwork`](@ref). """ -tensors(tn::TensorNetwork) = tn.tensors -arrays(tn::TensorNetwork) = parent.(tensors(tn)) +tensors(tn::absclass(TensorNetwork)) = tn.tensors +arrays(tn::absclass(TensorNetwork)) = parent.(tensors(tn)) """ - inds(tn::TensorNetwork, set = :all) + inds(tn::AbstractTensorNetwork, set = :all) Return the names of the indices in the [`TensorNetwork`](@ref). @@ -118,46 +63,38 @@ Return the names of the indices in the [`TensorNetwork`](@ref). + `:inner` Indices mentioned at least twice. + `:hyper` Indices mentioned at least in three tensors. """ -EinExprs.inds(tn::TensorNetwork; set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...) -@valsplit 2 EinExprs.inds(tn::TensorNetwork, set::Symbol, args...) = throw(MethodError(inds, "set=$set not recognized")) -EinExprs.inds(tn::TensorNetwork, ::Val{:all}) = collect(keys(tn.indices)) -EinExprs.inds(tn::TensorNetwork, ::Val{:open}) = map(first, Iterators.filter(==(1) ∘ length ∘ last, tn.indices)) -EinExprs.inds(tn::TensorNetwork, ::Val{:inner}) = map(first, Iterators.filter(>=(2) ∘ length ∘ last, tn.indices)) -EinExprs.inds(tn::TensorNetwork, ::Val{:hyper}) = map(first, Iterators.filter(>=(3) ∘ length ∘ last, tn.indices)) +EinExprs.inds(tn::absclass(TensorNetwork); set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...) +@valsplit 2 EinExprs.inds(tn::absclass(TensorNetwork), set::Symbol, args...) = + throw(MethodError(inds, "set=$set not recognized")) +EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:all}) = collect(keys(tn.indices)) +EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:open}) = + map(first, Iterators.filter(==(1) ∘ length ∘ last, tn.indices)) +EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:inner}) = + map(first, Iterators.filter(>=(2) ∘ length ∘ last, tn.indices)) +EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:hyper}) = + map(first, Iterators.filter(>=(3) ∘ length ∘ last, tn.indices)) """ - size(tn::TensorNetwork) - size(tn::TensorNetwork, index) + size(tn::AbstractTensorNetwork) + size(tn::AbstractTensorNetwork, index) Return a mapping from indices to their dimensionalities. If `index` is set, return the dimensionality of `index`. This is equivalent to `size(tn)[index]`. """ -Base.size(tn::TensorNetwork) = Dict(i => size(tn, i) for (i, x) in tn.indices) -Base.size(tn::TensorNetwork, i::Symbol) = size(tn.tensors[first(tn.indices[i])], i) - -Base.eltype(tn::TensorNetwork) = promote_type(eltype.(tensors(tn))...) - -Base.getindex(tn::TensorNetwork, key::Symbol) = tn.metadata[key] -Base.fieldnames(tn::T) where {T<:TensorNetwork} = fieldnames(T) -Base.propertynames(tn::TensorNetwork{A,N}) where {A,N} = tuple(fieldnames(tn)..., fieldnames(N)...) -Base.getproperty(tn::T, name::Symbol) where {T<:TensorNetwork} = - if hasfield(T, name) - getfield(tn, name) - elseif hasfield(fieldtype(T, :metadata), name) - getfield(getfield(tn, :metadata), name) - else - throw(KeyError(name)) - end +Base.size(tn::absclass(TensorNetwork)) = Dict(i => size(tn, i) for (i, x) in tn.indices) +Base.size(tn::absclass(TensorNetwork), i::Symbol) = size(tn.tensors[first(tn.indices[i])], i) + +Base.eltype(tn::absclass(TensorNetwork)) = promote_type(eltype.(tensors(tn))...) """ - push!(tn::TensorNetwork, tensor::Tensor) + push!(tn::AbstractTensorNetwork, tensor::Tensor) Add a new `tensor` to the Tensor Network. See also: [`append!`](@ref), [`pop!`](@ref). """ -function Base.push!(tn::TensorNetwork, tensor::Tensor) +function Base.push!(tn::absclass(TensorNetwork), tensor::Tensor) for i in Iterators.filter(i -> size(tn, i) != size(tensor, i), inds(tensor) ∩ inds(tn)) throw(DimensionMismatch("size(tensor,$i)=$(size(tensor,i)) but should be equal to size(tn,$i)=$(size(tn,i))")) end @@ -172,22 +109,22 @@ function Base.push!(tn::TensorNetwork, tensor::Tensor) end """ - append!(tn::TensorNetwork, tensors::AbstractVecOrTuple{<:Tensor}) - append!(A::TensorNetwork, B::TensorNetwork) + append!(tn::AbstractTensorNetwork, tensors::AbstractVecOrTuple{<:Tensor}) + append!(A::AbstractTensorNetwork, B::AbstractTensorNetwork) Add a list of tensors to the first `TensorNetwork`. See also: [`push!`](@ref) """ -Base.append!(tn::TensorNetwork, t::AbstractVecOrTuple{<:Tensor}) = (foreach(Base.Fix1(push!, tn), t); tn) -function Base.append!(A::TensorNetwork, B::TensorNetwork) +Base.append!(tn::absclass(TensorNetwork), t::AbstractVecOrTuple{<:Tensor}) = (foreach(Base.Fix1(push!, tn), t); tn) +function Base.append!(A::absclass(TensorNetwork), B::absclass(TensorNetwork)) append!(A, tensors(B)) # TODO define behaviour # merge!(A.metadata, B.metadata) return A end -function Base.popat!(tn::TensorNetwork, i::Integer) +function Base.popat!(tn::absclass(TensorNetwork), i::Integer) tensor = popat!(tn.tensors, i) # unlink indices @@ -207,22 +144,22 @@ function Base.popat!(tn::TensorNetwork, i::Integer) end """ - pop!(tn::TensorNetwork, tensor::Tensor) - pop!(tn::TensorNetwork, i::Union{Symbol,AbstractVecOrTuple{Symbol}}) + pop!(tn::AbstractTensorNetwork, tensor::Tensor) + pop!(tn::AbstractTensorNetwork, i::Union{Symbol,AbstractVecOrTuple{Symbol}}) Remove a tensor from the Tensor Network and returns it. If a `Tensor` is passed, then the first tensor satisfies _egality_ (i.e. `≡` or `===`) will be removed. If a `Symbol` or a list of `Symbol`s is passed, then remove and return the tensors that contain all the indices. See also: [`push!`](@ref), [`delete!`](@ref). """ -function Base.pop!(tn::TensorNetwork, tensor::Tensor) +function Base.pop!(tn::absclass(TensorNetwork), tensor::Tensor) i = findfirst(t -> t === tensor, tn.tensors) popat!(tn, i) end -Base.pop!(tn::TensorNetwork, i::Symbol) = pop!(tn, (i,)) +Base.pop!(tn::absclass(TensorNetwork), i::Symbol) = pop!(tn, (i,)) -function Base.pop!(tn::TensorNetwork, i::AbstractVecOrTuple{Symbol})::Vector{Tensor} +function Base.pop!(tn::absclass(TensorNetwork), i::AbstractVecOrTuple{Symbol})::Vector{Tensor} tensors = select(tn, i) for tensor in tensors _ = pop!(tn, tensor) @@ -232,23 +169,23 @@ function Base.pop!(tn::TensorNetwork, i::AbstractVecOrTuple{Symbol})::Vector{Ten end """ - delete!(tn::TensorNetwork, x) + delete!(tn::AbstractTensorNetwork, x) Like [`pop!`](@ref) but return the [`TensorNetwork`](@ref) instead. """ -Base.delete!(tn::TensorNetwork, x) = (_ = pop!(tn, x); tn) +Base.delete!(tn::absclass(TensorNetwork), x) = (_ = pop!(tn, x); tn) """ - replace(tn::TensorNetwork, old => new...) + replace(tn::AbstractTensorNetwork, old => new...) Return a copy of the [`TensorNetwork`](@ref) where `old` has been replaced by `new`. See also: [`replace!`](@ref). """ -Base.replace(tn::TensorNetwork, old_new::Pair...) = replace!(copy(tn), old_new...) +Base.replace(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(copy(tn), old_new...) """ - replace!(tn::TensorNetwork, old => new...) + replace!(tn::AbstractTensorNetwork, old => new...) Replace the element in `old` with the one in `new`. Depending on the types of `old` and `new`, the following behaviour is expected: @@ -257,14 +194,14 @@ Replace the element in `old` with the one in `new`. Depending on the types of `o See also: [`replace`](@ref). """ -function Base.replace!(tn::TensorNetwork, old_new::Pair...) +function Base.replace!(tn::absclass(TensorNetwork), old_new::Pair...) for pair in old_new replace!(tn, pair) end return tn end -function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor}) +function Base.replace!(tn::absclass(TensorNetwork), pair::Pair{<:Tensor,<:Tensor}) old_tensor, new_tensor = pair # check if old and new tensors are compatible @@ -279,7 +216,7 @@ function Base.replace!(tn::TensorNetwork, pair::Pair{<:Tensor,<:Tensor}) return tn end -function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}) +function Base.replace!(tn::absclass(TensorNetwork), old_new::Pair{Symbol,Symbol}) old, new = old_new new ∈ inds(tn) && throw(ArgumentError("new symbol $new is already present")) @@ -292,7 +229,7 @@ function Base.replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}) return tn end -function Base.replace!(tn::TensorNetwork, old_new::Pair{<:Tensor,<:TensorNetwork}) +function Base.replace!(tn::absclass(TensorNetwork), old_new::Pair{<:Tensor,<:AbstractTensorNetwork}) old, new = old_new issetequal(inds(new, set = :open), inds(old)) || throw(ArgumentError("indices must match")) @@ -306,29 +243,29 @@ function Base.replace!(tn::TensorNetwork, old_new::Pair{<:Tensor,<:TensorNetwork end """ - select(tn::TensorNetwork, i) + select(tn::AbstractTensorNetwork, i) Return tensors whose indices match with the list of indices `i`. """ -select(tn::TensorNetwork, i::AbstractVecOrTuple{Symbol}) = filter(Base.Fix1(⊆, i) ∘ inds, tensors(tn)) -select(tn::TensorNetwork, i::Symbol) = map(x -> tn.tensors[x], unique(tn.indices[i])) +select(tn::absclass(TensorNetwork), i::AbstractVecOrTuple{Symbol}) = filter(Base.Fix1(⊆, i) ∘ inds, tensors(tn)) +select(tn::absclass(TensorNetwork), i::Symbol) = map(x -> tn.tensors[x], unique(tn.indices[i])) """ - in(tensor::Tensor, tn::TensorNetwork) + in(tensor::Tensor, tn::AbstractTensorNetwork) Return `true` if there is a `Tensor` in `tn` for which `==` evaluates to `true`. This method is equivalent to `tensor ∈ tensors(tn)` code, but it's faster on large amount of tensors. """ -Base.in(tensor::Tensor, tn::TensorNetwork) = in(tensor, select(tn, inds(tensor))) +Base.in(tensor::Tensor, tn::absclass(TensorNetwork)) = in(tensor, select(tn, inds(tensor))) """ - slice!(tn::TensorNetwork, index::Symbol, i) + slice!(tn::AbstractTensorNetwork, index::Symbol, i) In-place projection of `index` on dimension `i`. See also: [`selectdim`](@ref), [`view`](@ref). """ -function slice!(tn::TensorNetwork, label::Symbol, i) +function slice!(tn::absclass(TensorNetwork), label::Symbol, i) for tensor in select(tn, label) pos = findfirst(t -> t === tensor, tn.tensors) tn.tensors[pos] = selectdim(tensor, label, i) @@ -340,23 +277,23 @@ function slice!(tn::TensorNetwork, label::Symbol, i) end """ - selectdim(tn::TensorNetwork, index::Symbol, i) + selectdim(tn::AbstractTensorNetwork, index::Symbol, i) Return a copy of the [`TensorNetwork`](@ref) where `index` has been projected to dimension `i`. See also: [`view`](@ref), [`slice!`](@ref). """ -Base.selectdim(tn::TensorNetwork, label::Symbol, i) = @view tn[label=>i] +Base.selectdim(tn::absclass(TensorNetwork), label::Symbol, i) = @view tn[label=>i] """ - view(tn::TensorNetwork, index => i...) + view(tn::AbstractTensorNetwork, index => i...) Return a copy of the [`TensorNetwork`](@ref) where each `index` has been projected to dimension `i`. It is equivalent to a recursive call of [`selectdim`](@ref). See also: [`selectdim`](@ref), [`slice!`](@ref). """ -function Base.view(tn::TensorNetwork, slices::Pair{Symbol,<:Any}...) +function Base.view(tn::absclass(TensorNetwork), slices::Pair{Symbol,<:Any}...) tn = copy(tn) for (label, i) in slices @@ -419,12 +356,12 @@ function Base.rand( push!.(inputs, (ind,)) end - tensors = [Tensor(rand([size_dict[ind] for ind in input]...), tuple(input...)) for input in inputs] + tensors = Tensor[Tensor(rand([size_dict[ind] for ind in input]...), tuple(input...)) for input in inputs] TensorNetwork(tensors) end """ - einexpr(tn::TensorNetwork; optimizer = EinExprs.Greedy, output = inds(tn, :open), kwargs...) + einexpr(tn::AbstractTensorNetwork; optimizer = EinExprs.Greedy, output = inds(tn, :open), kwargs...) Search a contraction path for the given [`TensorNetwork`](@ref) and return it as a `EinExpr`. @@ -436,7 +373,7 @@ Search a contraction path for the given [`TensorNetwork`](@ref) and return it as See also: [`contract`](@ref). """ -EinExprs.einexpr(tn::TensorNetwork; optimizer = Greedy, outputs = inds(tn, :open), kwargs...) = einexpr( +EinExprs.einexpr(tn::absclass(TensorNetwork); optimizer = Greedy, outputs = inds(tn, :open), kwargs...) = einexpr( optimizer, EinExpr( outputs, @@ -448,13 +385,13 @@ EinExprs.einexpr(tn::TensorNetwork; optimizer = Greedy, outputs = inds(tn, :open # TODO sequence of indices? # TODO what if parallel neighbour indices? """ - contract!(tn::TensorNetwork, index) + contract!(tn::AbstractTensorNetwork, index) In-place contraction of tensors connected to `index`. See also: [`contract`](@ref). """ -function contract!(tn::TensorNetwork, i) +function contract!(tn::absclass(TensorNetwork), i) tensor = reduce(pop!(tn, i)) do acc, tensor contract(acc, tensor, dims = i) end @@ -464,7 +401,7 @@ function contract!(tn::TensorNetwork, i) end """ - contract(tn::TensorNetwork; kwargs...) + contract(tn::AbstractTensorNetwork; kwargs...) Contract a [`TensorNetwork`](@ref). The contraction order will be first computed by [`einexpr`](@ref). @@ -472,7 +409,7 @@ The `kwargs` will be passed down to the [`einexpr`](@ref) function. See also: [`einexpr`](@ref), [`contract!`](@ref). """ -function contract(tn::TensorNetwork; path = einexpr(tn)) +function contract(tn::absclass(TensorNetwork); path = einexpr(tn)) # TODO does `first` work always? length(path.args) == 0 && return select(tn, inds(path)) |> first @@ -480,27 +417,22 @@ function contract(tn::TensorNetwork; path = einexpr(tn)) contract(intermediates...; dims = suminds(path)) end -contract!(t::Tensor, tn::TensorNetwork; kwargs...) = contract!(tn, t; kwargs...) -contract!(tn::TensorNetwork, t::Tensor; kwargs...) = (push!(tn, t); contract(tn; kwargs...)) -contract(t::Tensor, tn::TensorNetwork; kwargs...) = contract(tn, t; kwargs...) -contract(tn::TensorNetwork, t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...) - -struct TNSampler{A<:Ansatz,NT<:NamedTuple} <: Random.Sampler{TensorNetwork{A}} - parameters::NT - - TNSampler{A}(; kwargs...) where {A} = new{A,typeof(values(kwargs))}(values(kwargs)) -end +contract!(t::Tensor, tn::absclass(TensorNetwork); kwargs...) = contract!(tn, t; kwargs...) +contract!(tn::absclass(TensorNetwork), t::Tensor; kwargs...) = (push!(tn, t); contract(tn; kwargs...)) +contract(t::Tensor, tn::absclass(TensorNetwork); kwargs...) = contract(tn, t; kwargs...) +contract(tn::absclass(TensorNetwork), t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...) -Base.getproperty(obj::TNSampler{A,<:NamedTuple{K}}, name::Symbol) where {A,K} = - name ∈ K ? getfield(obj, :parameters)[name] : getfield(obj, name) -Base.get(obj::TNSampler, name, default) = get(getfield(obj, :parameters), name, default) +# struct TNSampler{A<:Ansatz,NT<:NamedTuple} <: Random.Sampler{TensorNetwork{A}} +# parameters::NT -Base.eltype(::TNSampler{A}) where {A<:Ansatz} = TensorNetwork{A} +# TNSampler{A}(; kwargs...) where {A} = new{A,typeof(values(kwargs))}(values(kwargs)) +# end -Base.rand(A::Type{<:Ansatz}; kwargs...) = rand(Random.default_rng(), A; kwargs...) -Base.rand(rng::AbstractRNG, ::Type{A}; kwargs...) where {A<:Ansatz} = rand(rng, TNSampler{A}(; kwargs...)) +# Base.getproperty(obj::TNSampler{A,<:NamedTuple{K}}, name::Symbol) where {A,K} = +# name ∈ K ? getfield(obj, :parameters)[name] : getfield(obj, name) +# Base.get(obj::TNSampler, name, default) = get(getfield(obj, :parameters), name, default) -Base.convert(::Type{T}, tn::TensorNetwork{A}) where {T<:Ansatz,A<:T} = - TensorNetwork{T}(tensors(tn); metadata(T)(tn.metadata)...) +# Base.eltype(::TNSampler{A}) where {A<:Ansatz} = TensorNetwork{A} -Base.convert(::Type{T}, tn::TensorNetwork{A}; metadata...) where {A<:Ansatz,T<:A} = TensorNetwork{T}(tn; metadata...) +# Base.rand(A::Type{<:Ansatz}; kwargs...) = rand(Random.default_rng(), A; kwargs...) +# Base.rand(rng::AbstractRNG, ::Type{A}; kwargs...) where {A<:Ansatz} = rand(rng, TNSampler{A}(; kwargs...)) diff --git a/src/Transformations.jl b/src/Transformations.jl index 3e5ba7d0..54a3ba84 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -8,27 +8,28 @@ using Combinatorics: combinations abstract type Transformation end """ - transform(tn::TensorNetwork, config::Transformation) - transform(tn::TensorNetwork, configs) + transform(tn::AbstractTensorNetwork, config::Transformation) + transform(tn::AbstractTensorNetwork, configs) Return a new [`TensorNetwork`](@ref) where some `Transformation` has been performed into it. See also: [`transform!`](@ref). """ -transform(tn::TensorNetwork, transformations) = (tn = deepcopy(tn); transform!(tn, transformations); return tn) +transform(tn::absclass(TensorNetwork), transformations) = + (tn = deepcopy(tn); transform!(tn, transformations); return tn) """ - transform!(tn::TensorNetwork, config::Transformation) - transform!(tn::TensorNetwork, configs) + transform!(tn::AbstractTensorNetwork, config::Transformation) + transform!(tn::AbstractTensorNetwork, configs) In-place version of [`transform`](@ref). """ function transform! end -transform!(tn::TensorNetwork, transformation::Type{<:Transformation}; kwargs...) = +transform!(tn::absclass(TensorNetwork), transformation::Type{<:Transformation}; kwargs...) = transform!(tn, transformation(kwargs...)) -function transform!(tn::TensorNetwork, transformations) +function transform!(tn::absclass(TensorNetwork), transformations) for transformation in transformations transform!(tn, transformation) end @@ -43,7 +44,7 @@ This transformation is always used by default when visualizing a `TensorNetwork` """ struct HyperindConverter <: Transformation end -function hyperflatten(tn::TensorNetwork) +function hyperflatten(tn::absclass(TensorNetwork)) map(inds(tn, :hyper)) do hyperindex n = select(tn, hyperindex) |> length map(1:n) do i @@ -52,7 +53,7 @@ function hyperflatten(tn::TensorNetwork) end |> Dict end -function transform!(tn::TensorNetwork, ::HyperindConverter) +function transform!(tn::absclass(TensorNetwork), ::HyperindConverter) for (flatindices, hyperindex) in hyperflatten(tn) # insert COPY tensor array = DeltaArray{length(flatindices)}(ones(size(tn, hyperindex))) @@ -82,7 +83,7 @@ Base.@kwdef struct DiagonalReduction <: Transformation atol::Float64 = 1e-12 end -function transform!(tn::TensorNetwork, config::DiagonalReduction) +function transform!(tn::absclass(TensorNetwork), config::DiagonalReduction) for tensor in filter(tensor -> !(parenttype(typeof(tensor)) <: DeltaArray), tensors(tn)) diaginds = find_diag_axes(tensor, atol = config.atol) isempty(diaginds) && continue @@ -111,7 +112,7 @@ function transform!(tn::TensorNetwork, config::DiagonalReduction) return (; target = target, copies = copies) end - transformed_tn = TensorNetwork([transformed_tensor.target, transformed_tensor.copies...]) + transformed_tn = TensorNetwork(Tensor[transformed_tensor.target, transformed_tensor.copies...]) replace!(tn, tensor => transformed_tn) end @@ -125,7 +126,7 @@ Preemptively contract tensors whose result doesn't increase in size. """ struct RankSimplification <: Transformation end -function transform!(tn::TensorNetwork, ::RankSimplification) +function transform!(tn::absclass(TensorNetwork), ::RankSimplification) @label rank_transformation_start for tensor in tensors(tn) # TODO replace this code for `neighbours` method @@ -173,7 +174,7 @@ Base.@kwdef struct AntiDiagonalGauging <: Transformation skip::Vector{Symbol} = Symbol[] end -function transform!(tn::TensorNetwork, config::AntiDiagonalGauging) +function transform!(tn::absclass(TensorNetwork), config::AntiDiagonalGauging) skip_inds = isempty(config.skip) ? inds(tn, set = :open) : config.skip for idx in keys(tn.tensors) @@ -212,7 +213,7 @@ Base.@kwdef struct ColumnReduction <: Transformation skip::Vector{Symbol} = Symbol[] end -function transform!(tn::TensorNetwork, config::ColumnReduction) +function transform!(tn::absclass(TensorNetwork), config::ColumnReduction) skip_inds = isempty(config.skip) ? inds(tn, set = :open) : config.skip for tensor in tn.tensors @@ -284,7 +285,7 @@ Base.@kwdef struct SplitSimplification <: Transformation atol::Float64 = 1e-10 # A threshold for SVD rank determination end -function transform!(tn::TensorNetwork, config::SplitSimplification) +function transform!(tn::absclass(TensorNetwork), config::SplitSimplification) @label split_simplification_start for tensor in tensors(tn) inds = Tenet.inds(tensor) diff --git a/test/Helpers_test.jl b/test/Helpers_test.jl index 5c9d6138..5d0b4e6c 100644 --- a/test/Helpers_test.jl +++ b/test/Helpers_test.jl @@ -12,23 +12,4 @@ # NOTE probabilitic testing due to time taken by `letter`. refactor when `letter` is optimized. @test all(isletter ∘ only ∘ String, Iterators.map(letter, rand(1:Tenet.NUM_UNICODE_LETTERS, 1000))) end - - @testset "merge" begin - N = NamedTuple{(),Tuple{}} - @test merge(N, N) === N - - A = NamedTuple{(:a,),Tuple{Int}} - @test merge(A, N) === merge(N, A) === A - - B = NamedTuple{(:b,),Tuple{Float64}} - @test merge(A, B) === - merge(A, B, N) === - merge(N, A, B) === - merge(A, N, B) === - NamedTuple{(:a, :b),Tuple{Int,Float64}} - end - - @testset "superansatzes" begin - @test Tenet.superansatzes(Arbitrary) === (Arbitrary, Ansatz) - end end diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 24c13e17..afe65b73 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -2,7 +2,6 @@ @testset "Constructors" begin @testset "empty" begin tn = TensorNetwork() - @test ansatz(tn) == ansatz(typeof(tn)) === Tenet.Arbitrary @test isempty(tensors(tn)) @test isempty(inds(tn)) @test isempty(size(tn)) @@ -10,7 +9,7 @@ @testset "list" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork([tensor]) + tn = TensorNetwork(Tensor[tensor]) @test only(tensors(tn)) === tensor @@ -61,7 +60,7 @@ @testset "pop!" begin @testset "by reference" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork([tensor]) + tn = TensorNetwork(Tensor[tensor]) @test pop!(tn, tensor) === tensor @test length(tn.tensors) == 0 @@ -71,7 +70,7 @@ @testset "by symbol" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork([tensor]) + tn = TensorNetwork(Tensor[tensor]) @test only(pop!(tn, :i)) === tensor @test length(tn.tensors) == 0 @@ -81,7 +80,7 @@ @testset "by symbols" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork([tensor]) + tn = TensorNetwork(Tensor[tensor]) @test only(pop!(tn, (:i, :j))) === tensor @test length(tn.tensors) == 0 @@ -93,7 +92,7 @@ # TODO by simbols @testset "delete!" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork([tensor]) + tn = TensorNetwork(Tensor[tensor]) @test delete!(tn, tensor) === tn @test length(tn.tensors) == 0 @@ -115,12 +114,13 @@ @testset "rand" begin tn = rand(TensorNetwork, 10, 3) - @test tn isa TensorNetwork{Arbitrary} + @test tn isa TensorNetwork @test length(tn.tensors) == 10 end @testset "copy" begin - tn = rand(TensorNetwork, 10, 3) + tensor = Tensor(zeros(2, 2), (:i, :j)) + tn = TensorNetwork(Tensor[tensor]) tn_copy = copy(tn) @test tensors(tn_copy) !== tensors(tn) && all(tensors(tn_copy) .=== tensors(tn)) @@ -128,12 +128,14 @@ end @testset "inds" begin - tn = TensorNetwork([ - Tensor(zeros(2, 2), (:i, :j)), - Tensor(zeros(2, 2), (:i, :k)), - Tensor(zeros(2, 2, 2), (:i, :l, :m)), - Tensor(zeros(2, 2), (:l, :m)), - ]) + tn = TensorNetwork( + Tensor[ + Tensor(zeros(2, 2), (:i, :j)), + Tensor(zeros(2, 2), (:i, :k)), + Tensor(zeros(2, 2, 2), (:i, :l, :m)), + Tensor(zeros(2, 2), (:l, :m)), + ], + ) @test issetequal(inds(tn), (:i, :j, :k, :l, :m)) @test issetequal(inds(tn, :open), (:j, :k)) @@ -142,12 +144,14 @@ end @testset "size" begin - tn = TensorNetwork([ - Tensor(zeros(2, 3), (:i, :j)), - Tensor(zeros(2, 4), (:i, :k)), - Tensor(zeros(2, 5, 6), (:i, :l, :m)), - Tensor(zeros(5, 6), (:l, :m)), - ]) + tn = TensorNetwork( + Tensor[ + Tensor(zeros(2, 3), (:i, :j)), + Tensor(zeros(2, 4), (:i, :k)), + Tensor(zeros(2, 5, 6), (:i, :l, :m)), + Tensor(zeros(5, 6), (:l, :m)), + ], + ) @test size(tn) == Dict((:i => 2, :j => 3, :k => 4, :l => 5, :m => 6)) @test all([size(tn, :i) == 2, size(tn, :j) == 3, size(tn, :k) == 4, size(tn, :l) == 5, size(tn, :m) == 6]) @@ -160,7 +164,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) @test issetequal(select(tn, :i), (t_ij, t_ik, t_ilm)) @test issetequal(select(tn, :j), (t_ij,)) @@ -201,7 +205,7 @@ A = Tensor(rand(2, 2, 2), (:i, :j, :k)) B = Tensor(rand(2, 2, 2), (:k, :l, :m)) - tn = TensorNetwork([A, B]) + tn = TensorNetwork(Tensor[A, B]) @test contract(tn) isa Tensor end @@ -210,7 +214,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) @testset "replace inds" begin mapping = (:i => :u, :j => :v, :k => :w, :l => :x, :m => :y) @@ -251,7 +255,7 @@ # New tensor network with two tensors with the same inds A = Tensor(rand(2, 2), (:u, :w)) B = Tensor(rand(2, 2), (:u, :w)) - tn = TensorNetwork([A, B]) + tn = TensorNetwork(Tensor[A, B]) new_tensor = Tensor(rand(2, 2), (:u, :w)) @@ -259,7 +263,7 @@ @test A === tn.tensors[1] @test new_tensor === tn.tensors[2] - tn = TensorNetwork([A, B]) + tn = TensorNetwork(Tensor[A, B]) replace!(tn, A => new_tensor) @test issetequal(tensors(tn), [new_tensor, B]) @@ -268,7 +272,7 @@ A = Tensor(zeros(2, 2), (:i, :j)) B = Tensor(zeros(2, 2), (:j, :k)) C = Tensor(zeros(2, 2), (:k, :l)) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) @test_throws ArgumentError replace!(tn, A => B, B => C, C => A) diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index e8813a2b..5c690336 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -25,7 +25,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) transform!(tn, HyperindConverter) @test isempty(inds(tn, :hyper)) @@ -66,7 +66,7 @@ @test issetequal(find_diag_axes(A), [[:i, :j]]) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) reduced = transform(tn, DiagonalReduction) @test all( @@ -100,7 +100,7 @@ @test issetequal(find_diag_axes(A), [[:i, :l], [:j, :m]]) @test issetequal(find_diag_axes(B), [[:j, :n, :o]]) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) reduced = transform(tn, DiagonalReduction) # Test that all tensors (that are no COPY tensors) in reduced have no @@ -124,7 +124,7 @@ D = Tensor(rand(2), (:p,)) E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) - tn = TensorNetwork([A, B, C, D, E]) + tn = TensorNetwork(Tensor[A, B, C, D, E]) reduced = transform(tn, RankSimplification) # Test that the resulting tn contains no tensors with larger rank than the original @@ -175,7 +175,7 @@ @test issetequal(find_anti_diag_axes(parent(A)), [(1, 4), (2, 5)]) @test issetequal(find_anti_diag_axes(parent(B)), [(1, 2)]) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) gauged = transform(tn, AntiDiagonalGauging) # Test that all tensors in gauged have no antidiagonals @@ -201,7 +201,7 @@ @test issetequal(find_zero_columns(parent(A)), [(2, 1), (2, 2)]) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) reduced = transform(tn, ColumnReduction) # Test that all the tensors in reduced have no columns and they do not have the 2nd :j index @@ -226,7 +226,7 @@ @test issetequal(find_zero_columns(parent(A)), [(2, 2)]) - tn = TensorNetwork([A, B, C]) + tn = TensorNetwork(Tensor[A, B, C]) reduced = transform(tn, ColumnReduction) # Test that all the tensors in reduced have no columns and they have smaller dimensions in the 2nd :j index @@ -252,7 +252,7 @@ t1 = contract(v1, v2) tensor = contract(t1, m1) # Define a tensor which can be splitted in three - tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) + tn = TensorNetwork(Tensor[tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) reduced = transform(tn, SplitSimplification) # Test that the new tensors in reduced are smaller than the deleted ones From 6f32e3d187844a1068a6b195625ee8b2dd1353e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 5 Oct 2023 00:29:40 +0200 Subject: [PATCH 02/29] Fix refactor in `ChainRulesCore`,`FiniteDifferences` extensions --- ext/TenetChainRulesCoreExt.jl | 12 ++++++++++-- ext/TenetFiniteDifferencesExt.jl | 13 +++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index e86b8e17..4e5432c1 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -38,8 +38,16 @@ end function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:absclass(TensorNetwork)} # TODO match tensors by indices - tensors = map(+, x.tensors, Δ.tensors) - T(tensors, ...) # TODO fix how to pass metadata + tensors = map(+, tensors(x), Δ.tensors) + + # TODO create function fitted for this? or maybe standardize constructors? + T(map(fieldnames(T)) do fieldname + if fieldname === :tensors + tensors + else + getfield(x, fieldname) + end + end...) end function ChainRulesCore.frule((_, Δ), T::Type{<:absclass(TensorNetwork)}, tensors) diff --git a/ext/TenetFiniteDifferencesExt.jl b/ext/TenetFiniteDifferencesExt.jl index e7b453f6..cf39c270 100644 --- a/ext/TenetFiniteDifferencesExt.jl +++ b/ext/TenetFiniteDifferencesExt.jl @@ -1,13 +1,22 @@ module TenetFiniteDifferencesExt using Tenet +using Classes using FiniteDifferences -function FiniteDifferences.to_vec(x::TensorNetwork{A}) where {A<:Ansatz} +function FiniteDifferences.to_vec(x::T) where {T<:absclass(TensorNetwork)} x_vec, back = to_vec(x.tensors) function TensorNetwork_from_vec(v) tensors = back(v) - TensorNetwork{A}(tensors; x.metadata...) + + # TODO create function fitted for this? or maybe standardize constructors? + T(map(fieldnames(T)) do fieldname + if fieldname === :tensors + tensors + else + getfield(x, fieldname) + end + end...) end return x_vec, TensorNetwork_from_vec From dc00471d058cc511d79a7054281ee51c9f381b89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 5 Oct 2023 00:58:01 +0200 Subject: [PATCH 03/29] Fix `ProjectTo` to `TensorNetwork` --- ext/TenetChainRulesCoreExt.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index 4e5432c1..a38ae735 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -28,7 +28,15 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull @non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...) function ChainRulesCore.ProjectTo(tn::T) where {T<:absclass(TensorNetwork)} - ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata) + # TODO create function to extract extra fields + fields = map(fieldnames(T)) do fieldname + if fieldname === :tensors + :tensors => ProjectTo(tn.tensors) + else + fieldname => getfield(tn, fieldname) + end + end + ProjectTo{T}(; fields...) end function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:absclass(TensorNetwork)} From b5b3317e27c136173d8a67bca1c45cd78317f6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 5 Oct 2023 19:07:17 +0200 Subject: [PATCH 04/29] Import `EinExprs.inds` symbol --- src/Tenet.jl | 2 ++ src/TensorNetwork.jl | 16 ++++++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/Tenet.jl b/src/Tenet.jl index 6dfb180b..b9b92b07 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -1,5 +1,7 @@ module Tenet +import EinExprs: inds + include("Helpers.jl") include("Tensor.jl") diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 86b50b1e..c5179f9a 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -63,16 +63,12 @@ Return the names of the indices in the [`TensorNetwork`](@ref). + `:inner` Indices mentioned at least twice. + `:hyper` Indices mentioned at least in three tensors. """ -EinExprs.inds(tn::absclass(TensorNetwork); set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...) -@valsplit 2 EinExprs.inds(tn::absclass(TensorNetwork), set::Symbol, args...) = - throw(MethodError(inds, "set=$set not recognized")) -EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:all}) = collect(keys(tn.indices)) -EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:open}) = - map(first, Iterators.filter(==(1) ∘ length ∘ last, tn.indices)) -EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:inner}) = - map(first, Iterators.filter(>=(2) ∘ length ∘ last, tn.indices)) -EinExprs.inds(tn::absclass(TensorNetwork), ::Val{:hyper}) = - map(first, Iterators.filter(>=(3) ∘ length ∘ last, tn.indices)) +inds(tn::absclass(TensorNetwork); set::Symbol = :all, kwargs...) = inds(tn, set; kwargs...) +@valsplit 2 inds(tn::absclass(TensorNetwork), set::Symbol, args...) = throw(MethodError(inds, "unknown set=$set")) +inds(tn::absclass(TensorNetwork), ::Val{:all}) = collect(keys(tn.indices)) +inds(tn::absclass(TensorNetwork), ::Val{:open}) = map(first, Iterators.filter(==(1) ∘ length ∘ last, tn.indices)) +inds(tn::absclass(TensorNetwork), ::Val{:inner}) = map(first, Iterators.filter(>=(2) ∘ length ∘ last, tn.indices)) +inds(tn::absclass(TensorNetwork), ::Val{:hyper}) = map(first, Iterators.filter(>=(3) ∘ length ∘ last, tn.indices)) """ size(tn::AbstractTensorNetwork) From eb5094a910358869d0ccfe54bd24bcbdc467160b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 5 Oct 2023 20:36:36 +0200 Subject: [PATCH 05/29] Fix invalidation of `EinExprs.inds` symbol import --- src/Numerics.jl | 1 - src/Tensor.jl | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/Numerics.jl b/src/Numerics.jl index e8121118..ec58264c 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -1,7 +1,6 @@ using OMEinsum using LinearAlgebra using UUIDs: uuid4 -using EinExprs: inds # TODO test array container typevar on output for op in [ diff --git a/src/Tensor.jl b/src/Tensor.jl index ea5737e3..514590f9 100644 --- a/src/Tensor.jl +++ b/src/Tensor.jl @@ -1,6 +1,5 @@ using Base: @propagate_inbounds using Base.Broadcast: Broadcasted, ArrayStyle -using EinExprs using ImmutableArrays struct Tensor{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} @@ -23,6 +22,8 @@ Tensor(data::A, inds::NTuple{N,Symbol}) where {T,N,A<:AbstractArray{T,N}} = Tens Tensor(data::AbstractArray{T,0}) where {T} = Tensor(data, Symbol[]) Tensor(data::Number) = Tensor(fill(data)) +inds(t::Tensor) = t.inds + function Base.copy(t::Tensor{T,N,<:SubArray{T,N}}) where {T,N} data = copy(t.data) inds = t.inds @@ -78,8 +79,6 @@ function Base.isapprox(a::Tensor, b::Tensor) end end -EinExprs.inds(t::Tensor) = t.inds - # NOTE: `replace` does not currenly support cyclic replacements Base.replace(t::Tensor, old_new::Pair{Symbol,Symbol}...) = Tensor(parent(t), replace(inds(t), old_new...)) From 2c91f0e53f82733e9653b4d5a7ec75f13f0c50da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Thu, 5 Oct 2023 20:37:12 +0200 Subject: [PATCH 06/29] Fix refactor on `Makie` extension --- ext/TenetMakieExt.jl | 6 +++--- test/integration/Makie_test.jl | 7 ++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/ext/TenetMakieExt.jl b/ext/TenetMakieExt.jl index 08463e04..a050f0be 100644 --- a/ext/TenetMakieExt.jl +++ b/ext/TenetMakieExt.jl @@ -19,7 +19,7 @@ Plot a [`TensorNetwork`](@ref) as a graph. - `labels` If `true`, show the labels of the tensor indices. Defaults to `false`. - The rest of `kwargs` are passed to `GraphMakie.graphplot`. """ -function Makie.plot(@nospecialize tn::TensorNetwork; kwargs...) +function Makie.plot(@nospecialize tn::absclass(TensorNetwork); kwargs...) f = Figure() ax, p = plot!(f[1, 1], tn; kwargs...) return Makie.FigureAxisPlot(f, ax, p) @@ -28,7 +28,7 @@ end # NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable __networklayout_dim(x) = typeof(x).super.parameters |> first -function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetwork; kwargs...) +function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::absclass(TensorNetwork); kwargs...) ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3 Axis3(f[1, 1]) else @@ -45,7 +45,7 @@ function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetw return Makie.AxisPlot(ax, p) end -function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::TensorNetwork; labels = false, kwargs...) +function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::absclass(TensorNetwork); labels = false, kwargs...) hypermap = Tenet.hyperflatten(tn) tn = transform(tn, Tenet.HyperindConverter) diff --git a/test/integration/Makie_test.jl b/test/integration/Makie_test.jl index 48a74428..f95bfa5b 100644 --- a/test/integration/Makie_test.jl +++ b/test/integration/Makie_test.jl @@ -2,11 +2,8 @@ using CairoMakie using NetworkLayout: Spring - tn = TensorNetwork([ - Tensor(rand(2, 2, 2, 2), (:x, :y, :z, :t)), - Tensor(rand(2, 2), (:x, :y)), - Tensor(rand(2), (:x,)), - ]) + tensors = Tensor[Tensor(rand(2, 2, 2, 2), (:x, :y, :z, :t)), Tensor(rand(2, 2), (:x, :y)), Tensor(rand(2), (:x,))] + tn = TensorNetwork(tensors) @testset "plot!" begin f = Figure() From a941fcd4f5edebb09749e2aa4886fd3a0441f945 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 6 Oct 2023 11:08:59 +0200 Subject: [PATCH 07/29] Fix `Classes` import in `Makie` extension --- ext/TenetMakieExt.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/ext/TenetMakieExt.jl b/ext/TenetMakieExt.jl index a050f0be..7f0a2f91 100644 --- a/ext/TenetMakieExt.jl +++ b/ext/TenetMakieExt.jl @@ -4,6 +4,7 @@ using Tenet using Combinatorics: combinations using Graphs using Makie +using Classes using GraphMakie From 270be2baf4440921cecc5560333e1b5251b8132d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 6 Oct 2023 17:01:49 +0200 Subject: [PATCH 08/29] Split functionality from `append!(::TensorNetwork)` to `merge!` --- docs/src/tensor-network.md | 1 + src/TensorNetwork.jl | 30 ++++++++++++++++++++---------- test/TensorNetwork_test.jl | 8 +++++++- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/docs/src/tensor-network.md b/docs/src/tensor-network.md index bbf01cd9..c78cb112 100644 --- a/docs/src/tensor-network.md +++ b/docs/src/tensor-network.md @@ -42,6 +42,7 @@ ansatz ```@docs push!(::TensorNetwork, ::Tensor) append!(::TensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor}) +merge!(::AbstractTensorNetwork, ::AbstractTensorNetwork) pop!(::TensorNetwork, ::Tensor) delete!(::TensorNetwork, ::Any) ``` diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index c5179f9a..c55f5257 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -106,20 +106,30 @@ end """ append!(tn::AbstractTensorNetwork, tensors::AbstractVecOrTuple{<:Tensor}) - append!(A::AbstractTensorNetwork, B::AbstractTensorNetwork) -Add a list of tensors to the first `TensorNetwork`. +Add a list of tensors to a `TensorNetwork`. -See also: [`push!`](@ref) +See also: [`push!`](@ref), [`merge!`](@ref). """ -Base.append!(tn::absclass(TensorNetwork), t::AbstractVecOrTuple{<:Tensor}) = (foreach(Base.Fix1(push!, tn), t); tn) -function Base.append!(A::absclass(TensorNetwork), B::absclass(TensorNetwork)) - append!(A, tensors(B)) - # TODO define behaviour - # merge!(A.metadata, B.metadata) - return A +function Base.append!(tn::absclass(TensorNetwork), ts::AbstractVecOrTuple{<:Tensor}) + for tensor in ts + push!(tn, tensor) + end + tn end +""" + merge!(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) + merge(self::AbstractTensorNetwork, others::AbstractTensorNetwork...) + +Fuse various [`TensorNetwork`](@ref)s into one. + +See also: [`append!`](@ref). +""" +Base.merge!(self::absclass(TensorNetwork), other::absclass(TensorNetwork)) = append!(self, tensors(other)) +Base.merge!(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = foldl(merge!, others; init = self) +Base.merge(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = merge!(copy(self), others...) + function Base.popat!(tn::absclass(TensorNetwork), i::Integer) tensor = popat!(tn.tensors, i) @@ -232,7 +242,7 @@ function Base.replace!(tn::absclass(TensorNetwork), old_new::Pair{<:Tensor,<:Abs # rename internal indices so there is no accidental hyperedge replace!(new, [index => Symbol(uuid4()) for index in filter(∈(inds(tn)), inds(new, set = :inner))]...) - append!(tn, new) + merge!(tn, new) delete!(tn, old) return tn diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index afe65b73..86d2f4bf 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -52,8 +52,14 @@ append!(B, [tensor]) @test only(tensors(B)) === tensor + end + + @testset "merge!" begin + tensor = Tensor(zeros(2, 3), (:i, :j)) + A = TensorNetwork(Tensor[tensor]) + B = TensorNetwork() - append!(A, B) + merge!(A, B) @test only(tensors(A)) === tensor end From 79027fd14af61d39c8090f779a39807d82d13a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 6 Oct 2023 17:25:31 +0200 Subject: [PATCH 09/29] Autoimplement `copy` for `TensorNetwork` subtypes --- src/TensorNetwork.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index c55f5257..9fa57d23 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -35,7 +35,7 @@ end # TensorNetwork{A}(tn::absclass(TensorNetwork){B}; metadata...) where {A,B} = # TensorNetwork{A}(tensors(tn); merge(tn.metadata, metadata)...) -Base.copy(tn::TensorNetwork) = TensorNetwork(copy(tensors(tn))) +Base.copy(tn::T) where {T<:absclass(TensorNetwork)} = T(map(field -> copy(getfield(tn, field)), fieldnames(T))...) Base.summary(io::IO, x::absclass(TensorNetwork)) = print(io, "$(length(x))-tensors $(typeof(x))") Base.show(io::IO, tn::absclass(TensorNetwork)) = From 7b48e03e70bc40d384e3dd8c51f8eb39e934c737 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 6 Oct 2023 18:08:23 +0200 Subject: [PATCH 10/29] Fix `replace!(::TensorNetwork)` for list of `Pair`s --- src/TensorNetwork.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 9fa57d23..7ed45f48 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -188,7 +188,7 @@ Return a copy of the [`TensorNetwork`](@ref) where `old` has been replaced by `n See also: [`replace!`](@ref). """ -Base.replace(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(copy(tn), old_new...) +Base.replace(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(copy(tn), old_new) """ replace!(tn::AbstractTensorNetwork, old => new...) @@ -200,7 +200,8 @@ Replace the element in `old` with the one in `new`. Depending on the types of `o See also: [`replace`](@ref). """ -function Base.replace!(tn::absclass(TensorNetwork), old_new::Pair...) +Base.replace!(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(tn, old_new) +function Base.replace!(tn::absclass(TensorNetwork), old_new::Base.AbstractVecOrTuple{Pair}) for pair in old_new replace!(tn, pair) end From e26b1f0398fe5538e148a70ac261b073656621db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Fri, 6 Oct 2023 19:05:37 +0200 Subject: [PATCH 11/29] Fix mutation on `merge(::TensorNetwork)` `copy` is not acting as expected and the copied TN has the `.indices` field mutated. --- src/TensorNetwork.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 7ed45f48..dc869df2 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -128,7 +128,7 @@ See also: [`append!`](@ref). """ Base.merge!(self::absclass(TensorNetwork), other::absclass(TensorNetwork)) = append!(self, tensors(other)) Base.merge!(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = foldl(merge!, others; init = self) -Base.merge(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = merge!(copy(self), others...) +Base.merge(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = merge!(deepcopy(self), others...) # TODO deepcopy because `indices` are not correctly copied and it mutates function Base.popat!(tn::absclass(TensorNetwork), i::Integer) tensor = popat!(tn.tensors, i) From 4e8ea4b39fcec1fd818010285aeec2db1decafab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sat, 7 Oct 2023 01:58:23 +0200 Subject: [PATCH 12/29] Refactor `Quantum` TNs --- Project.toml | 2 - src/Quantum/Quantum.jl | 311 +++++++++++++++-------------------------- src/Tenet.jl | 9 +- test/Project.toml | 1 - test/Quantum_test.jl | 187 +++++++++++++------------ test/runtests.jl | 13 +- 6 files changed, 221 insertions(+), 302 deletions(-) diff --git a/Project.toml b/Project.toml index f47fde21..110450ad 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,6 @@ authors = ["Sergio Sánchez Ramírez "] version = "0.2.0" [deps] -Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" Classes = "1a9c1350-211b-5766-99cd-4544d885a0d1" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188" @@ -35,7 +34,6 @@ TenetMakieExt = "Makie" TenetQuacExt = "Quac" [compat] -Bijections = "0.1" ChainRulesCore = "1.0" Combinatorics = "1.0" DeltaArrays = "0.1.1" diff --git a/src/Quantum/Quantum.jl b/src/Quantum/Quantum.jl index 9522da68..6b5df73a 100644 --- a/src/Quantum/Quantum.jl +++ b/src/Quantum/Quantum.jl @@ -1,256 +1,170 @@ using LinearAlgebra using UUIDs: uuid4 using ValSplit -using Bijections -using EinExprs: inds +using Classes """ - Quantum <: Ansatz + QuantumTensorNetwork Tensor Network `Ansatz` that has a notion of sites and directionality (input/output). """ -abstract type Quantum <: Arbitrary end - -# NOTE Storing `Plug` type on type parameters is not compatible with `Composite` ansatz. Use Holy traits instead. -metadata(::Type{Quantum}) = merge(metadata(supertype(Quantum)), @NamedTuple begin - plug::Type{<:Plug} - interlayer::Vector{Bijection{Int,Symbol}} -end) - -function checkmeta(::Type{Quantum}, tn::TensorNetwork) - # TODO run this check depending if State or Operator - length(tn.interlayer) >= 1 || return false - - # meta's indices exist - all(bij -> values(bij) ⊆ inds(tn), tn.interlayer) || return false - - return true +@class QuantumTensorNetwork <: TensorNetwork begin + input::Vector{Symbol} + output::Vector{Symbol} end -abstract type Boundary end -abstract type Open <: Boundary end -abstract type Periodic <: Boundary end -abstract type Infinite <: Boundary end - -""" - boundary(::TensorNetwork) - boundary(::Type{<:TensorNetwork}) - -Return the `Boundary` type of the [`TensorNetwork`](@ref). The following `Boundary`s are defined in `Tenet`: - - - `Open` - - `Periodic` - - `Infinite` -""" -function boundary end -boundary(::T) where {T<:TensorNetwork} = boundary(T) -boundary(::Type{T}) where {T<:TensorNetwork} = boundary(ansatz(T)) - -abstract type Plug end -abstract type Property <: Plug end -abstract type State <: Plug end -abstract type Operator <: Plug end - -""" - plug(::TensorNetwork{<:Quantum}) - plug(::Type{<:TensorNetwork}) - -Return the `Plug` type of the [`TensorNetwork`](@ref). The following `Plug`s are defined in `Tenet`: - - - `State` Only outputs. - - `Operator` Inputs and outputs. - - `Property` No inputs nor outputs. -""" -function plug end -plug(tn::TensorNetwork{<:Quantum}) = tn.plug -plug(T::Type{<:TensorNetwork}) = plug(ansatz(T)) +inds(tn::absclass(QuantumTensorNetwork), ::Val{:in}) = tuple(tn.input...) +inds(tn::absclass(QuantumTensorNetwork), ::Val{:in}, site) = tn.input[site] +inds(tn::absclass(QuantumTensorNetwork), ::Val{:out}) = tuple(tn.output...) +inds(tn::absclass(QuantumTensorNetwork), ::Val{:out}, site) = tn.output[site] +inds(tn::absclass(QuantumTensorNetwork), ::Val{:physical}) = ∪(tn.input, tn.output) +inds(tn::absclass(QuantumTensorNetwork), ::Val{:virtual}) = setdiff(inds(tn, Val(:all)), inds(tn, Val(:physical))) """ - sites(tn::TensorNetwork{<:Quantum}) + sites(tn::AbstractQuantumTensorNetwork, dir) Return the sites in which the [`TensorNetwork`](@ref) acts. """ -sites(tn::TensorNetwork) = collect(mapreduce(keys, ∪, tn.interlayer)) - -EinExprs.inds(tn::TensorNetwork, ::Val{:plug}) = unique(Iterators.flatten(Iterators.map(values, tn.interlayer))) -EinExprs.inds(tn::TensorNetwork, ::Val{:plug}, site) = last(tn.interlayer)[site] # inds(tn, Val(:in), site) ∪ inds(tn, Val(:out), site) -EinExprs.inds(tn::TensorNetwork, ::Val{:virtual}) = setdiff(inds(tn, Val(:all)), inds(tn, Val(:plug))) +sites(tn::absclass(QuantumTensorNetwork)) = sites(tn, :in) ∪ sites(tn, :out) +function sites(tn::absclass(QuantumTensorNetwork), dir) + if dir === :in + firstindex(tn.input):lastindex(tn.input) + elseif dir === :out + firstindex(tn.output):lastindex(tn.output) + else + throw(MethodError("unknown dir=$dir")) + end +end -""" - tensors(tn::TensorNetwork{<:Quantum}, site::Integer) +function Base.replace!(tn::absclass(QuantumTensorNetwork), old_new::Pair{Symbol,Symbol}) + Base.@invoke replace!(tn::absclass(TensorNetwork), old_new::Pair{Symbol,Symbol}) -Return the `Tensor` connected to the [`TensorNetwork`](@ref) on `site`. - -See also: [`sites`](@ref). -""" -tensors(tn::TensorNetwork{<:Quantum}, site::Integer, args...) = tensors(plug(tn), tn, site, args...) -tensors(::Type{State}, tn::TensorNetwork{<:Quantum}, site) = select(tn, inds(tn, :plug, site)) |> only -@valsplit 4 tensors(T::Type{Operator}, tn::TensorNetwork{<:Quantum}, site, dir::Symbol) = - throw(MethodError(sites, "dir=$dir not recognized")) - -function Base.replace!(tn::TensorNetwork{<:Quantum}, old_new::Pair{Symbol,Symbol}) - # replace indices in tensor network - Base.@invoke replace!(tn::TensorNetwork, old_new::Pair{Symbol,Symbol}) - - old, new = old_new - - # replace indices in interlayers (quantum-specific) - for interlayer in Iterators.filter(∋(old) ∘ image, tn.interlayer) - site = interlayer(old) - delete!(interlayer, site) - interlayer[site] = new - end + replace!(tn.input, old_new) + replace!(tn.output, old_new) return tn end -## `Composite` type """ - Composite <: Quantum + adjoint(tn::AbstractQuantumTensorNetwork) -A [`Quantum`](@ref) ansatz that represents several connected layers of [`Quantum`](@ref) [`TensorNetwork`](@ref)s. +Return the adjoint [`TensorNetwork`](@ref). # Implementation details -Introduces a field named `layermeta` that stores the metadata of each layer. - -See also: [`hcat`](@ref). +The tensors are not transposed, just `conj!` is applied to them. """ -abstract type Composite{Ts<:Tuple} <: Quantum end -Composite(@nospecialize(Ts::Type{<:Quantum}...)) = Composite{Tuple{Ts...}} -Base.fieldtypes(::Type{Composite{Ts}}) where {Ts} = fieldtypes(Ts) - -metadata(::Type{<:Composite}) = merge(metadata(Quantum), @NamedTuple begin - layermeta::Vector{Dict{Symbol,Any}} -end) - -function checkmeta(As::Type{<:Composite}, tn::TensorNetwork) - for (i, A) in enumerate(fieldtypes(As)) - tn_view = layers(tn, i) - checkansatz(tn_view) - end +function Base.adjoint(tn::absclass(QuantumTensorNetwork)) + tn = deepcopy(tn) - return true -end + # swap input/output + temp = copy(tn.input) + resize!(tn.input, length(tn.output)) + copy!(tn.input, tn.output) + resize!(tn.output, length(temp)) + copy!(tn.output, temp) -Base.length(@nospecialize(T::Type{<:Composite})) = length(fieldtypes(T)) + foreach(conj!, tensors(tn)) -# TODO create view of TN -""" - layers(tn::TensorNetwork{<:Composite}, i) + return tn +end -Return a [`TensorNetwork`](@ref) that is shallow copy of the ``i``-th layer of a `Composite` Tensor Network. -""" -function layers(tn::TensorNetwork{As}, i) where {As<:Composite} - A = fieldtypes(As)[i] - layer_plug = tn.layermeta[i][:plug] # TODO more programmatic access (e.g. plug(tn, i)?) - meta = tn.layermeta[i] +function Base.merge!(self::absclass(QuantumTensorNetwork), other::absclass(QuantumTensorNetwork)) + sites(self, :out) == sites(other, :in) || + throw(DimensionMismatch("both `QuantumTensorNetwork`s must contain the same set of sites")) - if layer_plug <: State && 1 < i < length(fieldtypes(As)) - throw(ErrorException("Layer #$i is a state but it is not a extreme layer")) + # copy to avoid mutation if reindex is needed + # TODO deepcopy because `indices` are not correctly copied and it mutates + other = deepcopy(other) + + # reindex other if needed + if inds(self, set = :out) != inds(other, set = :in) + replace!(other, map(splat(=>), zip(inds(other, set = :in), inds(self, set = :out)))) end - interlayer = if layer_plug <: State - i == 1 ? [first(tn.interlayer)] : [last(tn.interlayer)] - elseif layer_plug <: Operator - # shift if first layer is a state - tn.layermeta[1][:plug] <: State && (i = i - 1) - tn.interlayer[i:i+1] + # reindex inner indices of `other` to avoid accidental hyperindices + conflict = inds(self, set = :virtual) ∩ inds(other, set = :virtual) + if !isempty(conflict) + replace!(other, map(i -> i => Symbol(uuid4()), conflict)) end - return TensorNetwork{A}( - # TODO revise this - #filter(tensor -> get(tensor.meta, :layer, nothing) == i, tensors(tn)); - tensors(tn); - plug = layer_plug, - interlayer, - meta..., - ) + @invoke merge!(self::absclass(TensorNetwork), other::absclass(TensorNetwork)) + + # update i/o + copy!(self.output, other.output) + + self end -Base.merge(::Type{State}, ::Type{State}) = Property -Base.merge(::Type{State}, ::Type{Operator}) = State -Base.merge(::Type{Operator}, ::Type{State}) = State -Base.merge(::Type{Operator}, ::Type{Operator}) = Operator +function contract(a::absclass(QuantumTensorNetwork), b::absclass(QuantumTensorNetwork); kwargs...) + contract(merge(a, b); kwargs...) +end -# TODO implement hcat when QA or QB <: Composite -""" - hcat(A::TensorNetwork{<:Quantum}, B::TensorNetwork{<:Quantum}...)::TensorNetwork{<:Composite} +# Plug trait +abstract type Plug end +struct Property <: Plug end +struct State <: Plug end +struct Dual <: Plug end +struct Operator <: Plug end -Join [`TensorNetwork`](@ref)s into one by matching sites. """ -function Base.hcat(A::TensorNetwork{QA}, B::TensorNetwork{QB}) where {QA<:Quantum,QB<:Quantum} - issetequal(sites(A), sites(B)) || - throw(DimensionMismatch("A and B must contain the same set of sites in order to connect them")) + plug(::QuantumTensorNetwork) - # rename connector indices - newinds = Dict([s => Symbol(uuid4()) for s in sites(A)]) - - B = copy(B) +Return the `Plug` type of the [`TensorNetwork`](@ref). The following `Plug`s are defined in `Tenet`: - for site in sites(B) - a = inds(A, :plug, site) - b = inds(B, :plug, site) - if a != b && a ∉ inds(B) - replace!(B, b => a) - end + - `Property` No inputs nor outputs. + - `State` Only outputs. + - `Dual` Only inputs. + - `Operator` Inputs and outputs. +""" +function plug(tn) + if isempty(tn.input) && isempty(tn.output) + Property() + elseif isempty(tn.input) + State() + elseif isempty(tn.output) + Dual() + else + Operator() end - - # rename inner indices of B to avoid hyperindices - replace!(B, [i => Symbol(uuid4()) for i in inds(B, :inner)]...) - - combined_plug = merge(plug(A), plug(B)) - - # merge tensors and indices - interlayer = [A.interlayer..., collect(Iterators.drop(B.interlayer, 1))...] - - # TODO merge metadata? - layermeta = Dict{Symbol,Any}[ - Dict(Iterators.filter(((k, v),) -> k !== :interlayer, pairs(A.metadata))), - Dict(Iterators.filter(((k, v),) -> k !== :interlayer, pairs(B.metadata))), - ] - - return TensorNetwork{Composite(QA, QB)}([tensors(A)..., tensors(B)...]; plug = combined_plug, interlayer, layermeta) end -Base.hcat(tns::TensorNetwork...) = reduce(hcat, tns) +# Boundary trait +abstract type Boundary end +struct Open <: Boundary end +struct Periodic <: Boundary end +struct Infinite <: Boundary end """ - adjoint(tn::TensorNetwork{<:Quantum}) - -Return the adjoint [`TensorNetwork`](@ref). + boundary(::QuantumTensorNetwork) -# Implementation details +Return the `Boundary` type of the [`TensorNetwork`](@ref). The following `Boundary`s are defined in `Tenet`: -The tensors are not transposed, just `conj!` is applied to them. + - `Open` + - `Periodic` + - `Infinite` """ -function Base.adjoint(tn::TensorNetwork{<:Quantum}) - tn = deepcopy(tn) - - reverse!(tn.interlayer) - foreach(conj!, tensors(tn)) - - return tn -end - -contract(a::TensorNetwork{<:Quantum}, b::TensorNetwork{<:Quantum}; kwargs...) = contract(hcat(a, b); kwargs...) +function boundary end # TODO look for more stable ways """ - norm(ψ::TensorNetwork{<:Quantum}, p::Real=2) + norm(ψ::AbstractQuantumTensorNetwork, p::Real=2) Compute the ``p``-norm of a [`Quantum`](@ref) [`TensorNetwork`](@ref). See also: [`normalize!`](@ref). """ -function LinearAlgebra.norm(ψ::TensorNetwork{<:Quantum}, p::Real = 2; kwargs...) - p != 2 && throw(ArgumentError("p=$p is not implemented yet")) +function LinearAlgebra.norm(ψ::absclass(QuantumTensorNetwork), p::Real = 2; kwargs...) + p == 2 || throw(ArgumentError("p=$p is not implemented yet")) + + tn = merge(ψ, ψ') + all(isempty, [tn.input, tn.output]) || throw("unimplemented if <ψ|ψ> is an operator") - return contract(hcat(ψ, ψ'); kwargs...) |> only |> sqrt |> abs + return contract(tn; kwargs...) |> only |> sqrt |> abs end """ - normalize!(ψ::TensorNetwork{<:Quantum}, p::Real = 2; insert::Union{Nothing,Int} = nothing) + normalize!(ψ::AbstractQuantumTensorNetwork, p::Real = 2; insert::Union{Nothing,Int} = nothing) In-place normalize the [`TensorNetwork`](@ref). @@ -266,12 +180,12 @@ In-place normalize the [`TensorNetwork`](@ref). See also: [`norm`](@ref). """ function LinearAlgebra.normalize!( - ψ::TensorNetwork{<:Quantum}, + ψ::absclass(QuantumTensorNetwork), p::Real = 2; insert::Union{Nothing,Int} = nothing, kwargs..., ) - norm = LinearAlgebra.norm(ψ; kwargs...) + norm = LinearAlgebra.norm(ψ, p; kwargs...) if isnothing(insert) # method 1: divide all tensors by (√v)^(1/n) @@ -282,7 +196,7 @@ function LinearAlgebra.normalize!( end else # method 2: divide only one tensor - tensor = tensors(ψ, insert) + tensor = ψ.tensors[insert] # tensors(ψ, insert) # TODO fix this to match site? tensor ./= norm end end @@ -300,10 +214,9 @@ fidelity(a, b; kwargs...) = abs(only(contract(a, b'; kwargs...)))^2 Return the marginal quantum state of site. """ function marginal(ψ, site) - tensor = tensors(ψ, site) - index = inds(ψ, :plug, site) - sum(tensor, inds = setdiff(inds(tensor), [index])) -end + plug(ψ) == State() || throw("unimplemented") -include("MP.jl") -include("PEP.jl") + siteindex = inds(ψ, :out, site) + tensor = only(select(tn, siteindex)) + sum(tensor, inds = setdiff(inds(tensor), [siteindex])) +end diff --git a/src/Tenet.jl b/src/Tenet.jl index b9b92b07..49ec6486 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -12,19 +12,14 @@ include("Numerics.jl") include("TensorNetwork.jl") export TensorNetwork, tensors, arrays, select, slice! export contract, contract! -export Ansatz, ansatz, Arbitrary include("Transformations.jl") export transform, transform! include("Quantum/Quantum.jl") -export Quantum +export QuantumTensorNetwork, sites, fidelity +export Plug, plug, Property, State, Dual, Operator export Boundary, boundary, Open, Periodic, Infinite -export Plug, plug, Property, State, Operator -export sites, fidelity - -export MatrixProduct, MPS, MPO -export ProjectedEntangledPair, PEPS, PEPO # reexports from LinearAlgebra export norm, normalize! diff --git a/test/Project.toml b/test/Project.toml index 32241a30..f186cef7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,5 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/test/Quantum_test.jl b/test/Quantum_test.jl index 451ca7d6..b5f49978 100644 --- a/test/Quantum_test.jl +++ b/test/Quantum_test.jl @@ -1,50 +1,38 @@ @testset "Quantum" begin - using Bijections - - struct MockState <: Quantum end - Tenet.plug(::Type{MockState}) = State - Tenet.metadata(::Type{MockState}) = Tenet.metadata(Quantum) - - struct MockOperator <: Quantum end - Tenet.plug(::Type{MockOperator}) = Operator - Tenet.metadata(::Type{MockOperator}) = Tenet.metadata(Quantum) - - state = TensorNetwork{MockState}( - [Tensor(rand(2, 2), (:i, :k)), Tensor(rand(3, 2, 4), (:j, :k, :l))]; - plug = State, - interlayer = [Bijection(Dict([1 => :i, 2 => :j]))], + state = QuantumTensorNetwork( + TensorNetwork(Tensor[Tensor(rand(2, 2), (:i, :k)), Tensor(rand(3, 2, 4), (:j, :k, :l))]), + Symbol[], # input + [:i, :j], # output ) - operator = TensorNetwork{MockOperator}( - [Tensor(rand(2, 4, 2), (:a, :c, :d)), Tensor(rand(3, 4, 3, 5), (:b, :c, :e, :f))]; - plug = Operator, - interlayer = [Bijection(Dict([1 => :a, 2 => :b])), Bijection(Dict([1 => :d, 2 => :e]))], + operator = QuantumTensorNetwork( + TensorNetwork(Tensor[Tensor(rand(2, 4, 2), (:a, :c, :d)), Tensor(rand(3, 4, 3, 5), (:b, :c, :e, :f))]), + [:a, :b], # input + [:d, :e], # output ) - @testset "metadata" begin + @testset "adjoint" begin @testset "State" begin - @test Tenet.checkmeta(state) - @test hasproperty(state, :interlayer) - @test only(state.interlayer) == Bijection(Dict([1 => :i, 2 => :j])) + adj = adjoint(state) + @test adj.input == state.output + @test adj.output == state.input + @test all(((a, b),) -> a == conj(b), zip(tensors(state), tensors(adj))) end @testset "Operator" begin - @test Tenet.checkmeta(operator) - @test hasproperty(operator, :interlayer) - @test operator.interlayer == [Bijection(Dict([1 => :a, 2 => :b])), Bijection(Dict([1 => :d, 2 => :e]))] + adj = adjoint(operator) + @test adj.input == operator.output + @test adj.output == operator.input + @test all(((a, b),) -> a == conj(b), zip(tensors(operator), tensors(adj))) end end @testset "plug" begin - @test plug(state) === State - - @test plug(operator) === Operator + @test plug(state) == State() + @test plug(state') == Dual() + @test plug(operator) == Operator() end - # TODO write tests for - # - boundary - # - tensors - @testset "sites" begin @test issetequal(sites(state), [1, 2]) @test issetequal(sites(operator), [1, 2]) @@ -54,88 +42,111 @@ @testset "State" begin @test issetequal(inds(state), [:i, :j, :k, :l]) @test issetequal(inds(state, set = :open), [:i, :j, :l]) - @test issetequal(inds(state, set = :plug), [:i, :j]) @test issetequal(inds(state, set = :inner), [:k]) @test isempty(inds(state, set = :hyper)) + @test isempty(inds(state, set = :in)) + @test issetequal(inds(state, set = :out), [:i, :j]) + @test issetequal(inds(state, set = :physical), [:i, :j]) @test issetequal(inds(state, set = :virtual), [:k, :l]) end - # TODO change the indices @testset "Operator" begin @test issetequal(inds(operator), [:a, :b, :c, :d, :e, :f]) @test issetequal(inds(operator, set = :open), [:a, :b, :d, :e, :f]) - @test issetequal(inds(operator, set = :plug), [:a, :b, :d, :e]) @test issetequal(inds(operator, set = :inner), [:c]) @test isempty(inds(operator, set = :hyper)) - @test_broken issetequal(inds(operator, set = :virtual), [:c]) + @test issetequal(inds(operator, set = :in), [:a, :b]) + @test issetequal(inds(operator, set = :out), [:d, :e]) + @test issetequal(inds(operator, set = :physical), [:a, :b, :d, :e]) + @test issetequal(inds(operator, set = :virtual), [:c, :f]) end end - @testset "adjoint" begin - @testset "State" begin - adj = adjoint(state) + @testset "merge" begin + @testset "(State, State)" begin + tn = merge(state, state') - @test issetequal(sites(state), sites(adj)) - @test all(i -> inds(state, :plug, i) == inds(adj, :plug, i), sites(state)) - end + @test plug(tn) == Property() - @testset "Operator" begin - adj = adjoint(operator) + @test isempty(sites(tn, :in)) + @test isempty(sites(tn, :out)) - @test issetequal(sites(operator), sites(adj)) - @test_broken all(i -> inds(operator, :plug, i) == inds(adj, :plug, i), sites(operator)) - @test all(i -> first(operator.interlayer)[i] == last(adj.interlayer)[i], sites(operator)) - @test all(i -> last(operator.interlayer)[i] == first(adj.interlayer)[i], sites(operator)) - end - end - - @testset "hcat" begin - @testset "(State, State)" begin - expectation = hcat(state, state) - @test issetequal(sites(expectation), sites(state)) - @test issetequal(inds(expectation, set = :plug), inds(state, set = :plug)) - @test isempty(inds(expectation, set = :open)) - @test issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) + @test isempty(inds(tn, set = :in)) + @test isempty(inds(tn, set = :out)) + @test isempty(inds(tn, set = :physical)) + @test issetequal(inds(tn), inds(tn, set = :virtual)) end @testset "(State, Operator)" begin - expectation = hcat(state, operator) - @test issetequal(sites(expectation), sites(state)) - @test_broken issetequal(inds(expectation, set = :plug), inds(operator, set = :plug)) - @test_broken isempty(inds(expectation, set = :open)) - @test_broken issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) + tn = merge(state, operator) + + @test plug(tn) == State() + + @test isempty(sites(tn, :in)) + @test issetequal(sites(tn, :out), sites(operator, :out)) + + @test isempty(inds(tn, set = :in)) + @test issetequal(inds(tn, set = :out), inds(operator, :out)) + @test issetequal(inds(tn, set = :physical), inds(operator, :out)) + @test issetequal(inds(tn, set = :virtual), inds(state) ∪ inds(operator, :virtual)) end @testset "(Operator, State)" begin - expectation = hcat(operator, state) - @test issetequal(sites(expectation), sites(state)) - @test_broken issetequal(inds(expectation, set = :plug), inds(state, set = :plug)) - @test_broken isempty(inds(expectation, set = :open)) - @test_broken issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) + tn = merge(operator, state') + + @test plug(tn) == Dual() + + @test issetequal(sites(tn, :in), sites(operator, :in)) + @test isempty(sites(tn, :out)) + + @test issetequal(inds(tn, set = :in), inds(operator, :in)) + @test isempty(inds(tn, set = :out)) + @test issetequal(inds(tn, set = :physical), inds(operator, :in)) + @test issetequal( + inds(tn, set = :virtual), + inds(state, :virtual) ∪ inds(operator, :virtual) ∪ inds(operator, :out), + ) end @testset "(Operator, Operator)" begin - expectation = hcat(operator, operator) - @test issetequal(sites(expectation), sites(state)) - @test issetequal(inds(expectation, set = :plug), inds(operator, set = :plug)) - @test isempty(inds(expectation, set = :open)) - @test issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) + tn = merge(operator, operator') + + @test plug(tn) == Operator() + + @test issetequal(sites(tn, :in), sites(operator, :in)) + @test issetequal(sites(tn, :out), sites(operator, :in)) + + @test issetequal(inds(tn, set = :in), inds(operator, :in)) + @test issetequal(inds(tn, set = :out), inds(operator, :in)) + @test issetequal(inds(tn, set = :physical), inds(operator, :in)) + @test inds(operator, :virtual) ⊆ inds(tn, set = :virtual) end - # @testset "(State, Operator, State)" begin - # expectation = hcat(state, operator, state') - # @test_broken issetequal(sites(expectation), sites(state)) - # @test_broken issetequal(inds(expectation, set = :plug), inds(operator, set = :plug)) - # @test_broken isempty(inds(expectation, set = :open)) - # @test_broken issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) - # end - - # @testset "(Operator, Operator, Operator)" begin - # expectation = hcat(operator, operator, operator) - # @test_broken issetequal(sites(expectation), sites(state)) - # @test_broken issetequal(inds(expectation, set = :plug), inds(operator, set = :plug)) - # @test_broken isempty(inds(expectation, set = :open)) - # @test_broken issetequal(inds(expectation, set = :inner), inds(expectation, set = :all)) - # end + @testset "(Operator, Operator)" begin + tn = merge(operator', operator) + + @test plug(tn) == Operator() + + @test issetequal(sites(tn, :in), sites(operator, :out)) + @test issetequal(sites(tn, :out), sites(operator, :out)) + + @test issetequal(inds(tn, set = :in), inds(operator, :out)) + @test issetequal(inds(tn, set = :out), inds(operator, :out)) + @test issetequal(inds(tn, set = :physical), inds(operator, :out)) + @test inds(operator, :virtual) ⊆ inds(tn, set = :virtual) + end + + @testset "(State, Operator, State)" begin + tn = merge(state, operator, state') + + @test plug(tn) == Property() + + @test isempty(sites(tn, :in)) + @test isempty(sites(tn, :out)) + + @test isempty(inds(tn, set = :in)) + @test isempty(inds(tn, set = :out)) + @test isempty(inds(tn, set = :physical)) + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 032113af..f14a7ab0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,23 +2,26 @@ using Test using Tenet using OMEinsum -@testset "Unit tests" verbose = true begin +@testset "Core tests" verbose = true begin include("Helpers_test.jl") include("Tensor_test.jl") include("Numerics_test.jl") include("TensorNetwork_test.jl") - include("Quantum_test.jl") include("Transformations_test.jl") +end + +@testset "Quantum tests" verbose = true begin + include("Quantum_test.jl") # Ansatz Tensor Networks - include("MatrixProductState_test.jl") - include("MatrixProductOperator_test.jl") + # include("MatrixProductState_test.jl") + # include("MatrixProductOperator_test.jl") end @testset "Integration tests" verbose = true begin include("integration/ChainRules_test.jl") include("integration/BlockArray_test.jl") - include("integration/Quac_test.jl") + # include("integration/Quac_test.jl") include("integration/Makie_test.jl") end From dac1cb84fe4993e8276bb0a0d910966949f0ceb3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 8 Oct 2023 01:13:29 +0200 Subject: [PATCH 13/29] Refactor `TNSampler` to new OOP architecture --- src/TensorNetwork.jl | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index dc869df2..d38bb338 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -429,17 +429,16 @@ contract!(tn::absclass(TensorNetwork), t::Tensor; kwargs...) = (push!(tn, t); co contract(t::Tensor, tn::absclass(TensorNetwork); kwargs...) = contract(tn, t; kwargs...) contract(tn::absclass(TensorNetwork), t::Tensor; kwargs...) = contract!(copy(tn), t; kwargs...) -# struct TNSampler{A<:Ansatz,NT<:NamedTuple} <: Random.Sampler{TensorNetwork{A}} -# parameters::NT +struct TNSampler{T<:absclass(TensorNetwork)} <: Random.Sampler{T} + config::Dict{Symbol,Any} -# TNSampler{A}(; kwargs...) where {A} = new{A,typeof(values(kwargs))}(values(kwargs)) -# end + TNSampler{T}(; kwargs...) where {T} = new{T}(kwargs) +end -# Base.getproperty(obj::TNSampler{A,<:NamedTuple{K}}, name::Symbol) where {A,K} = -# name ∈ K ? getfield(obj, :parameters)[name] : getfield(obj, name) -# Base.get(obj::TNSampler, name, default) = get(getfield(obj, :parameters), name, default) +Base.eltype(::TNSampler{T}) where {T} = T -# Base.eltype(::TNSampler{A}) where {A<:Ansatz} = TensorNetwork{A} +Base.getproperty(obj::TNSampler, name::Symbol) = name === :config ? getfield(obj, :config) : obj.config[name] +Base.get(obj::TNSampler, name, default) = get(obj.config, name, default) -# Base.rand(A::Type{<:Ansatz}; kwargs...) = rand(Random.default_rng(), A; kwargs...) -# Base.rand(rng::AbstractRNG, ::Type{A}; kwargs...) where {A<:Ansatz} = rand(rng, TNSampler{A}(; kwargs...)) +Base.rand(T::Type{<:absclass(TensorNetwork)}; kwargs...) = rand(Random.default_rng(), T; kwargs...) +Base.rand(rng::AbstractRNG, T::Type{<:absclass(TensorNetwork)}; kwargs...) = rand(rng, TNSampler{T}(; kwargs...)) From 7bcc2f7d23d8c7e7d2c3f9d902788c406cc4bcd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Sun, 8 Oct 2023 01:51:00 +0200 Subject: [PATCH 14/29] Refactor `MatrixProduct` --- src/Quantum/MP.jl | 86 +++++++++++------------------- src/Tenet.jl | 3 ++ test/MatrixProductOperator_test.jl | 51 ++++++++---------- test/MatrixProductState_test.jl | 69 +++++++++++------------- test/runtests.jl | 6 +-- 5 files changed, 92 insertions(+), 123 deletions(-) diff --git a/src/Quantum/MP.jl b/src/Quantum/MP.jl index f9dba7b5..a1118e44 100644 --- a/src/Quantum/MP.jl +++ b/src/Quantum/MP.jl @@ -1,44 +1,31 @@ using UUIDs: uuid4 using Base.Iterators: flatten using Random -using Bijections using Muscle: gramschmidt! using EinExprs: inds +using Classes """ MatrixProduct{P<:Plug,B<:Boundary} <: Quantum A generic ansatz representing Matrix Product State (MPS) and Matrix Product Operator (MPO) topology, aka Tensor Train. Type variable `P` represents the `Plug` type (`State` or `Operator`) and `B` represents the `Boundary` type (`Open` or `Periodic`). - -# Ansatz Fields - - - `χ::Union{Nothing,Int}` Maximum virtual bond dimension. """ -abstract type MatrixProduct{P,B} <: Quantum where {P<:Plug,B<:Boundary} end - -boundary(::Type{<:MatrixProduct{P,B}}) where {P,B} = B -plug(::Type{<:MatrixProduct{P}}) where {P} = P +@class MatrixProduct{P<:Plug,B<:Boundary} <: QuantumTensorNetwork function MatrixProduct{P}(arrays; boundary::Type{<:Boundary} = Open, kwargs...) where {P<:Plug} MatrixProduct{P,boundary}(arrays; kwargs...) end -metadata(::Type{<:MatrixProduct}) = merge(metadata(supertype(MatrixProduct)), @NamedTuple begin - χ::Union{Nothing,Int} -end) - -function checkmeta(::Type{MatrixProduct{P,B}}, tn::TensorNetwork) where {P,B} - # meta has correct type - isnothing(tn.χ) || tn.χ > 0 || return false - - # no virtual index has dimensionality bigger than χ - all(i -> isnothing(tn.χ) || size(tn, i) <= tn.χ, inds(tn, :virtual)) || return false +const MPS = MatrixProduct{State} +const MPO = MatrixProduct{Operator} - return true -end +plug(::T) where {T<:absclass(MatrixProduct)} = plug(T) +plug(::Type{<:MatrixProduct{P}}) where {P} = P() +boundary(::T) where {T<:absclass(MatrixProduct)} = boundary(T) +boundary(::Type{<:MatrixProduct{P,B}}) where {P,B} = B() -_sitealias(::Type{MatrixProduct{P,Open}}, order, n, i) where {P<:Plug} = +sitealias(::Type{MatrixProduct{P,Open}}, order, n, i) where {P<:Plug} = if i == 1 filter(!=(:l), order) elseif i == n @@ -46,28 +33,23 @@ _sitealias(::Type{MatrixProduct{P,Open}}, order, n, i) where {P<:Plug} = else order end -_sitealias(::Type{MatrixProduct{P,Periodic}}, order, n, i) where {P<:Plug} = tuple(order...) -_sitealias(::Type{MatrixProduct{P,Infinite}}, order, n, i) where {P<:Plug} = tuple(order...) +sitealias(::Type{MatrixProduct{P,Periodic}}, order, n, i) where {P<:Plug} = tuple(order...) +sitealias(::Type{MatrixProduct{P,Infinite}}, order, n, i) where {P<:Plug} = tuple(order...) -defaultorder(::Type{MatrixProduct{State}}) = (:l, :r, :o) -defaultorder(::Type{MatrixProduct{Operator}}) = (:l, :r, :i, :o) +defaultorder(::Type{<:MatrixProduct{Property}}) = (:l, :r) +defaultorder(::Type{<:MatrixProduct{State}}) = (:l, :r, :o) +defaultorder(::Type{<:MatrixProduct{Operator}}) = (:l, :r, :i, :o) """ - MatrixProduct{P,B}(arrays::AbstractArray[]; χ::Union{Nothing,Int} = nothing, order = defaultorder(MatrixProduct{P})) + MatrixProduct{P,B}(arrays::AbstractArray[]; order = defaultorder(MatrixProduct{P})) Construct a [`TensorNetwork`](@ref) with [`MatrixProduct`](@ref) ansatz, from the arrays of the tensors. # Keyword Arguments - - `χ` Maximum virtual bond dimension. Defaults to `nothing`. - `order` Order of tensor indices on `arrays`. Defaults to `(:l, :r, :o)` if `P` is a `State`, `(:l, :r, :i, :o)` if `Operator`. """ -function MatrixProduct{P,B}( - arrays; - χ = nothing, - order = defaultorder(MatrixProduct{P}), - metadata..., -) where {P<:Plug,B<:Boundary} +function MatrixProduct{P,B}(arrays; order = defaultorder(MatrixProduct{P})) where {P<:Plug,B<:Boundary} issetequal(order, defaultorder(MatrixProduct{P})) || throw( ArgumentError( "`order` must be a permutation of $(join(String.(defaultorder(MatrixProduct{P})), ',', " and "))", @@ -76,19 +58,21 @@ function MatrixProduct{P,B}( n = length(arrays) vinds = Dict(x => Symbol(uuid4()) for x in ringpeek(1:n)) - oinds = Dict(i => Symbol(uuid4()) for i in 1:n) - iinds = Dict(i => Symbol(uuid4()) for i in 1:n) + oinds = map(_ -> Symbol(uuid4()), 1:n) + iinds = map(_ -> Symbol(uuid4()), 1:n) - interlayer = if P <: State - [Bijection(oinds)] + input, output = if P <: Property + Symbol[], Symbol[] + elseif P <: State + Symbol[], oinds elseif P <: Operator - [Bijection(iinds), Bijection(oinds)] + iinds, oinds else - throw(ErrorException("Plug $P is not valid")) + throw(ArgumentError("Plug $P is not valid")) end - tensors = map(enumerate(arrays)) do (i, array) - dirs = _sitealias(MatrixProduct{P,B}, order, n, i) + tensors::Vector{Tensor} = map(enumerate(arrays)) do (i, array) + dirs = sitealias(MatrixProduct{P,B}, order, n, i) inds = map(dirs) do dir if dir === :l @@ -105,15 +89,9 @@ function MatrixProduct{P,B}( Tensor(array, inds) end - return TensorNetwork{MatrixProduct{P,B}}(tensors; χ, plug = P, interlayer, metadata...) + return MatrixProduct{P,B}(QuantumTensorNetwork(TensorNetwork(tensors), input, output)) end -const MPS = MatrixProduct{State} -const MPO = MatrixProduct{Operator} - -tensors(ψ::TensorNetwork{MatrixProduct{P,Infinite}}, site::Int, args...) where {P<:Plug} = - tensors(plug(ψ), ψ, mod1(site, length(ψ.tensors)), args...) - # NOTE does not use optimal contraction path, but "parallel-optimal" which costs x2 more # function contractpath(a::TensorNetwork{<:MatrixProductState}, b::TensorNetwork{<:MatrixProductState}) # !issetequal(sites(a), sites(b)) && throw(ArgumentError("both tensor networks are expected to have same sites")) @@ -134,7 +112,7 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{Sta p = get(sampler, :p, 2) T = get(sampler, :eltype, Float64) - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + arrays::Vector{AbstractArray{T}} = map(1:n) do i χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 χl = min(χ, p^(i - 1)) χr = min(χ, p^i) @@ -159,7 +137,7 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{Sta # normalize state arrays[1] ./= sqrt(p) - MatrixProduct{State,Open}(arrays; χ = χ) + MatrixProduct{State,Open}(arrays) end # TODO let choose the orthogonality center @@ -172,7 +150,7 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{Ope ip = op = p - arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + arrays::Vector{AbstractArray{T}} = map(1:n) do i χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 χl = min(χ, ip^(i - 1) * op^(i - 1)) χr = min(χ, ip^i * op^i) @@ -199,7 +177,7 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{Ope ζ = min(χ, ip * op) arrays[1] ./= sqrt(ζ) - MatrixProduct{Operator,Open}(arrays; χ = χ) + MatrixProduct{Operator,Open}(arrays) end # TODO stable renormalization @@ -210,7 +188,7 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{P,P p = get(sampler, :p, 2) T = get(sampler, :eltype, Float64) - A = MatrixProduct{P,Periodic}([rand(rng, T, [P === State ? (χ, χ, p) : (χ, χ, p, p)]...) for _ in 1:n]; χ = χ) + A = MatrixProduct{P,Periodic}([rand(rng, T, [P === State ? (χ, χ, p) : (χ, χ, p, p)]...) for _ in 1:n]) normalize!(A) return A diff --git a/src/Tenet.jl b/src/Tenet.jl index 49ec6486..0b7bc146 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -21,6 +21,9 @@ export QuantumTensorNetwork, sites, fidelity export Plug, plug, Property, State, Dual, Operator export Boundary, boundary, Open, Periodic, Infinite +include("Quantum/MP.jl") +export MatrixProduct, MPS, MPO + # reexports from LinearAlgebra export norm, normalize! diff --git a/test/MatrixProductOperator_test.jl b/test/MatrixProductOperator_test.jl index fa71fbb7..ca25dd47 100644 --- a/test/MatrixProductOperator_test.jl +++ b/test/MatrixProductOperator_test.jl @@ -1,13 +1,11 @@ @testset "MatrixProduct{Operator}" begin - using Tenet: Operator, Composite - @testset "plug" begin - @test plug(MatrixProduct{Operator}) === Operator - @test all(T -> plug(MatrixProduct{Operator,T}) === Operator, [Open, Periodic]) + @test plug(MatrixProduct{Operator}) === Operator() + @test all(T -> plug(MatrixProduct{Operator,T}) === Operator(), [Open, Periodic]) end @testset "boundary" begin - @test all(B -> boundary(MatrixProduct{Operator,B}) == B, [Open, Periodic]) + @test all(B -> boundary(MatrixProduct{Operator,B}) == B(), [Open, Periodic]) end @testset "Constructor" begin @@ -16,25 +14,25 @@ @test begin arrays = [rand(2, 2, 2)] - MatrixProduct{Operator}(arrays) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator}(arrays) isa MPO{Open} end @test begin arrays = [rand(2, 2, 2), rand(2, 2, 2)] - MatrixProduct{Operator}(arrays) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator}(arrays) isa MPO{Open} end @testset "`Open` boundary" begin # product operator @test begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Open}(arrays) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator,Open}(arrays) isa MPO{Open} end # alternative constructor @test begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Open) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator}(arrays; boundary = Open) isa MPO{Open} end # entangling operator @@ -42,7 +40,7 @@ i = 3 o = 5 arrays = [rand(2, i, o), rand(2, 4, i, o), rand(4, i, o)] - MatrixProduct{Operator,Open}(arrays) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator,Open}(arrays) isa MPO{Open} end # entangling operator - change order @@ -50,14 +48,13 @@ i = 3 o = 5 arrays = [rand(i, 2, o), rand(2, i, 4, o), rand(4, i, o)] - MatrixProduct{Operator,Open}(arrays, order = (:l, :i, :r, :o)) isa - TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator,Open}(arrays, order = (:l, :i, :r, :o)) isa MPO{Open} end # fail on Open with Periodic format @test_throws MethodError begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Open}(arrays) isa TensorNetwork{MatrixProduct{Operator,Open}} + MatrixProduct{Operator,Open}(arrays) isa MPO{Open} end end @@ -65,13 +62,13 @@ # product operator @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Periodic}(arrays) isa TensorNetwork{MatrixProduct{Operator,Periodic}} + MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} end # alternative constructor @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Periodic) isa TensorNetwork{MatrixProduct{Operator,Periodic}} + MatrixProduct{Operator}(arrays; boundary = Periodic) isa MPO{Periodic} end # entangling operator @@ -79,7 +76,7 @@ i = 3 o = 5 arrays = [rand(2, 4, i, o), rand(4, 8, i, o), rand(8, 2, i, o)] - MatrixProduct{Operator,Periodic}(arrays) isa TensorNetwork{MatrixProduct{Operator,Periodic}} + MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} end # entangling operator - change order @@ -87,14 +84,13 @@ i = 3 o = 5 arrays = [rand(2, i, 4, o), rand(4, i, 8, o), rand(8, i, 2, o)] - MatrixProduct{Operator,Periodic}(arrays, order = (:l, :i, :r, :o)) isa - TensorNetwork{MatrixProduct{Operator,Periodic}} + MatrixProduct{Operator,Periodic}(arrays, order = (:l, :i, :r, :o)) isa MPO{Periodic} end # fail on Periodic with Open format @test_throws MethodError begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Periodic}(arrays) isa TensorNetwork{MatrixProduct{Operator,Periodic}} + MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} end end @@ -102,13 +98,13 @@ # product operator @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Infinite}(arrays) isa TensorNetwork{MatrixProduct{Operator,Infinite}} + MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} end # alternative constructor @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Infinite) isa TensorNetwork{MatrixProduct{Operator,Infinite}} + MatrixProduct{Operator}(arrays; boundary = Infinite) isa MPO{Infinite} end # entangling operator @@ -116,7 +112,7 @@ i = 3 o = 5 arrays = [rand(2, 4, i, o), rand(4, 8, i, o), rand(8, 2, i, o)] - MatrixProduct{Operator,Infinite}(arrays) isa TensorNetwork{MatrixProduct{Operator,Infinite}} + MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} end # entangling operator - change order @@ -124,14 +120,13 @@ i = 3 o = 5 arrays = [rand(2, i, 4, o), rand(4, i, 8, o), rand(8, i, 2, o)] - MatrixProduct{Operator,Infinite}(arrays, order = (:l, :i, :r, :o)) isa - TensorNetwork{MatrixProduct{Operator,Infinite}} + MatrixProduct{Operator,Infinite}(arrays, order = (:l, :i, :r, :o)) isa MPO{Infinite} end # fail on Infinite with Open format @test_throws MethodError begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Infinite}(arrays) isa TensorNetwork{MatrixProduct{Operator,Infinite}} + MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} end @testset "metadata" begin @@ -151,7 +146,7 @@ mps = MatrixProduct{State,Open}(arrays) arrays_o = [rand(2, 2, 2), rand(2, 2, 2)] mpo = MatrixProduct{Operator}(arrays_o) - hcat(mps, mpo) isa TensorNetwork{<:Composite} + merge(mps, mpo) isa QuantumTensorNetwork end @test begin @@ -159,13 +154,13 @@ mps = MatrixProduct{State,Open}(arrays) arrays_o = [rand(2, 2, 2), rand(2, 2, 2)] mpo = MatrixProduct{Operator}(arrays_o) - hcat(mpo, mps) isa TensorNetwork{<:Composite} + merge(mpo, mps') isa QuantumTensorNetwork end @test begin arrays = [rand(2, 2, 2), rand(2, 2, 2)] mpo = MatrixProduct{Operator}(arrays) - hcat(mpo, mpo) isa TensorNetwork{<:Composite} + merge(mpo, mpo') isa QuantumTensorNetwork end end diff --git a/test/MatrixProductState_test.jl b/test/MatrixProductState_test.jl index 3d714bea..80f96d58 100644 --- a/test/MatrixProductState_test.jl +++ b/test/MatrixProductState_test.jl @@ -1,13 +1,11 @@ @testset "MatrixProduct{State}" begin - using Tenet: Composite - @testset "plug" begin - @test plug(MatrixProduct{State}) === State - @test all(T -> plug(MatrixProduct{State,T}) === State, [Open, Periodic]) + @test plug(MatrixProduct{State}) == State() + @test all(T -> plug(MatrixProduct{State,T}) == State(), [Open, Periodic]) end @testset "boundary" begin - @test all(B -> boundary(MatrixProduct{State,B}) == B, [Open, Periodic]) + @test all(B -> boundary(MatrixProduct{State,B}) == B(), [Open, Periodic]) end @testset "Constructor" begin @@ -16,44 +14,44 @@ @test begin arrays = [rand(1, 2)] - MatrixProduct{State}(arrays) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State}(arrays) isa MPS{Open} end @test begin arrays = [rand(1, 2), rand(1, 2)] - MatrixProduct{State}(arrays) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State}(arrays) isa MPS{Open} end @testset "`Open` boundary" begin # product state @test begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Open}(arrays) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State,Open}(arrays) isa MPS{Open} end # entangled state @test begin arrays = [rand(2, 2), rand(2, 4, 2), rand(4, 1, 2), rand(1, 2)] - MatrixProduct{State,Open}(arrays) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State,Open}(arrays) isa MPS{Open} end @testset "custom order" begin arrays = [rand(3, 1), rand(3, 1, 3), rand(1, 3)] ψ = MatrixProduct{State,Open}(arrays, order = (:r, :o, :l)) - @test ψ isa TensorNetwork{MatrixProduct{State,Open}} + @test ψ isa MPS{Open} end # alternative constructor @test begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State}(arrays; boundary = Open) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State}(arrays; boundary = Open) isa MPS{Open} end # fail on Open with Periodic format @test_throws Exception begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Open}(arrays) isa TensorNetwork{MatrixProduct{State,Open}} + MatrixProduct{State,Open}(arrays) isa MPS{Open} end @testset "rand" begin @@ -62,8 +60,8 @@ @testset "χ = $χ" for χ in [4, 32] ψ = rand(MatrixProduct{State,Open}, n = 7, p = 2, χ = χ) - @test ψ isa TensorNetwork{MatrixProduct{State,Open}} - @test length(ψ) == 7 + @test ψ isa MPS{Open} + @test length(tensors(ψ)) == 7 @test maximum(vind -> size(ψ, vind), inds(ψ, :inner)) <= 32 end end @@ -73,32 +71,32 @@ # product state @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Periodic}(arrays) isa TensorNetwork{MatrixProduct{State,Periodic}} + MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} end # entangled state @test begin arrays = [rand(3, 4, 2), rand(4, 8, 2), rand(8, 3, 2)] - MatrixProduct{State,Periodic}(arrays) isa TensorNetwork{MatrixProduct{State,Periodic}} + MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} end @testset "custom order" begin arrays = [rand(3, 1, 3), rand(3, 1, 3), rand(3, 1, 3)] ψ = MatrixProduct{State,Periodic}(arrays, order = (:r, :o, :l)) - @test ψ isa TensorNetwork{MatrixProduct{State,Periodic}} + @test ψ isa MPS{Periodic} end # alternative constructor @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State}(arrays; boundary = Periodic) isa TensorNetwork{MatrixProduct{State,Periodic}} + MatrixProduct{State}(arrays; boundary = Periodic) isa MPS{Periodic} end # fail on Periodic with Open format @test_throws Exception begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Periodic}(arrays) isa TensorNetwork{MatrixProduct{State,Periodic}} + MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} end end @@ -106,59 +104,56 @@ # product state @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Infinite}(arrays) isa TensorNetwork{MatrixProduct{State,Infinite}} + MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} end # entangled state @test begin arrays = [rand(3, 4, 2), rand(4, 8, 2), rand(8, 3, 2)] - MatrixProduct{State,Infinite}(arrays) isa TensorNetwork{MatrixProduct{State,Infinite}} + MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} end @testset "custom order" begin arrays = [rand(3, 1, 3), rand(3, 1, 3), rand(3, 1, 3)] ψ = MatrixProduct{State,Infinite}(arrays, order = (:r, :o, :l)) - @test ψ isa TensorNetwork{MatrixProduct{State,Infinite}} + @test ψ isa MPS{Infinite} end # alternative constructor @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State}(arrays; boundary = Infinite) isa TensorNetwork{MatrixProduct{State,Infinite}} + MatrixProduct{State}(arrays; boundary = Infinite) isa MPS{Infinite} end # fail on Infinite with Open format @test_throws Exception begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Infinite}(arrays) isa TensorNetwork{MatrixProduct{State,Infinite}} + MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} end - @testset "metadata" begin - @testset "tensors" begin - arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - ψ = MatrixProduct{State,Infinite}(arrays, order = (:l, :r, :o)) + # @testset "tensors" begin + # arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] + # ψ = MatrixProduct{State,Infinite}(arrays, order = (:l, :r, :o)) - @test tensors(ψ, 1) isa Tensor - @test length(ψ) == Inf - @test tensors(ψ, 4) == tensors(ψ, 1) - @test tensors(ψ, 0) == tensors(ψ, 3) - end - end + # @test tensors(ψ, 1) isa Tensor + # @test tensors(ψ, 4) == tensors(ψ, 1) + # @test tensors(ψ, 0) == tensors(ψ, 3) + # end end end - @testset "hcat" begin + @testset "merge" begin @test begin arrays = [rand(2, 2), rand(2, 2)] mps = MatrixProduct{State,Open}(arrays) - hcat(mps, mps) isa TensorNetwork{<:Composite} + merge(mps, mps') isa QuantumTensorNetwork end @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2)] mps = MatrixProduct{State,Periodic}(arrays) - hcat(mps, mps) isa TensorNetwork{<:Composite} + merge(mps, mps') isa QuantumTensorNetwork end end diff --git a/test/runtests.jl b/test/runtests.jl index f14a7ab0..fb521a4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,10 +12,8 @@ end @testset "Quantum tests" verbose = true begin include("Quantum_test.jl") - - # Ansatz Tensor Networks - # include("MatrixProductState_test.jl") - # include("MatrixProductOperator_test.jl") + include("MatrixProductState_test.jl") + include("MatrixProductOperator_test.jl") end @testset "Integration tests" verbose = true begin From 462a1617ee7f86eea297dedf1c3e5f852eb0173c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 9 Oct 2023 14:32:07 +0200 Subject: [PATCH 15/29] Refactor `Quac` extension --- ext/TenetQuacExt.jl | 15 ++++++--------- test/integration/Quac_test.jl | 30 +++++++++++++----------------- test/runtests.jl | 2 +- 3 files changed, 20 insertions(+), 27 deletions(-) diff --git a/ext/TenetQuacExt.jl b/ext/TenetQuacExt.jl index 83e1dfb7..3879fdbf 100644 --- a/ext/TenetQuacExt.jl +++ b/ext/TenetQuacExt.jl @@ -2,12 +2,11 @@ module TenetQuacExt using Tenet using Quac: Circuit, lanes, arraytype, Swap -using Bijections -function Tenet.TensorNetwork(circuit::Circuit) +function Tenet.QuantumTensorNetwork(circuit::Circuit) n = lanes(circuit) wire = [[Tenet.letter(i)] for i in 1:n] - tensors = Tensor[] + tn = TensorNetwork() i = n + 1 @@ -29,15 +28,13 @@ function Tenet.TensorNetwork(circuit::Circuit) end |> x -> zip(x...) |> Iterators.flatten |> collect tensor = Tensor(array, inds) - push!(tensors, tensor) + push!(tn, tensor) end - interlayer = [ - Bijection(Dict([site => first(index) for (site, index) in enumerate(wire)])), - Bijection(Dict([site => last(index) for (site, index) in enumerate(wire)])), - ] + input = first.(wire) + output = last.(wire) - return TensorNetwork{Quantum}(tensors; plug = Tenet.Operator, interlayer) + return QuantumTensorNetwork(tn, input, output) end end diff --git a/test/integration/Quac_test.jl b/test/integration/Quac_test.jl index 9714613c..1b5179ba 100644 --- a/test/integration/Quac_test.jl +++ b/test/integration/Quac_test.jl @@ -1,29 +1,25 @@ @testset "Quac" begin - using Tenet: TensorNetwork, ansatz, Quantum, sites using Quac - n = 2 - qft = Quac.Algorithms.QFT(n) + using UUIDs: uuid4 @testset "Constructor" begin - tn = TensorNetwork(qft) - - @test ansatz(tn) == Quantum - @test tn isa TensorNetwork{Quantum} + n = 2 + qft = Quac.Algorithms.QFT(n) + tn = QuantumTensorNetwork(qft) + @test tn isa QuantumTensorNetwork @test issetequal(sites(tn), 1:n) end # TODO currently broken - # @testset "hcat" begin - # n = 2 - # qft = Quac.Algorithms.QFT(n) - # tn = TensorNetwork(qft) - - # newtn = hcat(tn, tn) + @testset "merge" begin + n = 2 + qft = QuantumTensorNetwork(Quac.Algorithms.QFT(n)) + iqft = replace(qft, [index => Symbol(uuid4()) for index in inds(qft)]...) - # @test ansatz(newtn) <: Composite(Quantum, Quantum) - # @test issetequal(sites(newtn), 1:2) + tn = merge(qft, iqft) - # # TODO @test_throws ErrorException ... - # end + @test tn isa QuantumTensorNetwork + @test issetequal(sites(tn), 1:2) + end end diff --git a/test/runtests.jl b/test/runtests.jl index fb521a4a..f35ac9ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ end @testset "Integration tests" verbose = true begin include("integration/ChainRules_test.jl") include("integration/BlockArray_test.jl") - # include("integration/Quac_test.jl") + include("integration/Quac_test.jl") include("integration/Makie_test.jl") end From 2057a141f7614563341a4a24fd1c4a129098e512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 9 Oct 2023 14:32:39 +0200 Subject: [PATCH 16/29] Test changes for MPO --- test/MatrixProductOperator_test.jl | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/test/MatrixProductOperator_test.jl b/test/MatrixProductOperator_test.jl index ca25dd47..08441fbf 100644 --- a/test/MatrixProductOperator_test.jl +++ b/test/MatrixProductOperator_test.jl @@ -128,19 +128,10 @@ arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} end - - @testset "metadata" begin - @testset "tensors" begin - arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - ψ = MatrixProduct{Operator,Infinite}(arrays, order = (:l, :r, :i, :o)) - - @test length(ψ) == Inf - end - end end end - @testset "hcat" begin + @testset "merge" begin @test begin arrays = [rand(2, 2), rand(2, 2)] mps = MatrixProduct{State,Open}(arrays) @@ -166,7 +157,7 @@ @testset "norm" begin mpo = rand(MatrixProduct{Operator,Open}, n = 8, p = 2, χ = 8) - @test norm(mpo) ≈ 1 + @test_broken norm(mpo) ≈ 1 end # @testset "Initialization" begin From ebcab1a412624b1fa2e05046095053f0aef2f378 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 9 Oct 2023 14:32:52 +0200 Subject: [PATCH 17/29] Refactor `replace` --- src/TensorNetwork.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index d38bb338..60d7321b 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -181,17 +181,9 @@ Like [`pop!`](@ref) but return the [`TensorNetwork`](@ref) instead. """ Base.delete!(tn::absclass(TensorNetwork), x) = (_ = pop!(tn, x); tn) -""" - replace(tn::AbstractTensorNetwork, old => new...) - -Return a copy of the [`TensorNetwork`](@ref) where `old` has been replaced by `new`. - -See also: [`replace!`](@ref). -""" -Base.replace(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(copy(tn), old_new) - """ replace!(tn::AbstractTensorNetwork, old => new...) + replace(tn::AbstractTensorNetwork, old => new...) Replace the element in `old` with the one in `new`. Depending on the types of `old` and `new`, the following behaviour is expected: @@ -207,6 +199,8 @@ function Base.replace!(tn::absclass(TensorNetwork), old_new::Base.AbstractVecOrT end return tn end +Base.replace(tn::absclass(TensorNetwork), old_new::Pair...) = replace(tn, old_new) +Base.replace(tn::absclass(TensorNetwork), old_new) = replace!(copy(tn), old_new) function Base.replace!(tn::absclass(TensorNetwork), pair::Pair{<:Tensor,<:Tensor}) old_tensor, new_tensor = pair From f28bcf48f8a03105b35f4caf4b772cc007e01a05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Mon, 9 Oct 2023 17:56:03 +0200 Subject: [PATCH 18/29] Refactor things for which I'm too lazy to name --- src/Quantum/MP.jl | 14 ++-- src/Quantum/PEP.jl | 105 ++++++++++++--------------- src/Quantum/Quantum.jl | 52 +++++++++----- src/Tenet.jl | 3 + test/MatrixProductOperator_test.jl | 111 ++++++++++++----------------- test/MatrixProductState_test.jl | 108 ++++++++++++++-------------- 6 files changed, 191 insertions(+), 202 deletions(-) diff --git a/src/Quantum/MP.jl b/src/Quantum/MP.jl index a1118e44..52ca5cb7 100644 --- a/src/Quantum/MP.jl +++ b/src/Quantum/MP.jl @@ -6,12 +6,12 @@ using EinExprs: inds using Classes """ - MatrixProduct{P<:Plug,B<:Boundary} <: Quantum + MatrixProduct{P<:Plug,B<:Boundary} <: Ansatz A generic ansatz representing Matrix Product State (MPS) and Matrix Product Operator (MPO) topology, aka Tensor Train. Type variable `P` represents the `Plug` type (`State` or `Operator`) and `B` represents the `Boundary` type (`Open` or `Periodic`). """ -@class MatrixProduct{P<:Plug,B<:Boundary} <: QuantumTensorNetwork +struct MatrixProduct{P<:Plug,B<:Boundary} <: Ansatz end function MatrixProduct{P}(arrays; boundary::Type{<:Boundary} = Open, kwargs...) where {P<:Plug} MatrixProduct{P,boundary}(arrays; kwargs...) @@ -20,9 +20,7 @@ end const MPS = MatrixProduct{State} const MPO = MatrixProduct{Operator} -plug(::T) where {T<:absclass(MatrixProduct)} = plug(T) plug(::Type{<:MatrixProduct{P}}) where {P} = P() -boundary(::T) where {T<:absclass(MatrixProduct)} = boundary(T) boundary(::Type{<:MatrixProduct{P,B}}) where {P,B} = B() sitealias(::Type{MatrixProduct{P,Open}}, order, n, i) where {P<:Plug} = @@ -89,7 +87,7 @@ function MatrixProduct{P,B}(arrays; order = defaultorder(MatrixProduct{P})) wher Tensor(array, inds) end - return MatrixProduct{P,B}(QuantumTensorNetwork(TensorNetwork(tensors), input, output)) + return QuantumTensorNetwork(TensorNetwork(tensors), input, output) end # NOTE does not use optimal contraction path, but "parallel-optimal" which costs x2 more @@ -106,7 +104,7 @@ end # end # TODO let choose the orthogonality center -function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{State,Open}}) +function Base.rand(rng::Random.AbstractRNG, sampler::QTNSampler{MatrixProduct{State,Open}}) n = sampler.n χ = sampler.χ p = get(sampler, :p, 2) @@ -142,7 +140,7 @@ end # TODO let choose the orthogonality center # TODO different input/output physical dims -function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{Operator,Open}}) +function Base.rand(rng::Random.AbstractRNG, sampler::QTNSampler{MatrixProduct{Operator,Open}}) n = sampler.n χ = sampler.χ p = get(sampler, :p, 2) @@ -182,7 +180,7 @@ end # TODO stable renormalization # TODO different input/output physical dims for Operator -function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{MatrixProduct{P,Periodic}}) where {P<:Plug} +function Base.rand(rng::Random.AbstractRNG, sampler::QTNSampler{MatrixProduct{P,Periodic}}) where {P<:Plug} n = sampler.n χ = sampler.χ p = get(sampler, :p, 2) diff --git a/src/Quantum/PEP.jl b/src/Quantum/PEP.jl index d9d865f9..b246c1ef 100644 --- a/src/Quantum/PEP.jl +++ b/src/Quantum/PEP.jl @@ -1,40 +1,25 @@ using UUIDs: uuid4 -using EinExprs: inds +using Classes """ - ProjectedEntangledPair{P<:Plug,B<:Boundary} <: Quantum + ProjectedEntangledPair{P<:Plug,B<:Boundary} <: Ansatz A generic ansatz representing Projected Entangled Pair States (PEPS) and Projected Entangled Pair Operators (PEPO). Type variable `P` represents the `Plug` type (`State` or `Operator`) and `B` represents the `Boundary` type (`Open` or `Periodic`). - -# Ansatz Fields - - - `χ::Union{Nothing,Int}` Maximum virtual bond dimension. """ -abstract type ProjectedEntangledPair{P,B} <: Quantum where {P<:Plug,B<:Boundary} end - -boundary(::Type{<:ProjectedEntangledPair{P,B}}) where {P,B} = B -plug(::Type{<:ProjectedEntangledPair{P}}) where {P} = P +struct ProjectedEntangledPair{P<:Plug,B<:Boundary} <: Ansatz end function ProjectedEntangledPair{P}(arrays; boundary::Type{<:Boundary} = Open, kwargs...) where {P<:Plug} ProjectedEntangledPair{P,boundary}(arrays; kwargs...) end -metadata(T::Type{<:ProjectedEntangledPair}) = merge(metadata(supertype(T)), @NamedTuple begin - χ::Union{Nothing,Int} -end) - -function checkmeta(::Type{ProjectedEntangledPair{P,B}}, tn::TensorNetwork) where {P,B} - # meta has correct value - isnothing(tn.χ) || tn.χ > 0 || return false - - # no virtual index has dimensionality bigger than χ - all(i -> isnothing(tn.χ) || size(tn, i) <= tn.χ, inds(tn, :virtual)) || return false +const PEPS = ProjectedEntangledPair{State} +const PEPO = ProjectedEntangledPair{Operator} - return true -end +plug(::Type{<:ProjectedEntangledPair{P}}) where {P} = P() +boundary(::Type{<:ProjectedEntangledPair{P,B}}) where {P,B} = B() -function _sitealias(::Type{ProjectedEntangledPair{P,Open}}, order, size, pos) where {P<:Plug} +function sitealias(::Type{<:ProjectedEntangledPair{P,Open}}, order, size, pos) where {P<:Plug} m, n = size i, j = pos @@ -44,11 +29,11 @@ function _sitealias(::Type{ProjectedEntangledPair{P,Open}}, order, size, pos) wh !(i == 1 && dir === :u || i == m && dir === :d || j == 1 && dir === :l || j == n && dir === :r) end end -_sitealias(::Type{ProjectedEntangledPair{P,Periodic}}, order, _, _) where {P<:Plug} = tuple(order...) -_sitealias(::Type{ProjectedEntangledPair{P,Infinite}}, order, _, _) where {P<:Plug} = tuple(order...) +sitealias(::Type{<:ProjectedEntangledPair{P,Periodic}}, order, _, _) where {P<:Plug} = tuple(order...) +sitealias(::Type{<:ProjectedEntangledPair{P,Infinite}}, order, _, _) where {P<:Plug} = tuple(order...) -defaultorder(::Type{ProjectedEntangledPair{State}}) = (:l, :r, :u, :d, :o) -defaultorder(::Type{ProjectedEntangledPair{Operator}}) = (:l, :r, :u, :d, :i, :o) +defaultorder(::Type{<:ProjectedEntangledPair{State}}) = (:l, :r, :u, :d, :o) +defaultorder(::Type{<:ProjectedEntangledPair{Operator}}) = (:l, :r, :u, :d, :i, :o) """ ProjectedEntangledPair{P,B}(arrays::Matrix{AbstractArray}; χ::Union{Nothing,Int} = nothing, order = defaultorder(ProjectedEntangledPair{P})) @@ -57,7 +42,6 @@ Construct a [`TensorNetwork`](@ref) with [`ProjectedEntangledPair`](@ref) ansatz # Keyword Arguments - - `χ` Maximum virtual bond dimension. Defaults to `nothing`. - `order` Order of the tensor indices on `arrays`. Defaults to `(:l, :r, :u, :d, :o)` if `P` is a `State`, `(:l, :r, :u, :d, :i, :o)` if `Operator`. """ function ProjectedEntangledPair{P,B}( @@ -89,41 +73,46 @@ function ProjectedEntangledPair{P,B}( throw(ErrorException("Plug $P is not valid")) end - tensors = map(zip(Iterators.map(Tuple, eachindex(IndexCartesian(), arrays)), arrays)) do ((i, j), array) - dirs = _sitealias(ProjectedEntangledPair{P,B}, order, (m, n), (i, j)) - - inds = map(dirs) do dir - if dir === :l - hinds[(i, (mod1(j - 1, n), j))] - elseif dir === :r - hinds[(i, (j, mod1(j + 1, n)))] - elseif dir === :u - vinds[((mod1(i - 1, m), i), j)] - elseif dir === :d - vinds[((i, mod1(i + 1, m)), j)] - elseif dir === :i - iinds[(i, j)] - elseif dir === :o - oinds[(i, j)] + input, output = if P <: Property + Symbol[], Symbol[] + elseif P <: State + Symbol[], [oinds[i, j] for i in 1:m, j in 1:n] + elseif P <: Operator + [iinds[i, j] for i in 1:m, j in 1:n], [oinds[i, j] for i in 1:m, j in 1:n] + else + throw(ArgumentError("Plug $P is not valid")) + end + + tensors::Vector{Tensor} = + map(zip(Iterators.map(Tuple, eachindex(IndexCartesian(), arrays)), arrays)) do ((i, j), array) + dirs = sitealias(ProjectedEntangledPair{P,B}, order, (m, n), (i, j)) + + inds = map(dirs) do dir + if dir === :l + hinds[(i, (mod1(j - 1, n), j))] + elseif dir === :r + hinds[(i, (j, mod1(j + 1, n)))] + elseif dir === :u + vinds[((mod1(i - 1, m), i), j)] + elseif dir === :d + vinds[((i, mod1(i + 1, m)), j)] + elseif dir === :i + iinds[(i, j)] + elseif dir === :o + oinds[(i, j)] + end end - end - Tensor(array, inds) - end |> vec + Tensor(array, inds) + end |> vec - return TensorNetwork{ProjectedEntangledPair{P,B}}(tensors; χ, plug = P, interlayer, metadata...) + return QuantumTensorNetwork(TensorNetwork(tensors), input, output) end -const PEPS = ProjectedEntangledPair{State} -const PEPO = ProjectedEntangledPair{Operator} - -tensors(ψ::TensorNetwork{ProjectedEntangledPair{P,Infinite}}, site::Int, args...) where {P<:Plug} = - tensors(plug(ψ), ψ, mod1(site, length(ψ.tensors)), args...) - # TODO normalize # TODO let choose the orthogonality center # TODO different input/output physical dims -function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{ProjectedEntangledPair{P,Open}}) where {P<:Plug} +function Base.rand(rng::Random.AbstractRNG, sampler::QTNSampler{ProjectedEntangledPair{P,Open}}) where {P<:Plug} rows = sampler.rows cols = sampler.cols χ = sampler.χ @@ -159,13 +148,13 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{ProjectedEntangle # normalize state arrays[1, 1] ./= P <: State ? sqrt(p) : p - ProjectedEntangledPair{P,Open}(arrays; χ) + ProjectedEntangledPair{P,Open}(arrays) end # TODO normalize # TODO let choose the orthogonality center # TODO different input/output physical dims -function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{ProjectedEntangledPair{P,Periodic}}) where {P<:Plug} +function Base.rand(rng::Random.AbstractRNG, sampler::QTNSampler{ProjectedEntangledPair{P,Periodic}}) where {P<:Plug} rows = sampler.rows cols = sampler.cols χ = sampler.χ @@ -192,5 +181,5 @@ function Base.rand(rng::Random.AbstractRNG, sampler::TNSampler{ProjectedEntangle # normalize state arrays[1, 1] ./= P <: State ? sqrt(p) : p - ProjectedEntangledPair{P,Periodic}(arrays; χ) + ProjectedEntangledPair{P,Periodic}(arrays) end diff --git a/src/Quantum/Quantum.jl b/src/Quantum/Quantum.jl index 6b5df73a..26b7ce92 100644 --- a/src/Quantum/Quantum.jl +++ b/src/Quantum/Quantum.jl @@ -6,7 +6,7 @@ using Classes """ QuantumTensorNetwork -Tensor Network `Ansatz` that has a notion of sites and directionality (input/output). +Tensor Network that has a notion of sites and directionality (input/output). """ @class QuantumTensorNetwork <: TensorNetwork begin input::Vector{Symbol} @@ -129,23 +129,6 @@ function plug(tn) end end -# Boundary trait -abstract type Boundary end -struct Open <: Boundary end -struct Periodic <: Boundary end -struct Infinite <: Boundary end - -""" - boundary(::QuantumTensorNetwork) - -Return the `Boundary` type of the [`TensorNetwork`](@ref). The following `Boundary`s are defined in `Tenet`: - - - `Open` - - `Periodic` - - `Infinite` -""" -function boundary end - # TODO look for more stable ways """ norm(ψ::AbstractQuantumTensorNetwork, p::Real=2) @@ -220,3 +203,36 @@ function marginal(ψ, site) tensor = only(select(tn, siteindex)) sum(tensor, inds = setdiff(inds(tensor), [siteindex])) end + +# Boundary trait +abstract type Boundary end +struct Open <: Boundary end +struct Periodic <: Boundary end +struct Infinite <: Boundary end + +""" + boundary(::QuantumTensorNetwork) + +Return the `Boundary` type of the [`TensorNetwork`](@ref). The following `Boundary`s are defined in `Tenet`: + + - `Open` + - `Periodic` + - `Infinite` +""" +function boundary end + +abstract type Ansatz end + +struct QTNSampler{A<:Ansatz} <: Random.Sampler{QuantumTensorNetwork} + config::Dict{Symbol,Any} + + QTNSampler{A}(; kwargs...) where {A} = new{A}(kwargs) +end + +Base.eltype(::QTNSampler{A}) where {A} = A + +Base.getproperty(obj::QTNSampler, name::Symbol) = name === :config ? getfield(obj, :config) : obj.config[name] +Base.get(obj::QTNSampler, name, default) = get(obj.config, name, default) + +Base.rand(A::Type{<:Ansatz}; kwargs...) = rand(Random.default_rng(), A; kwargs...) +Base.rand(rng::AbstractRNG, A::Type{<:Ansatz}; kwargs...) = rand(rng, QTNSampler{A}(; kwargs...)) \ No newline at end of file diff --git a/src/Tenet.jl b/src/Tenet.jl index 0b7bc146..8e0b066e 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -24,6 +24,9 @@ export Boundary, boundary, Open, Periodic, Infinite include("Quantum/MP.jl") export MatrixProduct, MPS, MPO +include("Quantum/PEP.jl") +export ProjectedEntangledPair, PEPS, PEPO + # reexports from LinearAlgebra export norm, normalize! diff --git a/test/MatrixProductOperator_test.jl b/test/MatrixProductOperator_test.jl index 08441fbf..24ee9a2b 100644 --- a/test/MatrixProductOperator_test.jl +++ b/test/MatrixProductOperator_test.jl @@ -14,25 +14,25 @@ @test begin arrays = [rand(2, 2, 2)] - MatrixProduct{Operator}(arrays) isa MPO{Open} + MatrixProduct{Operator}(arrays) isa QuantumTensorNetwork end @test begin arrays = [rand(2, 2, 2), rand(2, 2, 2)] - MatrixProduct{Operator}(arrays) isa MPO{Open} + MatrixProduct{Operator}(arrays) isa QuantumTensorNetwork end @testset "`Open` boundary" begin # product operator @test begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Open}(arrays) isa MPO{Open} + MatrixProduct{Operator,Open}(arrays) isa QuantumTensorNetwork end # alternative constructor @test begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Open) isa MPO{Open} + MatrixProduct{Operator}(arrays; boundary = Open) isa QuantumTensorNetwork end # entangling operator @@ -40,7 +40,7 @@ i = 3 o = 5 arrays = [rand(2, i, o), rand(2, 4, i, o), rand(4, i, o)] - MatrixProduct{Operator,Open}(arrays) isa MPO{Open} + MatrixProduct{Operator,Open}(arrays) isa QuantumTensorNetwork end # entangling operator - change order @@ -48,13 +48,13 @@ i = 3 o = 5 arrays = [rand(i, 2, o), rand(2, i, 4, o), rand(4, i, o)] - MatrixProduct{Operator,Open}(arrays, order = (:l, :i, :r, :o)) isa MPO{Open} + MatrixProduct{Operator,Open}(arrays, order = (:l, :i, :r, :o)) isa QuantumTensorNetwork end # fail on Open with Periodic format @test_throws MethodError begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Open}(arrays) isa MPO{Open} + MatrixProduct{Operator,Open}(arrays) isa QuantumTensorNetwork end end @@ -62,13 +62,13 @@ # product operator @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} + MatrixProduct{Operator,Periodic}(arrays) isa QuantumTensorNetwork end # alternative constructor @test begin arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Periodic) isa MPO{Periodic} + MatrixProduct{Operator}(arrays; boundary = Periodic) isa QuantumTensorNetwork end # entangling operator @@ -76,7 +76,7 @@ i = 3 o = 5 arrays = [rand(2, 4, i, o), rand(4, 8, i, o), rand(8, 2, i, o)] - MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} + MatrixProduct{Operator,Periodic}(arrays) isa QuantumTensorNetwork end # entangling operator - change order @@ -84,51 +84,51 @@ i = 3 o = 5 arrays = [rand(2, i, 4, o), rand(4, i, 8, o), rand(8, i, 2, o)] - MatrixProduct{Operator,Periodic}(arrays, order = (:l, :i, :r, :o)) isa MPO{Periodic} + MatrixProduct{Operator,Periodic}(arrays, order = (:l, :i, :r, :o)) isa QuantumTensorNetwork end # fail on Periodic with Open format @test_throws MethodError begin arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Periodic}(arrays) isa MPO{Periodic} + MatrixProduct{Operator,Periodic}(arrays) isa QuantumTensorNetwork end end - @testset "`Infinite` boundary" begin - # product operator - @test begin - arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} - end - - # alternative constructor - @test begin - arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] - MatrixProduct{Operator}(arrays; boundary = Infinite) isa MPO{Infinite} - end - - # entangling operator - @test begin - i = 3 - o = 5 - arrays = [rand(2, 4, i, o), rand(4, 8, i, o), rand(8, 2, i, o)] - MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} - end - - # entangling operator - change order - @test begin - i = 3 - o = 5 - arrays = [rand(2, i, 4, o), rand(4, i, 8, o), rand(8, i, 2, o)] - MatrixProduct{Operator,Infinite}(arrays, order = (:l, :i, :r, :o)) isa MPO{Infinite} - end - - # fail on Infinite with Open format - @test_throws MethodError begin - arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] - MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} - end - end + # @testset "`Infinite` boundary" begin + # # product operator + # @test skip = true begin + # arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] + # MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} + # end + + # # alternative constructor + # @test skip = true begin + # arrays = [rand(1, 1, 2, 2), rand(1, 1, 2, 2), rand(1, 1, 2, 2)] + # MatrixProduct{Operator}(arrays; boundary = Infinite) isa MPO{Infinite} + # end + + # # entangling operator + # @test skip = true begin + # i = 3 + # o = 5 + # arrays = [rand(2, 4, i, o), rand(4, 8, i, o), rand(8, 2, i, o)] + # MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} + # end + + # # entangling operator - change order + # @test skip = true begin + # i = 3 + # o = 5 + # arrays = [rand(2, i, 4, o), rand(4, i, 8, o), rand(8, i, 2, o)] + # MatrixProduct{Operator,Infinite}(arrays, order = (:l, :i, :r, :o)) isa MPO{Infinite} + # end + + # # fail on Infinite with Open format + # @test_throws MethodError begin + # arrays = [rand(1, 2, 2), rand(1, 1, 2, 2), rand(1, 2, 2)] + # MatrixProduct{Operator,Infinite}(arrays) isa MPO{Infinite} + # end + # end end @testset "merge" begin @@ -159,21 +159,4 @@ mpo = rand(MatrixProduct{Operator,Open}, n = 8, p = 2, χ = 8) @test_broken norm(mpo) ≈ 1 end - - # @testset "Initialization" begin - # for params in [ - # (2, 2, 2, 1), - # (2, 2, 2, 2), - # (4, 4, 4, 16), - # (4, 2, 2, 8), - # (4, 2, 3, 8), - # (6, 2, 2, 4), - # (8, 2, 3, 4), - # # (1, 2, 2, 1), - # # (1, 3, 3, 1), - # # (1, 1, 1, 1), - # ] - # @test rand(MatrixProduct{Operator,Open}, params...) isa TensorNetwork{MatrixProduct{Operator,Open}} - # end - # end end diff --git a/test/MatrixProductState_test.jl b/test/MatrixProductState_test.jl index 80f96d58..a2066c94 100644 --- a/test/MatrixProductState_test.jl +++ b/test/MatrixProductState_test.jl @@ -14,44 +14,44 @@ @test begin arrays = [rand(1, 2)] - MatrixProduct{State}(arrays) isa MPS{Open} + MatrixProduct{State}(arrays) isa QuantumTensorNetwork end @test begin arrays = [rand(1, 2), rand(1, 2)] - MatrixProduct{State}(arrays) isa MPS{Open} + MatrixProduct{State}(arrays) isa QuantumTensorNetwork end @testset "`Open` boundary" begin # product state @test begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Open}(arrays) isa MPS{Open} + MatrixProduct{State,Open}(arrays) isa QuantumTensorNetwork end # entangled state @test begin arrays = [rand(2, 2), rand(2, 4, 2), rand(4, 1, 2), rand(1, 2)] - MatrixProduct{State,Open}(arrays) isa MPS{Open} + MatrixProduct{State,Open}(arrays) isa QuantumTensorNetwork end @testset "custom order" begin arrays = [rand(3, 1), rand(3, 1, 3), rand(1, 3)] ψ = MatrixProduct{State,Open}(arrays, order = (:r, :o, :l)) - @test ψ isa MPS{Open} + @test ψ isa QuantumTensorNetwork end # alternative constructor @test begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State}(arrays; boundary = Open) isa MPS{Open} + MatrixProduct{State}(arrays; boundary = Open) isa QuantumTensorNetwork end # fail on Open with Periodic format @test_throws Exception begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Open}(arrays) isa MPS{Open} + MatrixProduct{State,Open}(arrays) isa QuantumTensorNetwork end @testset "rand" begin @@ -60,7 +60,7 @@ @testset "χ = $χ" for χ in [4, 32] ψ = rand(MatrixProduct{State,Open}, n = 7, p = 2, χ = χ) - @test ψ isa MPS{Open} + @test ψ isa QuantumTensorNetwork @test length(tensors(ψ)) == 7 @test maximum(vind -> size(ψ, vind), inds(ψ, :inner)) <= 32 end @@ -71,76 +71,76 @@ # product state @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} + MatrixProduct{State,Periodic}(arrays) isa QuantumTensorNetwork end # entangled state @test begin arrays = [rand(3, 4, 2), rand(4, 8, 2), rand(8, 3, 2)] - MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} + MatrixProduct{State,Periodic}(arrays) isa QuantumTensorNetwork end @testset "custom order" begin arrays = [rand(3, 1, 3), rand(3, 1, 3), rand(3, 1, 3)] ψ = MatrixProduct{State,Periodic}(arrays, order = (:r, :o, :l)) - @test ψ isa MPS{Periodic} + @test ψ isa QuantumTensorNetwork end # alternative constructor @test begin arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State}(arrays; boundary = Periodic) isa MPS{Periodic} + MatrixProduct{State}(arrays; boundary = Periodic) isa QuantumTensorNetwork end # fail on Periodic with Open format @test_throws Exception begin arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Periodic}(arrays) isa MPS{Periodic} + MatrixProduct{State,Periodic}(arrays) isa QuantumTensorNetwork end end - @testset "`Infinite` boundary" begin - # product state - @test begin - arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} - end - - # entangled state - @test begin - arrays = [rand(3, 4, 2), rand(4, 8, 2), rand(8, 3, 2)] - MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} - end - - @testset "custom order" begin - arrays = [rand(3, 1, 3), rand(3, 1, 3), rand(3, 1, 3)] - ψ = MatrixProduct{State,Infinite}(arrays, order = (:r, :o, :l)) - - @test ψ isa MPS{Infinite} - end - - # alternative constructor - @test begin - arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - MatrixProduct{State}(arrays; boundary = Infinite) isa MPS{Infinite} - end - - # fail on Infinite with Open format - @test_throws Exception begin - arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] - MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} - end - - # @testset "tensors" begin - # arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] - # ψ = MatrixProduct{State,Infinite}(arrays, order = (:l, :r, :o)) - - # @test tensors(ψ, 1) isa Tensor - # @test tensors(ψ, 4) == tensors(ψ, 1) - # @test tensors(ψ, 0) == tensors(ψ, 3) - # end - end + # @testset "`Infinite` boundary" begin + # # product state + # @test skip = true begin + # arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] + # MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} + # end + + # # entangled state + # @test skip = true begin + # arrays = [rand(3, 4, 2), rand(4, 8, 2), rand(8, 3, 2)] + # MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} + # end + + # @testset "custom order" begin + # arrays = [rand(3, 1, 3), rand(3, 1, 3), rand(3, 1, 3)] + # ψ = MatrixProduct{State,Infinite}(arrays, order = (:r, :o, :l)) + + # @test skip = true ψ isa MPS{Infinite} + # end + + # # alternative constructor + # @test skip = true begin + # arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] + # MatrixProduct{State}(arrays; boundary = Infinite) isa MPS{Infinite} + # end + + # # fail on Infinite with Open format + # @test_throws skip = true Exception begin + # arrays = [rand(1, 2), rand(1, 1, 2), rand(1, 2)] + # MatrixProduct{State,Infinite}(arrays) isa MPS{Infinite} + # end + + # # @testset "tensors" begin + # # arrays = [rand(1, 1, 2), rand(1, 1, 2), rand(1, 1, 2)] + # # ψ = MatrixProduct{State,Infinite}(arrays, order = (:l, :r, :o)) + + # # @test tensors(ψ, 1) isa Tensor + # # @test tensors(ψ, 4) == tensors(ψ, 1) + # # @test tensors(ψ, 0) == tensors(ψ, 3) + # # end + # end end @testset "merge" begin From 64ed19340146e3afc872f5c4cc21cfcc4a5b48b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 00:38:22 +0200 Subject: [PATCH 19/29] Fix `normalize!` --- src/Quantum/Quantum.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Quantum/Quantum.jl b/src/Quantum/Quantum.jl index 26b7ce92..bb6d7e65 100644 --- a/src/Quantum/Quantum.jl +++ b/src/Quantum/Quantum.jl @@ -172,7 +172,7 @@ function LinearAlgebra.normalize!( if isnothing(insert) # method 1: divide all tensors by (√v)^(1/n) - n = length(ψ) + n = length(tensors(ψ)) norm ^= 1 / n for tensor in tensors(ψ) tensor ./= norm From b8728012e45e4662faac97f5de135d3313bab563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 00:39:08 +0200 Subject: [PATCH 20/29] Remove legacy code --- src/Quantum/PEP.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/Quantum/PEP.jl b/src/Quantum/PEP.jl index b246c1ef..75ddc45a 100644 --- a/src/Quantum/PEP.jl +++ b/src/Quantum/PEP.jl @@ -62,17 +62,6 @@ function ProjectedEntangledPair{P,B}( oinds = Dict((i, j) => Symbol(uuid4()) for i in 1:m, j in 1:n) iinds = Dict((i, j) => Symbol(uuid4()) for i in 1:m, j in 1:n) - interlayer = if P <: State - [Bijection(Dict(i + j * m => index for ((i, j), index) in oinds))] - elseif P <: Operator - [ - Bijection(Dict(i + j * m => index for ((i, j), index) in iinds)), - Bijection(Dict(i + j * m => index for ((i, j), index) in oinds)), - ] - else - throw(ErrorException("Plug $P is not valid")) - end - input, output = if P <: Property Symbol[], Symbol[] elseif P <: State From 4ea6ab918c389a029f41a70fe03e7d809cf32794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 00:39:25 +0200 Subject: [PATCH 21/29] Update `ChainRules` extensions --- ext/TenetChainRulesCoreExt.jl | 2 +- ext/TenetChainRulesTestUtilsExt.jl | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/ext/TenetChainRulesCoreExt.jl b/ext/TenetChainRulesCoreExt.jl index a38ae735..db52a1ed 100644 --- a/ext/TenetChainRulesCoreExt.jl +++ b/ext/TenetChainRulesCoreExt.jl @@ -64,7 +64,7 @@ end TensorNetwork_pullback(Δ::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors) TensorNetwork_pullback(Δ::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ)) -function ChainRulesCore.rrule(T::Type{TensorNetwork}, tensors) +function ChainRulesCore.rrule(T::Type{<:absclass(TensorNetwork)}, tensors) T(tensors), TensorNetwork_pullback end diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 223297a7..23db9724 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -4,11 +4,10 @@ using Tenet using ChainRulesCore using ChainRulesTestUtils using Random +using Classes -function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork) - return Tangent{TensorNetwork}( - tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)], - ) +function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::T) where {T<:absclass(TensorNetwork)} + return Tangent{T}(tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) end end \ No newline at end of file From e2e9721377906c991710acfa3c40cdd844e4fd78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 00:39:37 +0200 Subject: [PATCH 22/29] Fix `copy` on `TensorNetwork` --- src/TensorNetwork.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 60d7321b..af894c2d 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -35,7 +35,9 @@ end # TensorNetwork{A}(tn::absclass(TensorNetwork){B}; metadata...) where {A,B} = # TensorNetwork{A}(tensors(tn); merge(tn.metadata, metadata)...) -Base.copy(tn::T) where {T<:absclass(TensorNetwork)} = T(map(field -> copy(getfield(tn, field)), fieldnames(T))...) +Base.copy(tn::T) where {T<:absclass(TensorNetwork)} = T(map(fieldnames(T)) do field + (field === :indices ? deepcopy : copy)(getfield(tn, field)) +end...) Base.summary(io::IO, x::absclass(TensorNetwork)) = print(io, "$(length(x))-tensors $(typeof(x))") Base.show(io::IO, tn::absclass(TensorNetwork)) = @@ -128,7 +130,7 @@ See also: [`append!`](@ref). """ Base.merge!(self::absclass(TensorNetwork), other::absclass(TensorNetwork)) = append!(self, tensors(other)) Base.merge!(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = foldl(merge!, others; init = self) -Base.merge(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = merge!(deepcopy(self), others...) # TODO deepcopy because `indices` are not correctly copied and it mutates +Base.merge(self::absclass(TensorNetwork), others::absclass(TensorNetwork)...) = merge!(copy(self), others...) function Base.popat!(tn::absclass(TensorNetwork), i::Integer) tensor = popat!(tn.tensors, i) From 62cce3b7154b8037ee890072f1ee43f0b757716f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 01:41:43 +0200 Subject: [PATCH 23/29] Fix `PEP` constructor --- src/Quantum/PEP.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Quantum/PEP.jl b/src/Quantum/PEP.jl index 75ddc45a..df3fb553 100644 --- a/src/Quantum/PEP.jl +++ b/src/Quantum/PEP.jl @@ -65,9 +65,9 @@ function ProjectedEntangledPair{P,B}( input, output = if P <: Property Symbol[], Symbol[] elseif P <: State - Symbol[], [oinds[i, j] for i in 1:m, j in 1:n] + Symbol[], vec([oinds[i, j] for i in 1:m, j in 1:n]) elseif P <: Operator - [iinds[i, j] for i in 1:m, j in 1:n], [oinds[i, j] for i in 1:m, j in 1:n] + vec([iinds[i, j] for i in 1:m, j in 1:n]), vec([oinds[i, j] for i in 1:m, j in 1:n]) else throw(ArgumentError("Plug $P is not valid")) end From b090cfc89ed6c9244b5645d503262f4d047dbc4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 01:42:17 +0200 Subject: [PATCH 24/29] Implement trace methods for `QuantumTensorNetwork` #110 --- src/Quantum/Quantum.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/Quantum/Quantum.jl b/src/Quantum/Quantum.jl index bb6d7e65..f781dff6 100644 --- a/src/Quantum/Quantum.jl +++ b/src/Quantum/Quantum.jl @@ -141,7 +141,9 @@ function LinearAlgebra.norm(ψ::absclass(QuantumTensorNetwork), p::Real = 2; kwa p == 2 || throw(ArgumentError("p=$p is not implemented yet")) tn = merge(ψ, ψ') - all(isempty, [tn.input, tn.output]) || throw("unimplemented if <ψ|ψ> is an operator") + if plug(tn) isa Operator + tn = tr(tn) + end return contract(tn; kwargs...) |> only |> sqrt |> abs end @@ -184,6 +186,26 @@ function LinearAlgebra.normalize!( end end +""" + LinearAlgebra.tr(U::AbstractQuantumTensorNetwork) + +Trace `U`: sum of diagonal elements if `U` is viewed as a matrix. + +Depending on the result of `plug(U)`, different actions can be taken: + + - If `Property()`, the result of `contract(U)` will be a "scalar", for which the trace acts like the identity. + - If `State()`, the result of `contract(U)` will be a "vector", for which the trace is undefined and will fail. + - If `Operator()`, the input and output indices of `U` are connected. +""" +LinearAlgebra.tr(U::absclass(QuantumTensorNetwork)) = tr!(U) +tr!(U::absclass(QuantumTensorNetwork)) = tr!(plug(U), U) +tr!(::Property, scalar::absclass(QuantumTensorNetwork)) = scalar +function tr!(::Operator, U::absclass(QuantumTensorNetwork)) + sites(U, :in) == sites(U, :out) || throw(ArgumentError("input and output sites do not match")) + copyto!(U.output, U.input) + U +end + """ fidelity(ψ,ϕ) From 7be04487877c97ae2e7115a81d8d87ff3c326e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 01:47:29 +0200 Subject: [PATCH 25/29] Fix docs --- docs/src/contraction.md | 4 ++-- docs/src/examples/ad-tn.md | 2 +- docs/src/examples/google-rqc.md | 2 +- docs/src/quantum/index.md | 28 +++++++--------------------- docs/src/tensor-network.md | 23 ++++++++++------------- docs/src/transformations.md | 10 +++++----- src/Quantum/Quantum.jl | 2 +- src/TensorNetwork.jl | 9 ++++----- 8 files changed, 31 insertions(+), 49 deletions(-) diff --git a/docs/src/contraction.md b/docs/src/contraction.md index d698d218..b3415adb 100644 --- a/docs/src/contraction.md +++ b/docs/src/contraction.md @@ -3,7 +3,7 @@ Contraction path optimization and execution is delegated to the [`EinExprs`](https://github.com/bsc-quantic/EinExprs) library. A `EinExpr` is a lower-level form of a Tensor Network, in which the contraction path has been laid out as a tree. It is similar to a symbolic expression (i.e. `Expr`) but in which every node represents an Einstein summation expression (aka `einsum`). ```@docs -einexpr(::TensorNetwork) -contract(::TensorNetwork) +einexpr(::Tenet.AbstractTensorNetwork) +contract(::Tenet.AbstractTensorNetwork) contract! ``` diff --git a/docs/src/examples/ad-tn.md b/docs/src/examples/ad-tn.md index 7aea7dfe..638dd735 100644 --- a/docs/src/examples/ad-tn.md +++ b/docs/src/examples/ad-tn.md @@ -19,7 +19,7 @@ rng = seed!(4) # hide ψ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 2) # hide ϕ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 4) # hide -tn = hcat(ψ, ϕ) +tn = merge(ψ, ϕ') plot(tn) # hide ``` diff --git a/docs/src/examples/google-rqc.md b/docs/src/examples/google-rqc.md index 9a3fa8b6..8f945228 100644 --- a/docs/src/examples/google-rqc.md +++ b/docs/src/examples/google-rqc.md @@ -42,7 +42,7 @@ _sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38, # load circuit and convert to `TensorNetwork` circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format = QuacIO.Qflex(), sites = _sites); -tn = TensorNetwork(circuit) +tn = QuantumTensorNetwork(circuit) tn = view(tn, [i => 1 for i in inds(tn, set=:open)]...) plot(tn) # hide ``` diff --git a/docs/src/quantum/index.md b/docs/src/quantum/index.md index d15d7250..e8924934 100644 --- a/docs/src/quantum/index.md +++ b/docs/src/quantum/index.md @@ -1,44 +1,30 @@ # Introduction -In `Tenet`, we define a [`Quantum`](@ref) Tensor Network as a [`TensorNetwork`](@ref) with a notion of sites and directionality. - -```@docs -Quantum -``` +In `Tenet`, we define a [`QuantumTensorNetwork`](@ref) as a [`TensorNetwork`](@ref) with a notion of sites and directionality. ```@docs +QuantumTensorNetwork plug -``` - -```@docs sites ``` -```@docs -tensors(::TensorNetwork{<:Quantum}, ::Integer) -``` - -```@docs -boundary -``` - ## Adjoint ```@docs adjoint ``` -## Concatenation +## Norm ```@docs -hcat(::TensorNetwork{<:Quantum}, ::TensorNetwork{<:Quantum}) +LinearAlgebra.norm(::Tenet.AbstractQuantumTensorNetwork, ::Real) +LinearAlgebra.normalize!(::Tenet.AbstractQuantumTensorNetwork, ::Real) ``` -## Norm +## Trace ```@docs -LinearAlgebra.norm(::TensorNetwork{<:Quantum}, p::Real) -LinearAlgebra.normalize!(::TensorNetwork{<:Quantum}, ::Real) +LinearAlgebra.tr(::Tenet.AbstractQuantumTensorNetwork) ``` ## Fidelity diff --git a/docs/src/tensor-network.md b/docs/src/tensor-network.md index c78cb112..1129e08f 100644 --- a/docs/src/tensor-network.md +++ b/docs/src/tensor-network.md @@ -28,11 +28,9 @@ Information about a `TensorNetwork` can be queried with the following functions. ## Query information ```@docs -inds(::TensorNetwork) -size(::TensorNetwork) -tensors(::TensorNetwork) -length(::TensorNetwork) -ansatz +inds(::Tenet.AbstractTensorNetwork) +size(::Tenet.AbstractTensorNetwork) +tensors(::Tenet.AbstractTensorNetwork) ``` ## Modification @@ -40,17 +38,16 @@ ansatz ### Add/Remove tensors ```@docs -push!(::TensorNetwork, ::Tensor) -append!(::TensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor}) -merge!(::AbstractTensorNetwork, ::AbstractTensorNetwork) -pop!(::TensorNetwork, ::Tensor) -delete!(::TensorNetwork, ::Any) +push!(::Tenet.AbstractTensorNetwork, ::Tensor) +append!(::Tenet.AbstractTensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor}) +merge!(::Tenet.AbstractTensorNetwork, ::Tenet.AbstractTensorNetwork) +pop!(::Tenet.AbstractTensorNetwork, ::Tensor) +delete!(::Tenet.AbstractTensorNetwork, ::Any) ``` ### Replace existing elements ```@docs -replace replace! ``` @@ -60,12 +57,12 @@ replace! select selectdim slice! -view(::TensorNetwork) +view(::Tenet.AbstractTensorNetwork) ``` ## Miscelaneous ```@docs -Base.copy(::TensorNetwork) +Base.copy(::Tenet.AbstractTensorNetwork) Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer) ``` diff --git a/docs/src/transformations.md b/docs/src/transformations.md index d6a79fd1..624114e8 100644 --- a/docs/src/transformations.md +++ b/docs/src/transformations.md @@ -81,7 +81,7 @@ A = Tensor(data, (:i, :j, :k, :l)) #hide B = Tensor(rand(2, 2), (:i, :m)) #hide C = Tensor(rand(2, 2), (:j, :n)) #hide -tn = TensorNetwork([A, B, C]) #hide +tn = TensorNetwork(Tensor[A, B, C]) #hide reduced = transform(tn, Tenet.DiagonalReduction) #hide smooth_annotation!( #hide @@ -139,7 +139,7 @@ B = Tensor(rand(2, 2), (:i, :m)) #hide C = Tensor(rand(2, 2, 2), (:m, :n, :o)) #hide E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) #hide -tn = TensorNetwork([A, B, C, E]) #hide +tn = TensorNetwork(Tensor[A, B, C, E]) #hide reduced = transform(tn, Tenet.RankSimplification) #hide smooth_annotation!( #hide @@ -193,7 +193,7 @@ A = Tensor(data, (:i, :j, :k)) #hide B = Tensor(rand(3, 3), (:j, :l)) #hide C = Tensor(rand(3, 3), (:l, :m)) #hide -tn = TensorNetwork([A, B, C]) #hide +tn = TensorNetwork(Tensor[A, B, C]) #hide reduced = transform(tn, Tenet.ColumnReduction) #hide smooth_annotation!( #hide @@ -247,7 +247,7 @@ m1 = Tensor(rand(3, 3), (:k, :l)) #hide t1 = contract(v1, v2) #hide tensor = contract(t1, m1) #hide -tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide +tn = TensorNetwork(Tensor[tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide reduced = transform(tn, Tenet.SplitSimplification) #hide smooth_annotation!( #hide @@ -294,7 +294,7 @@ set_theme!(resolution=(800,400)) # hide sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 72, 73, 74, 75, 76, 83, 84, 85, 94] circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format=QuacIO.Qflex(), sites=sites) -tn = TensorNetwork(circuit) +tn = QuantumTensorNetwork(circuit) # Apply transformations to the tensor network transformed_tn = transform(tn, [Tenet.AntiDiagonalGauging, Tenet.DiagonalReduction, Tenet.ColumnReduction, Tenet.RankSimplification]) diff --git a/src/Quantum/Quantum.jl b/src/Quantum/Quantum.jl index f781dff6..e0952873 100644 --- a/src/Quantum/Quantum.jl +++ b/src/Quantum/Quantum.jl @@ -133,7 +133,7 @@ end """ norm(ψ::AbstractQuantumTensorNetwork, p::Real=2) -Compute the ``p``-norm of a [`Quantum`](@ref) [`TensorNetwork`](@ref). +Compute the ``p``-norm of a [`QuantumTensorNetwork`](@ref). See also: [`normalize!`](@ref). """ diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index af894c2d..bb12738f 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -31,10 +31,11 @@ function TensorNetwork(tensors) return TensorNetwork(indices, tensors) end -# TODO maybe rename it as `convert` method? -# TensorNetwork{A}(tn::absclass(TensorNetwork){B}; metadata...) where {A,B} = -# TensorNetwork{A}(tensors(tn); merge(tn.metadata, metadata)...) +""" + copy(tn::TensorNetwork) +Return a shallow copy of a [`TensorNetwork`](@ref). +""" Base.copy(tn::T) where {T<:absclass(TensorNetwork)} = T(map(fieldnames(T)) do field (field === :indices ? deepcopy : copy)(getfield(tn, field)) end...) @@ -191,8 +192,6 @@ Replace the element in `old` with the one in `new`. Depending on the types of `o - If `Symbol`s, it will correspond to a index renaming. - If `Tensor`s, first element that satisfies _egality_ (`≡` or `===`) will be replaced. - -See also: [`replace`](@ref). """ Base.replace!(tn::absclass(TensorNetwork), old_new::Pair...) = replace!(tn, old_new) function Base.replace!(tn::absclass(TensorNetwork), old_new::Base.AbstractVecOrTuple{Pair}) From d3e4c11e7ba4e7239e645bc90cb6404f9f1fd7de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 02:00:06 +0200 Subject: [PATCH 26/29] Enable `norm` test on `MPO` --- test/MatrixProductOperator_test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/MatrixProductOperator_test.jl b/test/MatrixProductOperator_test.jl index 24ee9a2b..d0f5dbc5 100644 --- a/test/MatrixProductOperator_test.jl +++ b/test/MatrixProductOperator_test.jl @@ -157,6 +157,6 @@ @testset "norm" begin mpo = rand(MatrixProduct{Operator,Open}, n = 8, p = 2, χ = 8) - @test_broken norm(mpo) ≈ 1 + @test norm(mpo) ≈ 1 end end From 6567aa9e32b4ef56eb22eeb89d41530c688349a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 10:53:47 +0200 Subject: [PATCH 27/29] Fix `tensors` type in `TensorNetwork` constructor --- src/TensorNetwork.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index bb12738f..d8187e61 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -28,6 +28,8 @@ function TensorNetwork(tensors) throw(DimensionMismatch("Different sizes specified for index $index")) end + tensors = convert(Vector{Tensor}, tensors) + return TensorNetwork(indices, tensors) end From 86f3a054434e49cb33eda30846d47458cc67e2bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 12:20:08 +0200 Subject: [PATCH 28/29] Relax `Vector{Tensor}` conversion on `TensorNetwork` constructor --- docs/src/transformations.md | 8 ++--- ext/TenetChainRulesTestUtilsExt.jl | 2 +- test/Quantum_test.jl | 4 +-- test/TensorNetwork_test.jl | 54 +++++++++++++---------------- test/Transformations_test.jl | 16 ++++----- test/integration/ChainRules_test.jl | 8 ++--- test/integration/Makie_test.jl | 2 +- 7 files changed, 45 insertions(+), 49 deletions(-) diff --git a/docs/src/transformations.md b/docs/src/transformations.md index 624114e8..ac482326 100644 --- a/docs/src/transformations.md +++ b/docs/src/transformations.md @@ -81,7 +81,7 @@ A = Tensor(data, (:i, :j, :k, :l)) #hide B = Tensor(rand(2, 2), (:i, :m)) #hide C = Tensor(rand(2, 2), (:j, :n)) #hide -tn = TensorNetwork(Tensor[A, B, C]) #hide +tn = TensorNetwork([A, B, C]) #hide reduced = transform(tn, Tenet.DiagonalReduction) #hide smooth_annotation!( #hide @@ -139,7 +139,7 @@ B = Tensor(rand(2, 2), (:i, :m)) #hide C = Tensor(rand(2, 2, 2), (:m, :n, :o)) #hide E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) #hide -tn = TensorNetwork(Tensor[A, B, C, E]) #hide +tn = TensorNetwork([A, B, C, E]) #hide reduced = transform(tn, Tenet.RankSimplification) #hide smooth_annotation!( #hide @@ -193,7 +193,7 @@ A = Tensor(data, (:i, :j, :k)) #hide B = Tensor(rand(3, 3), (:j, :l)) #hide C = Tensor(rand(3, 3), (:l, :m)) #hide -tn = TensorNetwork(Tensor[A, B, C]) #hide +tn = TensorNetwork([A, B, C]) #hide reduced = transform(tn, Tenet.ColumnReduction) #hide smooth_annotation!( #hide @@ -247,7 +247,7 @@ m1 = Tensor(rand(3, 3), (:k, :l)) #hide t1 = contract(v1, v2) #hide tensor = contract(t1, m1) #hide -tn = TensorNetwork(Tensor[tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide +tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide reduced = transform(tn, Tenet.SplitSimplification) #hide smooth_annotation!( #hide diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 23db9724..aea16d94 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -7,7 +7,7 @@ using Random using Classes function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::T) where {T<:absclass(TensorNetwork)} - return Tangent{T}(tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) + return Tangent{T}(tensors = [ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)]) end end \ No newline at end of file diff --git a/test/Quantum_test.jl b/test/Quantum_test.jl index b5f49978..c7aa2f28 100644 --- a/test/Quantum_test.jl +++ b/test/Quantum_test.jl @@ -1,12 +1,12 @@ @testset "Quantum" begin state = QuantumTensorNetwork( - TensorNetwork(Tensor[Tensor(rand(2, 2), (:i, :k)), Tensor(rand(3, 2, 4), (:j, :k, :l))]), + TensorNetwork([Tensor(rand(2, 2), (:i, :k)), Tensor(rand(3, 2, 4), (:j, :k, :l))]), Symbol[], # input [:i, :j], # output ) operator = QuantumTensorNetwork( - TensorNetwork(Tensor[Tensor(rand(2, 4, 2), (:a, :c, :d)), Tensor(rand(3, 4, 3, 5), (:b, :c, :e, :f))]), + TensorNetwork([Tensor(rand(2, 4, 2), (:a, :c, :d)), Tensor(rand(3, 4, 3, 5), (:b, :c, :e, :f))]), [:a, :b], # input [:d, :e], # output ) diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 86d2f4bf..9acc05f8 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -9,7 +9,7 @@ @testset "list" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) @test only(tensors(tn)) === tensor @@ -56,7 +56,7 @@ @testset "merge!" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - A = TensorNetwork(Tensor[tensor]) + A = TensorNetwork([tensor]) B = TensorNetwork() merge!(A, B) @@ -66,7 +66,7 @@ @testset "pop!" begin @testset "by reference" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) @test pop!(tn, tensor) === tensor @test length(tn.tensors) == 0 @@ -76,7 +76,7 @@ @testset "by symbol" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) @test only(pop!(tn, :i)) === tensor @test length(tn.tensors) == 0 @@ -86,7 +86,7 @@ @testset "by symbols" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) @test only(pop!(tn, (:i, :j))) === tensor @test length(tn.tensors) == 0 @@ -98,7 +98,7 @@ # TODO by simbols @testset "delete!" begin tensor = Tensor(zeros(2, 3), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) @test delete!(tn, tensor) === tn @test length(tn.tensors) == 0 @@ -126,7 +126,7 @@ @testset "copy" begin tensor = Tensor(zeros(2, 2), (:i, :j)) - tn = TensorNetwork(Tensor[tensor]) + tn = TensorNetwork([tensor]) tn_copy = copy(tn) @test tensors(tn_copy) !== tensors(tn) && all(tensors(tn_copy) .=== tensors(tn)) @@ -134,14 +134,12 @@ end @testset "inds" begin - tn = TensorNetwork( - Tensor[ - Tensor(zeros(2, 2), (:i, :j)), - Tensor(zeros(2, 2), (:i, :k)), - Tensor(zeros(2, 2, 2), (:i, :l, :m)), - Tensor(zeros(2, 2), (:l, :m)), - ], - ) + tn = TensorNetwork([ + Tensor(zeros(2, 2), (:i, :j)), + Tensor(zeros(2, 2), (:i, :k)), + Tensor(zeros(2, 2, 2), (:i, :l, :m)), + Tensor(zeros(2, 2), (:l, :m)), + ],) @test issetequal(inds(tn), (:i, :j, :k, :l, :m)) @test issetequal(inds(tn, :open), (:j, :k)) @@ -150,14 +148,12 @@ end @testset "size" begin - tn = TensorNetwork( - Tensor[ - Tensor(zeros(2, 3), (:i, :j)), - Tensor(zeros(2, 4), (:i, :k)), - Tensor(zeros(2, 5, 6), (:i, :l, :m)), - Tensor(zeros(5, 6), (:l, :m)), - ], - ) + tn = TensorNetwork([ + Tensor(zeros(2, 3), (:i, :j)), + Tensor(zeros(2, 4), (:i, :k)), + Tensor(zeros(2, 5, 6), (:i, :l, :m)), + Tensor(zeros(5, 6), (:l, :m)), + ],) @test size(tn) == Dict((:i => 2, :j => 3, :k => 4, :l => 5, :m => 6)) @test all([size(tn, :i) == 2, size(tn, :j) == 3, size(tn, :k) == 4, size(tn, :l) == 5, size(tn, :m) == 6]) @@ -170,7 +166,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) @test issetequal(select(tn, :i), (t_ij, t_ik, t_ilm)) @test issetequal(select(tn, :j), (t_ij,)) @@ -211,7 +207,7 @@ A = Tensor(rand(2, 2, 2), (:i, :j, :k)) B = Tensor(rand(2, 2, 2), (:k, :l, :m)) - tn = TensorNetwork(Tensor[A, B]) + tn = TensorNetwork([A, B]) @test contract(tn) isa Tensor end @@ -220,7 +216,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) @testset "replace inds" begin mapping = (:i => :u, :j => :v, :k => :w, :l => :x, :m => :y) @@ -261,7 +257,7 @@ # New tensor network with two tensors with the same inds A = Tensor(rand(2, 2), (:u, :w)) B = Tensor(rand(2, 2), (:u, :w)) - tn = TensorNetwork(Tensor[A, B]) + tn = TensorNetwork([A, B]) new_tensor = Tensor(rand(2, 2), (:u, :w)) @@ -269,7 +265,7 @@ @test A === tn.tensors[1] @test new_tensor === tn.tensors[2] - tn = TensorNetwork(Tensor[A, B]) + tn = TensorNetwork([A, B]) replace!(tn, A => new_tensor) @test issetequal(tensors(tn), [new_tensor, B]) @@ -278,7 +274,7 @@ A = Tensor(zeros(2, 2), (:i, :j)) B = Tensor(zeros(2, 2), (:j, :k)) C = Tensor(zeros(2, 2), (:k, :l)) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) @test_throws ArgumentError replace!(tn, A => B, B => C, C => A) diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index 5c690336..e8813a2b 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -25,7 +25,7 @@ t_ik = Tensor(zeros(2, 2), (:i, :k)) t_ilm = Tensor(zeros(2, 2, 2), (:i, :l, :m)) t_lm = Tensor(zeros(2, 2), (:l, :m)) - tn = TensorNetwork(Tensor[t_ij, t_ik, t_ilm, t_lm]) + tn = TensorNetwork([t_ij, t_ik, t_ilm, t_lm]) transform!(tn, HyperindConverter) @test isempty(inds(tn, :hyper)) @@ -66,7 +66,7 @@ @test issetequal(find_diag_axes(A), [[:i, :j]]) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) reduced = transform(tn, DiagonalReduction) @test all( @@ -100,7 +100,7 @@ @test issetequal(find_diag_axes(A), [[:i, :l], [:j, :m]]) @test issetequal(find_diag_axes(B), [[:j, :n, :o]]) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) reduced = transform(tn, DiagonalReduction) # Test that all tensors (that are no COPY tensors) in reduced have no @@ -124,7 +124,7 @@ D = Tensor(rand(2), (:p,)) E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) - tn = TensorNetwork(Tensor[A, B, C, D, E]) + tn = TensorNetwork([A, B, C, D, E]) reduced = transform(tn, RankSimplification) # Test that the resulting tn contains no tensors with larger rank than the original @@ -175,7 +175,7 @@ @test issetequal(find_anti_diag_axes(parent(A)), [(1, 4), (2, 5)]) @test issetequal(find_anti_diag_axes(parent(B)), [(1, 2)]) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) gauged = transform(tn, AntiDiagonalGauging) # Test that all tensors in gauged have no antidiagonals @@ -201,7 +201,7 @@ @test issetequal(find_zero_columns(parent(A)), [(2, 1), (2, 2)]) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) reduced = transform(tn, ColumnReduction) # Test that all the tensors in reduced have no columns and they do not have the 2nd :j index @@ -226,7 +226,7 @@ @test issetequal(find_zero_columns(parent(A)), [(2, 2)]) - tn = TensorNetwork(Tensor[A, B, C]) + tn = TensorNetwork([A, B, C]) reduced = transform(tn, ColumnReduction) # Test that all the tensors in reduced have no columns and they have smaller dimensions in the 2nd :j index @@ -252,7 +252,7 @@ t1 = contract(v1, v2) tensor = contract(t1, m1) # Define a tensor which can be splitted in three - tn = TensorNetwork(Tensor[tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) + tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) reduced = transform(tn, SplitSimplification) # Test that the new tensors in reduced are smaller than the deleted ones diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index a67784e2..e41a7fb3 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -18,13 +18,13 @@ @testset "TensorNetwork" begin # TODO it crashes - # test_frule(TensorNetwork, Tensor[]) - # test_rrule(TensorNetwork, Tensor[]) + # test_frule(TensorNetwork, []) + # test_rrule(TensorNetwork, []) a = Tensor(rand(4, 2), (:i, :j)) b = Tensor(rand(2, 3), (:j, :k)) - test_frule(TensorNetwork, Tensor[a, b]) - test_rrule(TensorNetwork, Tensor[a, b]) + test_frule(TensorNetwork, [a, b]) + test_rrule(TensorNetwork, [a, b]) end end diff --git a/test/integration/Makie_test.jl b/test/integration/Makie_test.jl index f95bfa5b..4956425f 100644 --- a/test/integration/Makie_test.jl +++ b/test/integration/Makie_test.jl @@ -2,7 +2,7 @@ using CairoMakie using NetworkLayout: Spring - tensors = Tensor[Tensor(rand(2, 2, 2, 2), (:x, :y, :z, :t)), Tensor(rand(2, 2), (:x, :y)), Tensor(rand(2), (:x,))] + tensors = [Tensor(rand(2, 2, 2, 2), (:x, :y, :z, :t)), Tensor(rand(2, 2), (:x, :y)), Tensor(rand(2), (:x,))] tn = TensorNetwork(tensors) @testset "plot!" begin From ecdc993cd1e55e5798c50a6b4280ea45cf60c153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= Date: Tue, 10 Oct 2023 12:20:17 +0200 Subject: [PATCH 29/29] Update `contract` docstring --- src/Numerics.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Numerics.jl b/src/Numerics.jl index ec58264c..33e18e59 100644 --- a/src/Numerics.jl +++ b/src/Numerics.jl @@ -33,7 +33,7 @@ end __omeinsum_sym2str(x) = String[string(i) for i in x] """ - contract(a::Tensor[, b::Tensor, dims=nonunique([inds(a)..., inds(b)...])]) + contract(a::Tensor[, b::Tensor]; dims=nonunique([inds(a)..., inds(b)...])) Perform tensor contraction operation. """