Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TensorNetwork to class-based OOP organization #103

Merged
merged 29 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
146838d
Refactor `TensorNetwork` to `@class`
mofeing Oct 4, 2023
6f32e3d
Fix refactor in `ChainRulesCore`,`FiniteDifferences` extensions
mofeing Oct 4, 2023
dc00471
Fix `ProjectTo` to `TensorNetwork`
mofeing Oct 4, 2023
b5b3317
Import `EinExprs.inds` symbol
mofeing Oct 5, 2023
eb5094a
Fix invalidation of `EinExprs.inds` symbol import
mofeing Oct 5, 2023
2c91f0e
Fix refactor on `Makie` extension
mofeing Oct 5, 2023
a941fcd
Fix `Classes` import in `Makie` extension
mofeing Oct 6, 2023
270be2b
Split functionality from `append!(::TensorNetwork)` to `merge!`
mofeing Oct 6, 2023
79027fd
Autoimplement `copy` for `TensorNetwork` subtypes
mofeing Oct 6, 2023
7b48e03
Fix `replace!(::TensorNetwork)` for list of `Pair`s
mofeing Oct 6, 2023
e26b1f0
Fix mutation on `merge(::TensorNetwork)`
mofeing Oct 6, 2023
4e8ea4b
Refactor `Quantum` TNs
mofeing Oct 6, 2023
dac1cb8
Refactor `TNSampler` to new OOP architecture
mofeing Oct 7, 2023
7bcc2f7
Refactor `MatrixProduct`
mofeing Oct 7, 2023
462a161
Refactor `Quac` extension
mofeing Oct 9, 2023
2057a14
Test changes for MPO
mofeing Oct 9, 2023
ebcab1a
Refactor `replace`
mofeing Oct 9, 2023
f28bcf4
Refactor things for which I'm too lazy to name
mofeing Oct 9, 2023
64ed193
Fix `normalize!`
mofeing Oct 9, 2023
b872801
Remove legacy code
mofeing Oct 9, 2023
4ea6ab9
Update `ChainRules` extensions
mofeing Oct 9, 2023
e2e9721
Fix `copy` on `TensorNetwork`
mofeing Oct 9, 2023
62cce3b
Fix `PEP` constructor
mofeing Oct 9, 2023
b090cfc
Implement trace methods for `QuantumTensorNetwork` #110
mofeing Oct 9, 2023
7be0448
Fix docs
mofeing Oct 9, 2023
d3e4c11
Enable `norm` test on `MPO`
mofeing Oct 10, 2023
6567aa9
Fix `tensors` type in `TensorNetwork` constructor
mofeing Oct 10, 2023
86f3a05
Relax `Vector{Tensor}` conversion on `TensorNetwork` constructor
mofeing Oct 10, 2023
ecdc993
Update `contract` docstring
mofeing Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ authors = ["Sergio Sánchez Ramírez <[email protected]>"]
version = "0.2.0"

[deps]
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
Classes = "1a9c1350-211b-5766-99cd-4544d885a0d1"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Expand Down Expand Up @@ -34,7 +34,6 @@ TenetMakieExt = "Makie"
TenetQuacExt = "Quac"

