Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates and new futures for MPO-MPO contraction #14

Merged
merged 9 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/abstracttensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ function evaluate(
if length(indexset) != length(tt)
throw(ArgumentError("To evaluate a tt of length $(length(tt)), you have to provide $(length(tt)) indices, but there were $(length(indexset))."))
end
for (n, (T, i)) in enumerate(zip(tt, indexset))
length(size(T)) == length(i) + 2 || throw(ArgumentError("The index set $(i) at position $n does not have the correct length for the tensor of size $(size(T))."))
end
return only(prod(T[:, i..., :] for (T, i) in zip(tt, indexset)))
end

Expand Down
26 changes: 26 additions & 0 deletions src/batcheval.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
"""
Wrap any function to support batch evaluation.
"""
struct BatchEvaluatorAdapter{T} <: BatchEvaluator{T}
f::Function
localdims::Vector{Int}
end

makebatchevaluatable(::Type{T}, f, localdims) where {T} = BatchEvaluatorAdapter{T}(f, localdims)

function (bf::BatchEvaluatorAdapter{T})(indexset::MultiIndex)::T where T
bf.f(indexset)
end

function (bf::BatchEvaluatorAdapter{T})(
leftindexset::AbstractVector{MultiIndex},
rightindexset::AbstractVector{MultiIndex},
::Val{M}
)::Array{T,M + 2} where {T,M}
if length(leftindexset) * length(rightindexset) == 0
return Array{T,M + 2}(undef, ntuple(d -> 0, M + 2)...)
end
return _batchevaluate_dispatch(T, bf.f, bf.localdims, leftindexset, rightindexset, Val(M))
end


"""
This file contains functions for evaluating a function on a batch of indices mainly for TensorCI2.
If the function supports batch evaluation, then it should implement the `BatchEvaluator` interface.
Expand Down
94 changes: 77 additions & 17 deletions src/cachedtensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,55 @@ abstract type BatchEvaluator{V} <: AbstractTensorTrain{V} end


"""
struct TTCache{ValueType, N}
struct TTCache{ValueType}

Cached evalulation of TT
Cached evalulation of a tensor train. This is useful when the same TT is evaluated multiple times with the same indices. The number of site indices per tensor core can be arbitray irrespective of the number of site indices of the original tensor train.
"""
struct TTCache{ValueType} <: BatchEvaluator{ValueType}
sitetensors::Vector{Array{ValueType,3}}
cacheleft::Vector{Dict{MultiIndex,Vector{ValueType}}}
cacheright::Vector{Dict{MultiIndex,Vector{ValueType}}}
sitedims::Vector{Vector{Int}}

function TTCache(sitetensors::AbstractVector{<:AbstractArray{ValueType}}) where {ValueType}
function TTCache{ValueType}(sitetensors::AbstractVector{<:AbstractArray{ValueType}}, sitedims) where {ValueType}
length(sitetensors) == length(sitedims) || throw(ArgumentError("The number of site tensors and site dimensions must be the same."))
for n in 1:length(sitetensors)
prod(sitedims[n]) == prod(size(sitetensors[n])[2:end-1]) || error("Site dimensions do not match the site tensor dimensions at $n.")
end
new{ValueType}(
sitetensors,
[reshape(x, size(x, 1), :, size(x)[end]) for x in sitetensors],
[Dict{MultiIndex,Vector{ValueType}}() for _ in sitetensors],
[Dict{MultiIndex,Vector{ValueType}}() for _ in sitetensors],
[Dict{MultiIndex,Vector{ValueType}}() for _ in sitetensors])
sitedims
)
end
function TTCache{ValueType}(sitetensors::AbstractVector{<:AbstractArray{ValueType}}) where {ValueType}
return TTCache{ValueType}(sitetensors, [collect(size(x)[2:end-1]) for x in sitetensors])
end
end

TTCache(tt::AbstractTensorTrain{ValueType}) where {ValueType} = TTCache{ValueType}(sitetensors(tt))

TTCache(sitetensors::AbstractVector{<:AbstractArray{ValueType}}) where {ValueType} = TTCache{ValueType}(sitetensors)

TTCache(sitetensors::AbstractVector{<:AbstractArray{ValueType}}, sitedims) where {ValueType} = TTCache{ValueType}(sitetensors, sitedims)

TTCache(tt::AbstractTensorTrain{ValueType}, sitedims) where {ValueType} = TTCache{ValueType}(sitetensors(tt), sitedims)

Base.length(obj::TTCache) = length(obj.sitetensors)

sitedims(obj::TTCache)::Vector{Vector{Int}} = obj.sitedims

function sitetensors(tt::TTCache{V}) where {V}
return [sitetensor(tt, n) for n in 1:length(tt)]
end

function sitetensor(tt::TTCache{V}, i::Integer) where {V}
sitetensor = tt.sitetensors[i]
return reshape(
sitetensor,
size(sitetensor)[1], tt.sitedims[i]..., size(sitetensor)[end]
)
end

function Base.empty!(tt::TTCache{V}) where {V}
Expand All @@ -27,9 +60,9 @@ function Base.empty!(tt::TTCache{V}) where {V}
end


function TTCache(TT::AbstractTensorTrain{ValueType}) where {ValueType}
TTCache(sitetensors(TT))
end
#function TTCache(TT::AbstractTensorTrain{ValueType}) where {ValueType}
#TTCache(sitetensors(TT))
#end

function ttcache(tt::TTCache{V}, leftright::Symbol, b::Int) where {V}
if leftright == :left
Expand Down Expand Up @@ -98,7 +131,7 @@ function evaluate(
tt::TTCache{V},
indexset::AbstractVector{Int};
usecache::Bool=true,
midpoint::Int = div(length(tt), 2)
midpoint::Int=div(length(tt), 2)
)::V where {V}
if length(tt) != length(indexset)
throw(ArgumentError("To evaluate a tensor train of length $(length(tt)), need $(length(tt)) index values, but got $(length(indexset))."))
Expand All @@ -112,35 +145,48 @@ function evaluate(
end
end


function (tt::TTCache{V})(
"""
projector: 0 means no projection, otherwise the index of the projector
"""
function batchevaluate(tt::TTCache{V},
leftindexset::AbstractVector{MultiIndex},
rightindexset::AbstractVector{MultiIndex},
::Val{M}
)::Array{V,M + 2} where {V,M}
::Val{M},
projector::Union{Nothing,AbstractVector{<:AbstractVector{<:Integer}}}=nothing)::Array{V,M + 2} where {V,M}
if length(leftindexset) * length(rightindexset) == 0
return Array{V,M+2}(undef, ntuple(d->0, M+2)...)
return Array{V,M + 2}(undef, ntuple(d -> 0, M + 2)...)
end
N = length(tt)
nleft = length(leftindexset[1])
nright = length(rightindexset[1])
nleftindexset = length(leftindexset)
nrightindexset = length(rightindexset)
ncent = N - nleft - nright
s_, e_ = nleft + 1, N - nright

if ncent != M
error("Invalid parameter M: $(M)")
end
if projector === nothing
projector = [fill(0, length(s)) for s in sitedims(tt)[nleft+1:(N-nright)]]
end
if length(projector) != M
error("Invalid length of projector: $(projector), correct length should be M=$M")
end
for n in s_:e_
length(projector[n-s_+1]) == length(tt.sitedims[n]) || error("Invalid projector at $n: $(projector[n - s_ + 1]), the length must be $(length(tt.sitedims[n]))")
all(0 .<= projector[n-s_+1] .<= tt.sitedims[n]) || error("Invalid projector: $(projector[n - s_ + 1])")
end

DL = nleft == 0 ? 1 : size(tt[nleft], 3)
DL = (0 < nleft < N) ? linkdims(tt)[nleft] : 1
lenv = ones(V, nleftindexset, DL)
if nleft > 0
for (il, lindex) in enumerate(leftindexset)
lenv[il, :] .= evaluateleft(tt, lindex)
end
end

DR = nright == 0 ? 1 : linkdim(tt, nleft + ncent)
DR = (0 < nright < N) ? linkdim(tt, nleft + ncent) : 1
renv = ones(V, DR, nrightindexset)
if nright > 0
for (ir, rindex) in enumerate(rightindexset)
Expand All @@ -151,7 +197,11 @@ function (tt::TTCache{V})(
localdim = zeros(Int, ncent)
for n in nleft+1:(N-nright)
# (nleftindexset, d, ..., d, D) x (D, d, D)
T_ = sitetensor(tt, n)
T_ = begin
slice, slice_size = projector_to_slice(projector[n-nleft])
s = sitetensor(tt, n)[:, slice..., :]
reshape(s, size(s)[1], :, size(s)[end])
end
localdim[n-nleft] = size(T_, 2)
bonddim_ = size(T_, 1)
lenv = reshape(lenv, :, bonddim_) * reshape(T_, bonddim_, :)
Expand All @@ -165,5 +215,15 @@ function (tt::TTCache{V})(
end



function (tt::TTCache{V})(
leftindexset::AbstractVector{MultiIndex},
rightindexset::AbstractVector{MultiIndex},
::Val{M}
)::Array{V,M + 2} where {V,M}
return batchevaluate(tt, leftindexset, rightindexset, Val(M))
end


isbatchevaluable(f) = false
isbatchevaluable(f::BatchEvaluator) = true
42 changes: 39 additions & 3 deletions src/contraction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ end

Base.length(obj::Contraction) = length(obj.mpo[1])

function sitedims(obj::Contraction{T})::Vector{Vector{Int}} where {T}
return obj.sitedims
end

function Base.lastindex(obj::Contraction{T}) where {T}
return lastindex(obj.mpo[1])
end
Expand Down Expand Up @@ -226,12 +230,30 @@ function (obj::Contraction{T})(
rightindexset::AbstractVector{MultiIndex},
::Val{M},
)::Array{T,M + 2} where {T,M}
return batchevaluate(obj, leftindexset, rightindexset, Val(M))
end

function batchevaluate(obj::Contraction{T},
leftindexset::AbstractVector{MultiIndex},
rightindexset::AbstractVector{MultiIndex},
::Val{M},
projector::Union{Nothing,AbstractVector{<:AbstractVector{<:Integer}}}=nothing)::Array{T,M + 2} where {T,M}
N = length(obj)
Nr = length(rightindexset[1])
s_ = length(leftindexset[1]) + 1
e_ = N - length(rightindexset[1])
a, b = obj.mpo

if projector === nothing
projector = [fill(0, length(obj.sitedims[n])) for n in s_:e_]
end
length(projector) == M || error("Length mismatch: length of projector (=$(length(projector))) must be $(M)")
for n in s_:e_
length(projector[n - s_ + 1]) == 2 || error("Invalid projector at $n: $(projector[n - s_ + 1]), the length must be 2")
all(0 .<= projector[n - s_ + 1] .<= obj.sitedims[n]) || error("Invalid projector: $(projector[n - s_ + 1])")
end


# Unfused index
leftindexset_unfused = [
[_unfuse_idx(obj, n, idx) for (n, idx) in enumerate(idxs)] for idxs in leftindexset
Expand Down Expand Up @@ -264,14 +286,28 @@ function (obj::Contraction{T})(

# (left_index, link_a, link_b, site[s_] * site'[s_] * ... * site[e_] * site'[e_])
leftobj::Array{T,4} = reshape(left_, size(left_)..., 1)
return_size_siteinds = Int[]
for n = s_:e_
slice_ab, shape_ab = projector_to_slice(projector[n - s_ + 1])
a_n = begin
a_n_org = obj.mpo[1][n]
tmp = a_n_org[:, slice_ab[1], :, :]
reshape(tmp, size(a_n_org, 1), shape_ab[1], size(a_n_org)[3:4]...)
end
b_n = begin
b_n_org = obj.mpo[2][n]
tmp = b_n_org[:, :, slice_ab[2], :]
reshape(tmp, size(b_n_org, 1), size(b_n_org, 2), shape_ab[2], size(b_n_org, 4))
end
push!(return_size_siteinds, size(a_n, 2) * size(b_n, 3))

#(left_index, link_a, link_b, S) * (link_a, site[n], shared, link_a')
# => (left_index, link_b, S, site[n], shared, link_a')
tmp1 = _contract(leftobj, a[n], (2,), (1,))
tmp1 = _contract(leftobj, a_n, (2,), (1,))

# (left_index, link_b, S, site[n], shared, link_a') * (link_b, shared, site'[n], link_b')
# => (left_index, S, site[n], link_a', site'[n], link_b')
tmp2 = _contract(tmp1, b[n], (2, 5), (1, 2))
tmp2 = _contract(tmp1, b_n, (2, 5), (1, 2))

# (left_index, S, site[n], link_a', site'[n], link_b')
# => (left_index, link_a', link_b', S, site[n], site'[n])
Expand All @@ -282,7 +318,7 @@ function (obj::Contraction{T})(

return_size = (
length(leftindexset),
ntuple(i -> prod(obj.sitedims[i+s_-1]), M)...,
return_size_siteinds...,
length(rightindexset),
)
t5 = time_ns()
Expand Down
2 changes: 1 addition & 1 deletion src/tensortrain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,4 +239,4 @@ end
function (obj::TensorTrainFit{ValueType})(x::Vector{ValueType}) where {ValueType}
tensors = to_tensors(obj, x)
return sum((abs2(_evaluate(tensors, indexset) - obj.values[i]) for (i, indexset) in enumerate(obj.indexsets)))
end
end
9 changes: 9 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,12 @@ function replacenothing(value::Union{T, Nothing}, default::T)::T where {T}
return value
end
end


"""
Construct slice for the site indces of one tensor core
Returns a slice and the corresponding shape for `resize`
"""
function projector_to_slice(p::AbstractVector{<:Integer})
return [x == 0 ? Colon() : x for x in p], [x == 0 ? Colon() : 1 for x in p]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ include("test_matrixaca.jl")
include("test_matrixlu.jl")
include("test_matrixluci.jl")
include("test_batcheval.jl")
include("test_cachedtensortrain.jl")
include("test_tensorci1.jl")
include("test_tensorci2.jl")
include("test_tensortrain.jl")
Expand Down
19 changes: 19 additions & 0 deletions test/test_batcheval.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
using Test
import TensorCrossInterpolation as TCI


struct NonBatchEvaluator{T} <: Function end

function (f::NonBatchEvaluator{T})(x::Vector{Int})::T where {T}
return sum(x)
end


@testset "batcheval" begin
@testset "M=1" begin
localdims = [2, 2, 2, 2, 2]
Expand All @@ -26,6 +34,17 @@ import TensorCrossInterpolation as TCI
@test result ≈ ref
end

@testset "BatchEvaluator" begin
tbf = NonBatchEvaluator{Float64}()

leftindexset = [[1], [2]]
rightindexset = [[1], [2]]
localdims = [3, 3, 3, 3]

bf = TCI.makebatchevaluatable(Float64, tbf, localdims)
@test size(bf(leftindexset, rightindexset, Val(1))) == (2, 3, 2)
end

@testset "ThreadedBatchEvaluator" begin
L = 20
localdims = fill(2, L)
Expand Down
Loading
Loading