Skip to content

Commit

Permalink
Merge pull request #103 from bsc-quantic/refactor/oop
Browse files Browse the repository at this point in the history
Refactor `TensorNetwork` to class-based OOP organization
  • Loading branch information
arturgs authored Oct 25, 2023
2 parents 41e4e7a + ecdc993 commit d4d0f78
Show file tree
Hide file tree
Showing 31 changed files with 689 additions and 896 deletions.
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["Sergio Sánchez Ramírez <[email protected]>"]
version = "0.2.0"

[deps]
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
Classes = "1a9c1350-211b-5766-99cd-4544d885a0d1"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Expand Down Expand Up @@ -34,7 +34,6 @@ TenetMakieExt = "Makie"
TenetQuacExt = "Quac"

[compat]
Bijections = "0.1"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
DeltaArrays = "0.1.1"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/contraction.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Contraction path optimization and execution is delegated to the [`EinExprs`](https://github.com/bsc-quantic/EinExprs) library. A `EinExpr` is a lower-level form of a Tensor Network, in which the contraction path has been laid out as a tree. It is similar to a symbolic expression (i.e. `Expr`) but in which every node represents an Einstein summation expression (aka `einsum`).

```@docs
einexpr(::TensorNetwork)
contract(::TensorNetwork)
einexpr(::Tenet.AbstractTensorNetwork)
contract(::Tenet.AbstractTensorNetwork)
contract!
```
2 changes: 1 addition & 1 deletion docs/src/examples/ad-tn.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rng = seed!(4) # hide
ψ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 2) # hide
ϕ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 4) # hide
tn = hcat(ψ, ϕ)
tn = merge(ψ, ϕ')
plot(tn) # hide
```
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/google-rqc.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ _sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38,
# load circuit and convert to `TensorNetwork`
circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format = QuacIO.Qflex(), sites = _sites);
tn = TensorNetwork(circuit)
tn = QuantumTensorNetwork(circuit)
tn = view(tn, [i => 1 for i in inds(tn, set=:open)]...)
plot(tn) # hide
```
Expand Down
28 changes: 7 additions & 21 deletions docs/src/quantum/index.md
Original file line number Diff line number Diff line change
@@ -1,44 +1,30 @@
# Introduction

In `Tenet`, we define a [`Quantum`](@ref) Tensor Network as a [`TensorNetwork`](@ref) with a notion of sites and directionality.

```@docs
Quantum
```
In `Tenet`, we define a [`QuantumTensorNetwork`](@ref) as a [`TensorNetwork`](@ref) with a notion of sites and directionality.

```@docs
QuantumTensorNetwork
plug
```

```@docs
sites
```

```@docs
tensors(::TensorNetwork{<:Quantum}, ::Integer)
```

```@docs
boundary
```

## Adjoint

```@docs
adjoint
```

## Concatenation
## Norm

```@docs
hcat(::TensorNetwork{<:Quantum}, ::TensorNetwork{<:Quantum})
LinearAlgebra.norm(::Tenet.AbstractQuantumTensorNetwork, ::Real)
LinearAlgebra.normalize!(::Tenet.AbstractQuantumTensorNetwork, ::Real)
```

## Norm
## Trace

```@docs
LinearAlgebra.norm(::TensorNetwork{<:Quantum}, p::Real)
LinearAlgebra.normalize!(::TensorNetwork{<:Quantum}, ::Real)
LinearAlgebra.tr(::Tenet.AbstractQuantumTensorNetwork)
```

## Fidelity
Expand Down
22 changes: 10 additions & 12 deletions docs/src/tensor-network.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,26 @@ Information about a `TensorNetwork` can be queried with the following functions.
## Query information

```@docs
inds(::TensorNetwork)
size(::TensorNetwork)
tensors(::TensorNetwork)
length(::TensorNetwork)
ansatz
inds(::Tenet.AbstractTensorNetwork)
size(::Tenet.AbstractTensorNetwork)
tensors(::Tenet.AbstractTensorNetwork)
```

## Modification

### Add/Remove tensors

```@docs
push!(::TensorNetwork, ::Tensor)
append!(::TensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor})
pop!(::TensorNetwork, ::Tensor)
delete!(::TensorNetwork, ::Any)
push!(::Tenet.AbstractTensorNetwork, ::Tensor)
append!(::Tenet.AbstractTensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor})
merge!(::Tenet.AbstractTensorNetwork, ::Tenet.AbstractTensorNetwork)
pop!(::Tenet.AbstractTensorNetwork, ::Tensor)
delete!(::Tenet.AbstractTensorNetwork, ::Any)
```

### Replace existing elements

```@docs
replace
replace!
```

Expand All @@ -59,12 +57,12 @@ replace!
select
selectdim
slice!
view(::TensorNetwork)
view(::Tenet.AbstractTensorNetwork)
```

## Miscelaneous

```@docs
Base.copy(::TensorNetwork)
Base.copy(::Tenet.AbstractTensorNetwork)
Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer)
```
2 changes: 1 addition & 1 deletion docs/src/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ set_theme!(resolution=(800,400)) # hide
sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 72, 73, 74, 75, 76, 83, 84, 85, 94]
circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format=QuacIO.Qflex(), sites=sites)
tn = TensorNetwork(circuit)
tn = QuantumTensorNetwork(circuit)
# Apply transformations to the tensor network
transformed_tn = transform(tn, [Tenet.AntiDiagonalGauging, Tenet.DiagonalReduction, Tenet.ColumnReduction, Tenet.RankSimplification])
Expand Down
37 changes: 27 additions & 10 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TenetChainRulesCoreExt

using Tenet
using Classes
using ChainRulesCore

function ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor}
Expand All @@ -26,29 +27,45 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull
@non_differentiable intersect(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
@non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)

function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork}
ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata)
function ChainRulesCore.ProjectTo(tn::T) where {T<:absclass(TensorNetwork)}
# TODO create function to extract extra fields
fields = map(fieldnames(T)) do fieldname
if fieldname === :tensors
:tensors => ProjectTo(tn.tensors)
else
fieldname => getfield(tn, fieldname)
end
end
ProjectTo{T}(; fields...)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetwork}
function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:absclass(TensorNetwork)}
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:absclass(TensorNetwork)}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
tensors = map(+, tensors(x), Δ.tensors)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), Tangent{TensorNetwork}(tensors = Δ)
function ChainRulesCore.frule((_, Δ), T::Type{<:absclass(TensorNetwork)}, tensors)
T(tensors), Tangent{TensorNetwork}(tensors = Δ)
end

TensorNetwork_pullback::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors)
TensorNetwork_pullback::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
function ChainRulesCore.rrule(T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), TensorNetwork_pullback
function ChainRulesCore.rrule(T::Type{<:absclass(TensorNetwork)}, tensors)
T(tensors), TensorNetwork_pullback
end

end
7 changes: 3 additions & 4 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Classes

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork)
return Tangent{TensorNetwork}(
tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)],
)
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::T) where {T<:absclass(TensorNetwork)}
return Tangent{T}(tensors = [ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
end

end
13 changes: 11 additions & 2 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
module TenetFiniteDifferencesExt

using Tenet
using Classes
using FiniteDifferences

function FiniteDifferences.to_vec(x::TensorNetwork{A}) where {A<:Ansatz}
function FiniteDifferences.to_vec(x::T) where {T<:absclass(TensorNetwork)}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)
TensorNetwork{A}(tensors; x.metadata...)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end

return x_vec, TensorNetwork_from_vec
Expand Down
7 changes: 4 additions & 3 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Tenet
using Combinatorics: combinations
using Graphs
using Makie
using Classes

using GraphMakie

Expand All @@ -19,7 +20,7 @@ Plot a [`TensorNetwork`](@ref) as a graph.
- `labels` If `true`, show the labels of the tensor indices. Defaults to `false`.
- The rest of `kwargs` are passed to `GraphMakie.graphplot`.
"""
function Makie.plot(@nospecialize tn::TensorNetwork; kwargs...)
function Makie.plot(@nospecialize tn::absclass(TensorNetwork); kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], tn; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
Expand All @@ -28,7 +29,7 @@ end
# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable
__networklayout_dim(x) = typeof(x).super.parameters |> first

function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetwork; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::absclass(TensorNetwork); kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand All @@ -45,7 +46,7 @@ function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetw
return Makie.AxisPlot(ax, p)
end

function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::TensorNetwork; labels = false, kwargs...)
function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::absclass(TensorNetwork); labels = false, kwargs...)
hypermap = Tenet.hyperflatten(tn)
tn = transform(tn, Tenet.HyperindConverter)

Expand Down
15 changes: 6 additions & 9 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ module TenetQuacExt

using Tenet
using Quac: Circuit, lanes, arraytype, Swap
using Bijections

function Tenet.TensorNetwork(circuit::Circuit)
function Tenet.QuantumTensorNetwork(circuit::Circuit)
n = lanes(circuit)
wire = [[Tenet.letter(i)] for i in 1:n]
tensors = Tensor[]
tn = TensorNetwork()

i = n + 1

Expand All @@ -29,15 +28,13 @@ function Tenet.TensorNetwork(circuit::Circuit)
end |> x -> zip(x...) |> Iterators.flatten |> collect

tensor = Tensor(array, inds)
push!(tensors, tensor)
push!(tn, tensor)
end

interlayer = [
Bijection(Dict([site => first(index) for (site, index) in enumerate(wire)])),
Bijection(Dict([site => last(index) for (site, index) in enumerate(wire)])),
]
input = first.(wire)
output = last.(wire)

return TensorNetwork{Quantum}(tensors; plug = Tenet.Operator, interlayer)
return QuantumTensorNetwork(tn, input, output)
end

end
10 changes: 0 additions & 10 deletions src/Helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,6 @@ julia> letter(20204)
letter(i) =
Iterators.drop(Iterators.filter(isletter, Iterators.map(Char, 1:2^21-1)), i - 1) |> iterate |> first |> Symbol

Base.merge(@nospecialize(A::Type{<:NamedTuple}), @nospecialize(Bs::Type{<:NamedTuple}...)) = NamedTuple{
foldl((acc, B) -> (acc..., B...), Iterators.map(fieldnames, Bs); init = fieldnames(A)),
foldl((acc, B) -> Tuple{fieldtypes(acc)...,B...}, Iterators.map(fieldtypes, Bs); init = A),
}

function superansatzes(T)
S = supertype(T)
return T === Ansatz ? (T,) : (T, superansatzes(S)...)
end

# NOTE from https://stackoverflow.com/q/54652787
function nonunique(x)
uniqueindexes = indexin(unique(x), x)
Expand Down
3 changes: 1 addition & 2 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using OMEinsum
using LinearAlgebra
using UUIDs: uuid4
using EinExprs: inds

# TODO test array container typevar on output
for op in [
Expand Down Expand Up @@ -34,7 +33,7 @@ end
__omeinsum_sym2str(x) = String[string(i) for i in x]

"""
contract(a::Tensor[, b::Tensor, dims=nonunique([inds(a)..., inds(b)...])])
contract(a::Tensor[, b::Tensor]; dims=nonunique([inds(a)..., inds(b)...]))
Perform tensor contraction operation.
"""
Expand Down
Loading

0 comments on commit d4d0f78

Please sign in to comment.