Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Ansatz to traits #97

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ImmutableArrays = "667c17eb-ab9b-4487-935f-1c621bb82497"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Muscle = "21fe5c4b-a943-414d-bf3e-516f24900631"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Expand Down
8 changes: 4 additions & 4 deletions docs/src/quantum/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ sites
```

```@docs
tensors(::TensorNetwork{<:Quantum}, ::Integer)
tensors(::TensorNetwork{Quantum}, ::Integer)
```

```@docs
Expand All @@ -31,14 +31,14 @@ adjoint
## Concatenation

```@docs
hcat(::TensorNetwork{<:Quantum}, ::TensorNetwork{<:Quantum})
hcat(::TensorNetwork{Quantum}, ::TensorNetwork{Quantum})
```

## Norm

```@docs
LinearAlgebra.norm(::TensorNetwork{<:Quantum}, p::Real)
LinearAlgebra.normalize!(::TensorNetwork{<:Quantum}, ::Real)
LinearAlgebra.norm(::TensorNetwork{Quantum}, p::Real)
LinearAlgebra.normalize!(::TensorNetwork{Quantum}, ::Real)
```

## Fidelity
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tensor-network.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ inds(::TensorNetwork)
size(::TensorNetwork)
tensors(::TensorNetwork)
length(::TensorNetwork)
ansatz
domain
```

## Modification
Expand Down
12 changes: 6 additions & 6 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@ using Tenet
using ChainRulesCore

function ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor}
ProjectTo{T}(; data = ProjectTo(tensor.data), inds = tensor.inds, meta = tensor.meta)
ProjectTo{T}(; data = ProjectTo(tensor.data), inds = tensor.inds)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:Tensor}
T(projector.data(dx.data), projector.inds; projector.meta...)
T(projector.data(dx.data), projector.inds)
end

ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), T(Δ, inds; meta...)
ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds)

Tensor_pullback(Δ) = (NoTangent(), Δ.data, NoTangent())
Tensor_pullback(Δ::AbstractThunk) = Tensor_pullback(unthunk(Δ))
ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds; meta...) = T(data, inds; meta...), Tensor_pullback
ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pullback

# NOTE fix problem with vector generator in `contract`
@non_differentiable Tenet.__omeinsum_sym2str(x)
Expand All @@ -35,10 +35,10 @@ function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetw
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
function Base.:+(x::TensorNetwork{D}, Δ::Tangent{TensorNetwork}) where {D}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
TensorNetwork{D}(tensors; x.metadata...)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...)
Expand Down
4 changes: 2 additions & 2 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module TenetFiniteDifferencesExt
using Tenet
using FiniteDifferences

function FiniteDifferences.to_vec(x::TensorNetwork{A}) where {A<:Ansatz}
function FiniteDifferences.to_vec(x::TensorNetwork{D}) where {D}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)
TensorNetwork{A}(tensors; x.metadata...)
TensorNetwork{D}(tensors; x.metadata...)
end

return x_vec, TensorNetwork_from_vec
Expand Down
9 changes: 5 additions & 4 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Plot a [`TensorNetwork`](@ref) as a graph.
- `labels` If `true`, show the labels of the tensor indices. Defaults to `false`.
- The rest of `kwargs` are passed to `GraphMakie.graphplot`.
"""
function Makie.plot(tn::TensorNetwork; kwargs...)
function Makie.plot(@nospecialize tn::TensorNetwork; kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], tn; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
Expand All @@ -28,7 +28,7 @@ end
# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable
__networklayout_dim(x) = typeof(x).super.parameters |> first

function Makie.plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetwork; kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand All @@ -45,14 +45,15 @@ function Makie.plot!(f::Union{Figure,GridPosition}, tn::TensorNetwork; kwargs...
return Makie.AxisPlot(ax, p)
end

function Makie.plot!(ax::Union{Axis,Axis3}, tn::TensorNetwork; labels = false, kwargs...)
function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::TensorNetwork; labels = false, kwargs...)
hypermap = Tenet.hyperflatten(tn)
tn = transform(tn, Tenet.HyperindConverter)

# TODO how to mark multiedges? (i.e. parallel edges)
graph = SimpleGraph([Edge(tensors...) for (_, tensors) in tn.indices if length(tensors) > 1])

# TODO recognise `copytensors` by using `DeltaArray` or `Diagonal` representations
copytensors = findall(t -> haskey(t.meta, :dual), tensors(tn))
copytensors = findall(tensor -> any(flatinds -> issetequal(inds(tensor), flatinds), values(hypermap)), tensors(tn))
ghostnodes = map(inds(tn, :open)) do ind
# create new ghost node
add_vertex!(graph)
Expand Down
8 changes: 4 additions & 4 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TenetQuacExt

using Tenet
using Quac: Circuit, lanes, arraytype, Swap
using Bijections
using Bijection: Bijection

function Tenet.TensorNetwork(circuit::Circuit)
n = lanes(circuit)
Expand All @@ -28,16 +28,16 @@ function Tenet.TensorNetwork(circuit::Circuit)
(from, to)
end |> x -> zip(x...) |> Iterators.flatten |> collect

tensor = Tensor(array, tuple(inds...); gate = gate)
tensor = Tensor(array, inds)
push!(tensors, tensor)
end

interlayer = [
plug = [
Bijection(Dict([site => first(index) for (site, index) in enumerate(wire)])),
Bijection(Dict([site => last(index) for (site, index) in enumerate(wire)])),
]

return TensorNetwork{Quantum}(tensors; plug = Tenet.Operator, interlayer)
return TensorNetwork{Quantum}(tensors; plug = Tenet.Operator, plug)
end

end
5 changes: 0 additions & 5 deletions src/Helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ Base.merge(@nospecialize(A::Type{<:NamedTuple}), @nospecialize(Bs::Type{<:NamedT
foldl((acc, B) -> Tuple{fieldtypes(acc)...,B...}, Iterators.map(fieldtypes, Bs); init = A),
}

function superansatzes(T)
S = supertype(T)
return T === Ansatz ? (T,) : (T, superansatzes(S)...)
end

# NOTE from https://stackoverflow.com/q/54652787
function nonunique(x)
uniqueindexes = indexin(unique(x), x)
Expand Down
8 changes: 3 additions & 5 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,17 @@ function contract(a::Tensor, b::Tensor; dims = (∩(inds(a), inds(b))))

data = EinCode((_ia, _ib), _ic)(parent(a), parent(b))

# TODO merge metadata?
return Tensor(data, ic)
end

function contract(a::Tensor; dims = nonunique(inds(a)))
ia = inds(a)
i = ∩(dims, ia)

ic = tuple(setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))...)
ic = setdiff(ia, i isa Base.AbstractVecOrTuple ? i : (i,))

data = EinCode((String.(ia),), String.(ic))(parent(a))

# TODO merge metadata
return Tensor(data, ic)
end

Expand All @@ -79,8 +77,8 @@ contract(tensors::Tensor...; kwargs...) = reduce((x, y) -> contract(x, y; kwargs
Alias for [`contract`](@ref).
"""
Base.:*(a::Tensor, b::Tensor) = contract(a, b)
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a); a.meta...)
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b); b.meta...)
Base.:*(a::T, b::Number) where {T<:Tensor} = T(parent(a) * b, inds(a))
Base.:*(a::Number, b::T) where {T<:Tensor} = T(a * parent(b), inds(b))

LinearAlgebra.svd(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke svd(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

Expand Down
Loading
Loading