[compat]
Bijections = "0.1"
ChainRulesCore = "1.0"
Combinatorics = "1.0"
DeltaArrays = "0.1.1"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/contraction.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Contraction path optimization and execution is delegated to the [`EinExprs`](https://github.com/bsc-quantic/EinExprs) library. A `EinExpr` is a lower-level form of a Tensor Network, in which the contraction path has been laid out as a tree. It is similar to a symbolic expression (i.e. `Expr`) but in which every node represents an Einstein summation expression (aka `einsum`).

```@docs
einexpr(::TensorNetwork)
contract(::TensorNetwork)
einexpr(::Tenet.AbstractTensorNetwork)
contract(::Tenet.AbstractTensorNetwork)
contract!
```
2 changes: 1 addition & 1 deletion docs/src/examples/ad-tn.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rng = seed!(4) # hide
ψ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 2) # hide
ϕ = rand(rng, MPS{Open}, n = 4, p = 2, χ = 4) # hide
tn = hcat(ψ, ϕ)
tn = merge(ψ, ϕ')
plot(tn) # hide
```
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/google-rqc.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ _sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38,
# load circuit and convert to `TensorNetwork`
circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format = QuacIO.Qflex(), sites = _sites);
tn = TensorNetwork(circuit)
tn = QuantumTensorNetwork(circuit)
tn = view(tn, [i => 1 for i in inds(tn, set=:open)]...)
plot(tn) # hide
```
Expand Down
28 changes: 7 additions & 21 deletions docs/src/quantum/index.md
Original file line number Diff line number Diff line change
@@ -1,44 +1,30 @@
# Introduction

In `Tenet`, we define a [`Quantum`](@ref) Tensor Network as a [`TensorNetwork`](@ref) with a notion of sites and directionality.

```@docs
Quantum
```
In `Tenet`, we define a [`QuantumTensorNetwork`](@ref) as a [`TensorNetwork`](@ref) with a notion of sites and directionality.

```@docs
QuantumTensorNetwork
plug
```

```@docs
sites
```

```@docs
tensors(::TensorNetwork{<:Quantum}, ::Integer)
```

```@docs
boundary
```

## Adjoint

```@docs
adjoint
```

## Concatenation
## Norm

```@docs
hcat(::TensorNetwork{<:Quantum}, ::TensorNetwork{<:Quantum})
LinearAlgebra.norm(::Tenet.AbstractQuantumTensorNetwork, ::Real)
LinearAlgebra.normalize!(::Tenet.AbstractQuantumTensorNetwork, ::Real)
```

## Norm
## Trace

```@docs
LinearAlgebra.norm(::TensorNetwork{<:Quantum}, p::Real)
LinearAlgebra.normalize!(::TensorNetwork{<:Quantum}, ::Real)
LinearAlgebra.tr(::Tenet.AbstractQuantumTensorNetwork)
```

## Fidelity
Expand Down
22 changes: 10 additions & 12 deletions docs/src/tensor-network.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,28 +28,26 @@ Information about a `TensorNetwork` can be queried with the following functions.
## Query information

```@docs
inds(::TensorNetwork)
size(::TensorNetwork)
tensors(::TensorNetwork)
length(::TensorNetwork)
ansatz
inds(::Tenet.AbstractTensorNetwork)
size(::Tenet.AbstractTensorNetwork)
tensors(::Tenet.AbstractTensorNetwork)
```

## Modification

### Add/Remove tensors

```@docs
push!(::TensorNetwork, ::Tensor)
append!(::TensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor})
pop!(::TensorNetwork, ::Tensor)
delete!(::TensorNetwork, ::Any)
push!(::Tenet.AbstractTensorNetwork, ::Tensor)
append!(::Tenet.AbstractTensorNetwork, ::Base.AbstractVecOrTuple{<:Tensor})
merge!(::Tenet.AbstractTensorNetwork, ::Tenet.AbstractTensorNetwork)
pop!(::Tenet.AbstractTensorNetwork, ::Tensor)
delete!(::Tenet.AbstractTensorNetwork, ::Any)
```

### Replace existing elements

```@docs
replace
replace!
```

Expand All @@ -59,12 +57,12 @@ replace!
select
selectdim
slice!
view(::TensorNetwork)
view(::Tenet.AbstractTensorNetwork)
```

## Miscelaneous

```@docs
Base.copy(::TensorNetwork)
Base.copy(::Tenet.AbstractTensorNetwork)
Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer)
```
10 changes: 5 additions & 5 deletions docs/src/transformations.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ A = Tensor(data, (:i, :j, :k, :l)) #hide
B = Tensor(rand(2, 2), (:i, :m)) #hide
C = Tensor(rand(2, 2), (:j, :n)) #hide

tn = TensorNetwork([A, B, C]) #hide
tn = TensorNetwork(Tensor[A, B, C]) #hide
mofeing marked this conversation as resolved.
Show resolved Hide resolved
reduced = transform(tn, Tenet.DiagonalReduction) #hide

smooth_annotation!( #hide
Expand Down Expand Up @@ -139,7 +139,7 @@ B = Tensor(rand(2, 2), (:i, :m)) #hide
C = Tensor(rand(2, 2, 2), (:m, :n, :o)) #hide
E = Tensor(rand(2, 2, 2, 2), (:o, :p, :q, :j)) #hide

tn = TensorNetwork([A, B, C, E]) #hide
tn = TensorNetwork(Tensor[A, B, C, E]) #hide
reduced = transform(tn, Tenet.RankSimplification) #hide

smooth_annotation!( #hide
Expand Down Expand Up @@ -193,7 +193,7 @@ A = Tensor(data, (:i, :j, :k)) #hide
B = Tensor(rand(3, 3), (:j, :l)) #hide
C = Tensor(rand(3, 3), (:l, :m)) #hide

tn = TensorNetwork([A, B, C]) #hide
tn = TensorNetwork(Tensor[A, B, C]) #hide
reduced = transform(tn, Tenet.ColumnReduction) #hide

smooth_annotation!( #hide
Expand Down Expand Up @@ -247,7 +247,7 @@ m1 = Tensor(rand(3, 3), (:k, :l)) #hide
t1 = contract(v1, v2) #hide
tensor = contract(t1, m1) #hide

tn = TensorNetwork([tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide
tn = TensorNetwork(Tensor[tensor, Tensor(rand(3, 3, 3), (:k, :m, :n)), Tensor(rand(3, 3, 3), (:l, :n, :o))]) #hide
reduced = transform(tn, Tenet.SplitSimplification) #hide

smooth_annotation!( #hide
Expand Down Expand Up @@ -294,7 +294,7 @@ set_theme!(resolution=(800,400)) # hide

sites = [5, 6, 14, 15, 16, 17, 24, 25, 26, 27, 28, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 72, 73, 74, 75, 76, 83, 84, 85, 94]
circuit = QuacIO.parse(joinpath(@__DIR__, "sycamore_53_10_0.qasm"), format=QuacIO.Qflex(), sites=sites)
tn = TensorNetwork(circuit)
tn = QuantumTensorNetwork(circuit)

# Apply transformations to the tensor network
transformed_tn = transform(tn, [Tenet.AntiDiagonalGauging, Tenet.DiagonalReduction, Tenet.ColumnReduction, Tenet.RankSimplification])
Expand Down
37 changes: 27 additions & 10 deletions ext/TenetChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module TenetChainRulesCoreExt

using Tenet
using Classes
using ChainRulesCore

function ChainRulesCore.ProjectTo(tensor::T) where {T<:Tensor}
Expand All @@ -26,29 +27,45 @@ ChainRulesCore.rrule(T::Type{<:Tensor}, data, inds) = T(data, inds), Tensor_pull
@non_differentiable intersect(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)
@non_differentiable symdiff(s::Base.AbstractVecOrTuple{Symbol}, itrs::Base.AbstractVecOrTuple{Symbol}...)

function ChainRulesCore.ProjectTo(tn::T) where {T<:TensorNetwork}
ProjectTo{T}(; tensors = ProjectTo(tn.tensors), metadata = tn.metadata)
function ChainRulesCore.ProjectTo(tn::T) where {T<:absclass(TensorNetwork)}
# TODO create function to extract extra fields
fields = map(fieldnames(T)) do fieldname
if fieldname === :tensors
:tensors => ProjectTo(tn.tensors)
else
fieldname => getfield(tn, fieldname)
end
end
ProjectTo{T}(; fields...)
end

function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:TensorNetwork}
function (projector::ProjectTo{T})(dx::Union{T,Tangent{T}}) where {T<:absclass(TensorNetwork)}
dx.tensors isa NoTangent && return NoTangent()
Tangent{TensorNetwork}(tensors = projector.tensors(dx.tensors))
end

function Base.:+(x::TensorNetwork{A}, Δ::Tangent{TensorNetwork}) where {A<:Ansatz}
function Base.:+(x::T, Δ::Tangent{TensorNetwork}) where {T<:absclass(TensorNetwork)}
# TODO match tensors by indices
tensors = map(+, x.tensors, Δ.tensors)
TensorNetwork{A}(tensors; x.metadata...)
tensors = map(+, tensors(x), Δ.tensors)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end

function ChainRulesCore.frule((_, Δ), T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), Tangent{TensorNetwork}(tensors = Δ)
function ChainRulesCore.frule((_, Δ), T::Type{<:absclass(TensorNetwork)}, tensors)
T(tensors), Tangent{TensorNetwork}(tensors = Δ)
end

TensorNetwork_pullback::Tangent{TensorNetwork}) = (NoTangent(), Δ.tensors)
TensorNetwork_pullback::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
function ChainRulesCore.rrule(T::Type{<:TensorNetwork}, tensors; metadata...)
T(tensors; metadata...), TensorNetwork_pullback
function ChainRulesCore.rrule(T::Type{<:absclass(TensorNetwork)}, tensors)
T(tensors), TensorNetwork_pullback
end

end
7 changes: 3 additions & 4 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Classes

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork)
return Tangent{TensorNetwork}(
tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)],
)
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::T) where {T<:absclass(TensorNetwork)}
return Tangent{T}(tensors = Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
end

end
13 changes: 11 additions & 2 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
module TenetFiniteDifferencesExt

using Tenet
using Classes
using FiniteDifferences

function FiniteDifferences.to_vec(x::TensorNetwork{A}) where {A<:Ansatz}
function FiniteDifferences.to_vec(x::T) where {T<:absclass(TensorNetwork)}
x_vec, back = to_vec(x.tensors)
function TensorNetwork_from_vec(v)
tensors = back(v)
TensorNetwork{A}(tensors; x.metadata...)

# TODO create function fitted for this? or maybe standardize constructors?
T(map(fieldnames(T)) do fieldname
if fieldname === :tensors
tensors
else
getfield(x, fieldname)
end
end...)
end

return x_vec, TensorNetwork_from_vec
Expand Down
7 changes: 4 additions & 3 deletions ext/TenetMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Tenet
using Combinatorics: combinations
using Graphs
using Makie
using Classes

using GraphMakie

Expand All @@ -19,7 +20,7 @@ Plot a [`TensorNetwork`](@ref) as a graph.
- `labels` If `true`, show the labels of the tensor indices. Defaults to `false`.
- The rest of `kwargs` are passed to `GraphMakie.graphplot`.
"""
function Makie.plot(@nospecialize tn::TensorNetwork; kwargs...)
function Makie.plot(@nospecialize tn::absclass(TensorNetwork); kwargs...)
f = Figure()
ax, p = plot!(f[1, 1], tn; kwargs...)
return Makie.FigureAxisPlot(f, ax, p)
Expand All @@ -28,7 +29,7 @@ end
# NOTE this is a hack! we did it in order not to depend on NetworkLayout but can be unstable
__networklayout_dim(x) = typeof(x).super.parameters |> first

function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetwork; kwargs...)
function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::absclass(TensorNetwork); kwargs...)
ax = if haskey(kwargs, :layout) && __networklayout_dim(kwargs[:layout]) == 3
Axis3(f[1, 1])
else
Expand All @@ -45,7 +46,7 @@ function Makie.plot!(f::Union{Figure,GridPosition}, @nospecialize tn::TensorNetw
return Makie.AxisPlot(ax, p)
end

function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::TensorNetwork; labels = false, kwargs...)
function Makie.plot!(ax::Union{Axis,Axis3}, @nospecialize tn::absclass(TensorNetwork); labels = false, kwargs...)
hypermap = Tenet.hyperflatten(tn)
tn = transform(tn, Tenet.HyperindConverter)

Expand Down
15 changes: 6 additions & 9 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ module TenetQuacExt

using Tenet
using Quac: Circuit, lanes, arraytype, Swap
using Bijections

function Tenet.TensorNetwork(circuit::Circuit)
function Tenet.QuantumTensorNetwork(circuit::Circuit)
n = lanes(circuit)
wire = [[Tenet.letter(i)] for i in 1:n]
tensors = Tensor[]
tn = TensorNetwork()

i = n + 1

Expand All @@ -29,15 +28,13 @@ function Tenet.TensorNetwork(circuit::Circuit)
end |> x -> zip(x...) |> Iterators.flatten |> collect

tensor = Tensor(array, inds)
push!(tensors, tensor)
push!(tn, tensor)
end

interlayer = [
Bijection(Dict([site => first(index) for (site, index) in enumerate(wire)])),
Bijection(Dict([site => last(index) for (site, index) in enumerate(wire)])),
]
input = first.(wire)
output = last.(wire)

return TensorNetwork{Quantum}(tensors; plug = Tenet.Operator, interlayer)
return QuantumTensorNetwork(tn, input, output)
end

end
10 changes: 0 additions & 10 deletions src/Helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,6 @@ julia> letter(20204)
letter(i) =
Iterators.drop(Iterators.filter(isletter, Iterators.map(Char, 1:2^21-1)), i - 1) |> iterate |> first |> Symbol

Base.merge(@nospecialize(A::Type{<:NamedTuple}), @nospecialize(Bs::Type{<:NamedTuple}...)) = NamedTuple{
foldl((acc, B) -> (acc..., B...), Iterators.map(fieldnames, Bs); init = fieldnames(A)),
foldl((acc, B) -> Tuple{fieldtypes(acc)...,B...}, Iterators.map(fieldtypes, Bs); init = A),
}

function superansatzes(T)
S = supertype(T)
return T === Ansatz ? (T,) : (T, superansatzes(S)...)
end

# NOTE from https://stackoverflow.com/q/54652787
function nonunique(x)
uniqueindexes = indexin(unique(x), x)
Expand Down
1 change: 0 additions & 1 deletion src/Numerics.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
using OMEinsum
using LinearAlgebra
using UUIDs: uuid4
using EinExprs: inds

# TODO test array container typevar on output
for op in [
Expand Down
Loading
Loading