Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Ansatz to traits #97

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ImmutableArrays = "667c17eb-ab9b-4487-935f-1c621bb82497"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tensor-network.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ inds(::TensorNetwork)
size(::TensorNetwork)
tensors(::TensorNetwork)
length(::TensorNetwork)
ansatz
domain
```

## Modification
Expand Down
12 changes: 6 additions & 6 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ using Tenet
using ChainRulesCore

function ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor}
ProjectTo{T}(; data = ProjectTo(tensor.data), inds = tensor.inds, meta = tensor.meta)
ProjectTo{T}(; data = ProjectTo(tensor.data), inds = tensor.inds)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:Tensor}
T(projector.data(dx.data), projector.inds; projector.meta...)
T(projector.data(dx.data), projector.inds)
end

ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), T(Δ, inds; meta...)
ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds)

Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent())
Tensor_pullback(Δ::AbstractThunk) = Tensor_pullback(unthunk(Δ))
ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), Tensor_pullback
ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pullback

# NOTE fix problem with vector generator in `contract`
@non_differentiable Tenet.__omeinsum_sym2str(x)
Expand All @@ -35,10 +35,10 @@ function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetw
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
function Base.:+(x::TensorNetwork{D}, Δ::Tangent{TensorNetwork}) where {D}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
TensorNetwork{D}(tensors; x.metadata...)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...)
Expand Down
4 changes: 2 additions & 2 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module TenetFiniteDifferencesExt
using Tenet
using FiniteDifferences

function FiniteDifferences.to_vec(x::TensorNetwork{A}) where {A<:Ansatz}
function FiniteDifferences.to_vec(x::TensorNetwork{D}) where {D}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)
TensorNetwork{A}(tensors; x.metadata...)
TensorNetwork{D}(tensors; x.metadata...)
end

return x_vec, TensorNetwork_from_vec
Expand Down
9 changes: 5 additions & 4 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Plot a [`TensorNetwork`](@ref) as a graph.
- `labels` If `true`, show the labels of the tensor indices. Defaults to `false`.
- The rest of `kwargs` are passed to `GraphMakie.graphplot`.
"""
function Makie.plot(tn::TensorNetwork; kwargs...)
function Makie.plot(@nospecialize tn::TensorNetwork; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], tn; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
Expand All @@ -28,7 +28,7 @@ end
# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable
__networklayout_dim(x) = typeof(x).super.parameters |> first

function Makie.plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetwork; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand All @@ -45,14 +45,15 @@ function Makie.plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...
return Makie.AxisPlot(ax, p)
end

function Makie.plot!(ax::Union{Axis,Axis3}, tn::TensorNetwork; labels = false, kwargs...)
function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::TensorNetwork; labels = false, kwargs...)
hypermap = Tenet.hyperflatten(tn)
tn = transform(tn, Tenet.HyperindConverter)

# TODO how to mark multiedges? (i.e. parallel edges)
graph = SimpleGraph([Edge(tensors...) for (_, tensors) in tn.indices if length(tensors) > 1])

# TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations
copytensors = findall(t -> haskey(t.meta, :dual), tensors(tn))
copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), values(hypermap)), tensors(tn))
ghostnodes = map(inds(tn, :open)) do ind
# create new ghost node
add_vertex!(graph)
Expand Down
2 changes: 1 addition & 1 deletion ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function Tenet.TensorNetwork(circuit::Circuit)
(from, to)
end |> x -> zip(x...) |> Iterators.flatten |> collect

tensor = Tensor(array, tuple(inds...); gate = gate)
tensor = Tensor(array, inds)
push!(tensors, tensor)
end

Expand Down
5 changes: 0 additions & 5 deletions src/Helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ Base.merge(@nospecialize(A::Type{<:NamedTuple}), @nospecialize(Bs::Type{<:NamedT
foldl((acc, B) -> Tuple{fieldtypes(acc)...,B...}, Iterators.map(fieldtypes, Bs); init = A),
}

function superansatzes(T)
S = supertype(T)
return T === Ansatz ? (T,) : (T, superansatzes(S)...)
end

# NOTE from https://stackoverflow.com/q/54652787
function nonunique(x)
uniqueindexes = indexin(unique(x), x)
Expand Down
8 changes: 3 additions & 5 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,17 @@ function contract(a::Tensor, b::Tensor; dims = (∩(inds(a), inds(b))))

data = EinCode((_ia, _ib), _ic)(parent(a), parent(b))

# TODO merge metadata?
return Tensor(data, ic)
end

function contract(a::Tensor; dims = nonunique(inds(a)))
ia = inds(a)
i = ∩(dims, ia)

ic = tuple(setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))...)
ic = setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))

data = EinCode((String.(ia),), String.(ic))(parent(a))

# TODO merge metadata
return Tensor(data, ic)
end

Expand All @@ -79,8 +77,8 @@ contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs
Alias for [`contract`](@ref).
"""
Base.:*(a::Tensor, b::Tensor) = contract(a, b)
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a); a.meta...)
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b); b.meta...)
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a))
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b))

LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

Expand Down
5 changes: 1 addition & 4 deletions src/Quantum/MP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ function MatrixProduct{P,B}(
iinds[i]
end
end
alias = Dict(dir => label for (dir, label) in zip(dirs, inds))

Tensor(array, inds; alias = alias)
Tensor(array, inds)
end

return TensorNetwork{MatrixProduct{P,B}}(tensors; χ, plug = P, interlayer, metadata...)
Expand All @@ -115,8 +114,6 @@ const MPO = MatrixProduct{Operator}
tensors(ψ::TensorNetwork{MatrixProduct{P,Infinite}}, site::Int, args...) where {P<:Plug} =
tensors(plug(ψ), ψ, mod1(site, length(ψ.tensors)), args...)

Base.length(ψ::TensorNetwork{MatrixProduct{P,Infinite}}) where {P<:Plug} = Inf

# NOTE does not use optimal contraction path, but "parallel-optimal" which costs x2 more
# function contractpath(a::TensorNetwork{<:MatrixProductState}, b::TensorNetwork{<:MatrixProductState})
# !issetequal(sites(a), sites(b)) && throw(ArgumentError("both tensor networks are expected to have same sites"))
Expand Down
5 changes: 1 addition & 4 deletions src/Quantum/PEP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,8 @@ function ProjectedEntangledPair{P,B}(
oinds[(i, j)]
end
end
alias = Dict(dir => label for (dir, label) in zip(dirs, inds))

Tensor(array, inds; alias = alias)
Tensor(array, inds)
end |> vec

return TensorNetwork{ProjectedEntangledPair{P,B}}(tensors; χ, plug = P, interlayer, metadata...)
Expand All @@ -121,8 +120,6 @@ const PEPO = ProjectedEntangledPair{Operator}
tensors(ψ::TensorNetwork{ProjectedEntangledPair{P,Infinite}}, site::Int, args...) where {P<:Plug} =
tensors(plug(ψ), ψ, mod1(site, length(ψ.tensors)), args...)

Base.length(ψ::TensorNetwork{ProjectedEntangledPair{P,Infinite}}) where {P<:Plug} = Inf

# TODO normalize
# TODO let choose the orthogonality center
# TODO different input/output physical dims
Expand Down
8 changes: 3 additions & 5 deletions src/Quantum/Quantum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ function layers(tn::TensorNetwork{As}, i) where {As<:Composite}
end

return TensorNetwork{A}(
filter(tensor -> get(tensor.meta, :layer, nothing) == i, tensors(tn));
# TODO revise this
#filter(tensor -> get(tensor.meta, :layer, nothing) == i, tensors(tn));
tensors(tn);
plug = layer_plug,
interlayer,
meta...,
Expand Down Expand Up @@ -197,10 +199,6 @@ function Base.hcat(A::TensorNetwork{QA}, B::TensorNetwork{QB}) where {QA<:Quantu
# rename inner indices of B to avoid hyperindices
replace!(B, [i => Symbol(uuid4()) for i in inds(B, :inner)]...)

# TODO refactor this part to be compatible with more layers
foreach(tensor -> tensor.meta[:layer] = 1, tensors(A))
foreach(tensor -> tensor.meta[:layer] = 2, tensors(B))

combined_plug = merge(plug(A), plug(B))

# merge tensors and indices
Expand Down
3 changes: 1 addition & 2 deletions src/Tenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@ include("Helpers.jl")

include("Tensor.jl")
export Tensor, contract, dim, expand
export tags, hastag, tag!, untag!

include("Numerics.jl")

include("TensorNetwork.jl")
export TensorNetwork, tensors, arrays, select, slice!
export Domain, domain
export contract, contract!
export Ansatz, ansatz, Arbitrary

include("Transformations.jl")
export transform, transform!
Expand Down
Loading
Loading