Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve support for mixed scalartypes #259

Merged
merged 14 commits into from
Mar 3, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"

[compat]
Expand Down Expand Up @@ -50,7 +51,6 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

Expand Down
1 change: 1 addition & 0 deletions src/MPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ using Compat: @compat
using TensorKit
using TensorKit: BraidingTensor
using BlockTensorKit
using TensorOperations
using KrylovKit
using KrylovKit: KrylovAlgorithm
using OptimKit
Expand Down
7 changes: 4 additions & 3 deletions src/algorithms/timestep/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@ $(TYPEDFIELDS)
finalize::F = Defaults._finalize
end

function timestep(ψ::InfiniteMPS, H, t::Number, dt::Number, alg::TDVP,
envs::AbstractMPSEnvironments=environments(ψ, H);
function timestep(ψ_::InfiniteMPS, H, t::Number, dt::Number, alg::TDVP,
envs::AbstractMPSEnvironments=environments(ψ_, H);
leftorthflag=true)
ψ = complex(ψ_)
temp_ACs = similar(ψ.AC)
temp_Cs = similar(ψ.C)

Expand Down Expand Up @@ -172,5 +173,5 @@ end
function timestep(ψ::AbstractFiniteMPS, H, time::Number, timestep::Number,
alg::Union{TDVP,TDVP2}, envs::AbstractMPSEnvironments=environments(ψ, H);
kwargs...)
return timestep!(copy(ψ), H, time, timestep, alg, envs; kwargs...)
return timestep!(copy(complex(ψ)), H, time, timestep, alg, envs; kwargs...)
end
8 changes: 7 additions & 1 deletion src/operators/abstractmpo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
# Properties
# ----------
left_virtualspace(mpo::AbstractMPO, site::Int) = left_virtualspace(mpo[site])
left_virtualspace(mpo::AbstractMPO) = map(left_virtualspace, parent(mpo))

Check warning on line 27 in src/operators/abstractmpo.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/abstractmpo.jl#L27

Added line #L27 was not covered by tests
right_virtualspace(mpo::AbstractMPO, site::Int) = right_virtualspace(mpo[site])
right_virtualspace(mpo::AbstractMPO) = map(right_virtualspace, parent(mpo))

Check warning on line 29 in src/operators/abstractmpo.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/abstractmpo.jl#L29

Added line #L29 was not covered by tests
physicalspace(mpo::AbstractMPO, site::Int) = physicalspace(mpo[site])
physicalspace(mpo::AbstractMPO) = map(physicalspace, mpo)

Expand Down Expand Up @@ -170,7 +172,11 @@
Base.:/(mpo::AbstractMPO, α::Number) = scale(mpo, inv(α))
Base.:\(α::Number, mpo::AbstractMPO) = scale(mpo, inv(α))

VectorInterface.scale(mpo::AbstractMPO, α::Number) = scale!(copy(mpo), α)
function VectorInterface.scale(mpo::AbstractMPO, α::Number)
T = VectorInterface.promote_scale(scalartype(mpo), scalartype(α))
dst = similar(mpo, T)
return scale!(dst, mpo, α)
end

LinearAlgebra.norm(mpo::AbstractMPO) = sqrt(abs(dot(mpo, mpo)))

Expand Down
2 changes: 2 additions & 0 deletions src/operators/lazysum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
Base.similar(x::LazySum, ::Type{S}, dims::Dims) where {S} = LazySum(similar(x.ops, S, dims))
Base.setindex!(A::LazySum, X, i::Int) = (setindex!(A.ops, X, i); A)

Base.complex(x::LazySum) = LazySum(complex.(x.ops))

Check warning on line 27 in src/operators/lazysum.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/lazysum.jl#L27

Added line #L27 was not covered by tests

# Holy traits
TimeDependence(x::LazySum) = istimed(x) ? TimeDependent() : NotTimeDependent()
istimed(x::LazySum) = any(istimed, x)
Expand Down
90 changes: 43 additions & 47 deletions src/operators/mpo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@
Base.parent(mpo::MPO) = mpo.O
Base.copy(mpo::MPO) = MPO(map(copy, mpo))

function Base.similar(mpo::MPO, ::Type{O}, L::Int) where {O}
function Base.similar(mpo::MPO{<:MPOTensor}, ::Type{O}, L::Int) where {O<:MPOTensor}

Check warning on line 58 in src/operators/mpo.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpo.jl#L58

Added line #L58 was not covered by tests
return MPO(similar(parent(mpo), O, L))
end
function Base.similar(mpo::MPO, ::Type{T}) where {T<:Number}
return MPO(similar.(parent(mpo), T))
end

Base.repeat(mpo::MPO, n::Int) = MPO(repeat(parent(mpo), n))
Base.repeat(mpo::MPO, rows::Int, cols::Int) = MultilineMPO(fill(repeat(mpo, cols), rows))
Expand Down Expand Up @@ -102,19 +105,20 @@
return convert(TensorMap, _instantiate_finitempo(L, M, R))
end

Base.complex(mpo::MPO) = MPO(map(complex, parent(mpo)))

# Linear Algebra
# --------------
# VectorInterface.scalartype(::Type{FiniteMPO{O}}) where {O} = scalartype(O)

Base.:+(mpo::MPO) = MPO(map(+, parent(mpo)))
function Base.:+(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
(N = length(mpo1)) == length(mpo2) || throw(ArgumentError("dimension mismatch"))
function Base.:+(mpo1::FiniteMPO{<:MPOTensor}, mpo2::FiniteMPO{<:MPOTensor})
N = check_length(mpo1, mpo2)
@assert left_virtualspace(mpo1, 1) == left_virtualspace(mpo2, 1) &&
right_virtualspace(mpo1, N) == right_virtualspace(mpo2, N)

mpo = similar(parent(mpo1))
halfN = N ÷ 2
A = storagetype(TO)
A = storagetype(eltype(mpo1))

# left half
F₁ = isometry(A, (right_virtualspace(mpo1, 1) ⊕ right_virtualspace(mpo2, 1)),
Expand All @@ -127,7 +131,9 @@

# making sure that the new operator is "full rank"
O, R = leftorth!(O)
mpo[1] = transpose(O, ((2, 3), (1, 4)))
O′ = transpose(O, ((2, 3), (1, 4)))
mpo = similar(mpo1, typeof(O′))
mpo[1] = O′

for i in 2:halfN
# incorporate fusers from left side
Expand Down Expand Up @@ -193,11 +199,18 @@
scale!(first(mpo), α)
return mpo
end
function VectorInterface.scale!(dst::MPO, src::MPO, α::Number)
N = check_length(dst, src)
for i in 1:N
scale!(dst[i], src[i], i == 1 ? α : One())
end
return dst
end

function Base.:*(mpo1::FiniteMPO{<:MPOTensor}, mpo2::FiniteMPO{<:MPOTensor})
N = check_length(mpo1, mpo2)
(S = spacetype(mpo1)) == spacetype(mpo2) || throw(SectorMismatch())

# TODO: merge implementation with that of InfiniteMPO
function Base.:*(mpo1::FiniteMPO{TO}, mpo2::FiniteMPO{TO}) where {TO<:MPOTensor}
(N = length(mpo1)) == length(mpo2) || throw(ArgumentError("dimension mismatch"))
S = spacetype(TO)
if (left_virtualspace(mpo1, 1) != oneunit(S) ||
left_virtualspace(mpo2, 1) != oneunit(S)) ||
(right_virtualspace(mpo1, N) != oneunit(S) ||
Expand All @@ -207,44 +220,34 @@
# would work and for now I dont feel like figuring out if this is important
end

O = similar(parent(mpo1))
A = storagetype(TO)

# note order of mpos: mpo1 * mpo2 * state -> mpo2 on top of mpo1
local Fᵣ # trick to make Fᵣ defined in the loop
for i in 1:N
Fₗ = i != 1 ? Fᵣ : fuser(A, left_virtualspace(mpo2, i), left_virtualspace(mpo1, i))
Fᵣ = fuser(A, right_virtualspace(mpo2, i), right_virtualspace(mpo1, i))
@plansor O[i][-1 -2; -3 -4] := Fₗ[-1; 1 4] * mpo2[i][1 2; -3 3] *
mpo1[i][4 -2; 2 5] *
conj(Fᵣ[-4; 3 5])
end

O = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
return changebonds!(FiniteMPO(O), SvdCut(; trscheme=notrunc()))
end
function Base.:*(mpo1::InfiniteMPO, mpo2::InfiniteMPO)
check_length(mpo1, mpo2)
Os = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
return InfiniteMPO(Os)
end

function Base.:*(mpo::FiniteMPO, mps::FiniteMPS)
length(mpo) == length(mps) || throw(ArgumentError("dimension mismatch"))

A = [mps.AC[1]; mps.AR[2:end]]
TT = storagetype(eltype(A))

local Fᵣ # trick to make Fᵣ defined in the loop
for i in 1:length(mps)
Fₗ = i != 1 ? Fᵣ : fuser(TT, left_virtualspace(mps, i), left_virtualspace(mpo, i))
Fᵣ = fuser(TT, right_virtualspace(mps, i), right_virtualspace(mpo, i))
A[i] = _fuse_mpo_mps(mpo[i], A[i], Fₗ, Fᵣ)
N = check_length(mpo, mps)
T = TensorOperations.promote_contract(scalartype(mpo), scalartype(mps))
A = TensorKit.similarstoragetype(eltype(mps), T)
Fᵣ = fuser(A, left_virtualspace(mps, 1), left_virtualspace(mpo, 1))
A2 = map(1:N) do i
A1 = i == 1 ? mps.AC[1] : mps.AR[i]
Fₗ = Fᵣ
Fᵣ = fuser(A, right_virtualspace(mps, i), right_virtualspace(mpo, i))
return _fuse_mpo_mps(mpo[i], A1, Fₗ, Fᵣ)
end

return changebonds!(FiniteMPS(A),
SvdCut(; trscheme=truncbelow(eps(real(scalartype(TT)))));
normalize=false)
trscheme = truncbelow(eps(real(T)))
return changebonds!(FiniteMPS(A2), SvdCut(; trscheme); normalize=false)
end

function Base.:*(mpo::InfiniteMPO, mps::InfiniteMPS)
L = check_length(mpo, mps)
T = promote_type(scalartype(mpo), scalartype(mps))
fusers = PeriodicArray(fuser.(T, left_virtualspace.(Ref(mps), 1:L),
A = TensorKit.similarstoragetype(eltype(mps), T)
fusers = PeriodicArray(fuser.(A, left_virtualspace.(Ref(mps), 1:L),
left_virtualspace.(Ref(mpo), 1:L)))
As = map(1:L) do i
return _fuse_mpo_mps(mpo[i], mps.AL[i], fusers[i], fusers[i + 1])
Expand All @@ -260,12 +263,6 @@
return A′ isa AbstractBlockTensorMap ? TensorMap(A′) : A′
end

function Base.:*(mpo1::InfiniteMPO, mpo2::InfiniteMPO)
check_length(mpo1, mpo2)
Os = map(fuse_mul_mpo, parent(mpo1), parent(mpo2))
return InfiniteMPO(Os)
end

function Base.:*(mpo::FiniteMPO{<:MPOTensor}, x::AbstractTensorMap)
@assert length(mpo) > 1
@assert numout(x) == length(mpo)
Expand All @@ -281,8 +278,7 @@
# in the middle
function TensorKit.dot(bra::FiniteMPS{T}, mpo::FiniteMPO{<:MPOTensor},
ket::FiniteMPS{T}) where {T}
(N = length(bra)) == length(mpo) == length(ket) ||
throw(ArgumentError("dimension mismatch"))
N = check_length(bra, mpo, ket)
Nhalf = N ÷ 2
# left half
ρ_left = isomorphism(storagetype(T),
Expand Down
67 changes: 56 additions & 11 deletions src/operators/mpohamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,32 @@
end
end

# TODO: remove once complex(::BraidingTensor) isa BraidingTensor
# Base.complex(H::MPOHamiltonian) = MPOHamiltonian(map(complex, parent(H)))
function Base.complex(H::MPOHamiltonian)
scalartype(H) <: Complex && return H
Ws = map(parent(H)) do W
W′ = jordanmpotensortype(spacetype(W), complex(scalartype(W)))
W′[1] = W[1]
W′[end] = W[end]
for (I, v) in nonzero_pairs(W)
if v isa BraidingTensor
W′[I] = BraidingTensor{scalartype(W′)}(space(v), v.adjoint)

Check warning on line 387 in src/operators/mpohamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpohamiltonian.jl#L379-L387

Added lines #L379 - L387 were not covered by tests
else
W′[I] = complex(v)

Check warning on line 389 in src/operators/mpohamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpohamiltonian.jl#L389

Added line #L389 was not covered by tests
end
end

Check warning on line 391 in src/operators/mpohamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpohamiltonian.jl#L391

Added line #L391 was not covered by tests
end
return MPOHamiltonian(H)

Check warning on line 393 in src/operators/mpohamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpohamiltonian.jl#L393

Added line #L393 was not covered by tests
end

function Base.similar(H::MPOHamiltonian, ::Type{O}, L::Int) where {O<:MPOTensor}
return MPOHamiltonian(similar(parent(H), O, L))

Check warning on line 397 in src/operators/mpohamiltonian.jl

View check run for this annotation

Codecov / codecov/patch

src/operators/mpohamiltonian.jl#L396-L397

Added lines #L396 - L397 were not covered by tests
end
function Base.similar(H::MPOHamiltonian, ::Type{T}) where {T<:Number}
return MPOHamiltonian(similar.(parent(H), T))
end

# Linear Algebra
# --------------

Expand Down Expand Up @@ -496,15 +522,36 @@
return H
end

function VectorInterface.scale!(dst::MPOHamiltonian, src::MPOHamiltonian,
λ::Number)
N = check_length(dst, src)
for i in 1:N
space(dst[i]) == space(src[i]) || throw(SpaceMismatch())
zerovector!(dst[i])
for (I, v) in nonzero_pairs(src[i])
# only scale "starting" terms
isstarting = I[1] == 1 &&
((isfinite(dst) && i == N && I[4] == size(src[i], 4)) ||
((!isfinite(dst) || i != N) && I[4] > 1))
if v isa BraidingTensor && !isstarting
dst[i][I] = v
else
dst[i][I] = scale!(dst[i][I], v, isstarting ? λ : One())
end
end
end
return dst
end

function Base.:*(H1::MPOHamiltonian, H2::MPOHamiltonian)
check_length(H1, H2)
Ws = fuse_mul_mpo.(parent(H1), parent(H2))
return MPOHamiltonian(Ws)
end

function Base.:*(H::FiniteMPOHamiltonian, mps::FiniteMPS)
check_length(H, mps)
@assert length(mps) > 2 "MPS should have at least three sites, to be implemented otherwise"
N = check_length(H, mps)
@assert N > 2 "MPS should have at least three sites, to be implemented otherwise"
A = convert.(BlockTensorMap, [mps.AC[1]; mps.AR[2:end]])
A′ = similar(A,
tensormaptype(spacetype(mps), numout(eltype(mps)), numin(eltype(mps)),
Expand All @@ -515,30 +562,30 @@
Q, R = leftorth!(a; alg=QR())
A′[1] = convert(TensorMap, Q)

for i in 2:(length(mps) ÷ 2)
for i in 2:(N ÷ 2)
@plansor a[-1 -2; -3 -4] := R[-1; 1 2] * A[i][1 3; -3] * H[i][2 -2; 3 -4]
Q, R = leftorth!(a; alg=QR())
A′[i] = convert(TensorMap, Q)
end

# right to middle
U = ones(scalartype(H), right_virtualspace(H, length(H)))
U = ones(scalartype(H), right_virtualspace(H, N))
@plansor a[-1 -2; -3 -4] := A[end][-1 2; -3] * H[end][-2 -4; 2 1] * U[1]
L, Q = rightorth!(a; alg=LQ())
A′[end] = transpose(convert(TensorMap, Q), ((1, 3), (2,)))

for i in (length(mps) - 1):-1:(length(mps) ÷ 2 + 2)
for i in (N - 1):-1:(N ÷ 2 + 2)
@plansor a[-1 -2; -3 -4] := A[i][-1 3; 1] * H[i][-2 -4; 3 2] * L[1 2; -3]
L, Q = rightorth!(a; alg=LQ())
A′[i] = transpose(convert(TensorMap, Q), ((1, 3), (2,)))
end

# connect pieces
@plansor a[-1 -2; -3] := R[-1; 1 2] *
A[length(mps) ÷ 2 + 1][1 3; 4] *
H[length(mps) ÷ 2 + 1][2 -2; 3 5] *
A[N ÷ 2 + 1][1 3; 4] *
H[N ÷ 2 + 1][2 -2; 3 5] *
L[4 5; -3]
A′[length(mps) ÷ 2 + 1] = convert(TensorMap, a)
A′[N ÷ 2 + 1] = convert(TensorMap, a)

return FiniteMPS(A′)
end
Expand All @@ -553,9 +600,7 @@
end

function TensorKit.dot(H₁::FiniteMPOHamiltonian, H₂::FiniteMPOHamiltonian)
check_length(H₁, H₂)

N = length(H₁)
N = check_length(H₁, H₂)
Nhalf = N ÷ 2
# left half
@plansor ρ_left[-1; -2] := conj(H₁[1][1 2; 3 -1]) * H₂[1][1 2; 3 -2]
Expand Down
13 changes: 13 additions & 0 deletions src/states/finitemps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,19 @@
end
end

_complex_if_not_missing(x) = ismissing(x) ? x : complex(x)

Check warning on line 319 in src/states/finitemps.jl

View check run for this annotation

Codecov / codecov/patch

src/states/finitemps.jl#L319

Added line #L319 was not covered by tests
function Base.complex(mps::FiniteMPS)
scalartype(mps) <: Complex && return mps
ALs = _complex_if_not_missing.(mps.ALs)
ARs = _complex_if_not_missing.(mps.ARs)
Cs = _complex_if_not_missing.(mps.Cs)
ACs = _complex_if_not_missing.(mps.ACs)
return FiniteMPS(collect(Union{Missing,eltype(ALs)}, ALs),

Check warning on line 326 in src/states/finitemps.jl

View check run for this annotation

Codecov / codecov/patch

src/states/finitemps.jl#L322-L326

Added lines #L322 - L326 were not covered by tests
collect(Union{Missing,eltype(ARs)}, ARs),
collect(Union{Missing,eltype(ACs)}, ACs),
collect(Union{Missing,eltype(Cs)}, Cs))
end

@inline function Base.getindex(ψ::FiniteMPS, I::AbstractUnitRange)
return Base.getindex.(Ref(ψ), I)
end
Expand Down
5 changes: 5 additions & 0 deletions src/states/infinitemps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,11 @@
return ψ
end

function Base.complex(ψ::InfiniteMPS)
scalartype(ψ) <: Complex && return ψ
return InfiniteMPS(complex.(ψ.AL), complex.(ψ.AR), complex.(ψ.C), complex.(ψ.AC))

Check warning on line 229 in src/states/infinitemps.jl

View check run for this annotation

Codecov / codecov/patch

src/states/infinitemps.jl#L229

Added line #L229 was not covered by tests
end

function Base.repeat(ψ::InfiniteMPS, i::Int)
return InfiniteMPS(repeat(ψ.AL, i), repeat(ψ.AR, i), repeat(ψ.C, i), repeat(ψ.AC, i))
end
Expand Down
Loading