-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement spatially symmetric LoopTNR (#7)
* add Zygote and OptimKit as dependencies * implement SLoopTNR * add default `LBFGS` parameters * add `SLoopTNR` to the spaces testset * remove `SLoopTNR` from spaces test again * add `SLoopTNR` to Ising test * add comment to `SLoopTNR` implementation * formatting Co-authored-by: darts <[email protected]>
- Loading branch information
1 parent
5f6b372
commit dc6d972
Showing
5 changed files
with
93 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# TODO: rewrite SLoopTNR contractions to work with symmetric tensors | ||
mutable struct SLoopTNR <: TRGScheme | ||
T::TensorMap | ||
|
||
optimization_algorithm::OptimKit.OptimizationAlgorithm | ||
finalize!::Function | ||
function SLoopTNR(T::TensorMap; | ||
optimization_algorithm=LBFGS(8; verbosity=1, maxiter=500, | ||
gradtol=1e-4), finalize=finalize!) | ||
@assert scalartype(T) <: Real "SLoopTNR only supports real-valued TensorMaps" | ||
return new(T, optimization_algorithm, finalize) | ||
end | ||
end | ||
|
||
function step!(scheme::SLoopTNR, trunc::TensorKit.TruncationScheme) | ||
f(A) = _SLoopTNR_cost(permute(scheme.T, ((1, 2), (4, 3))), A) # Another convention was used when implementing SLoopTNR | ||
|
||
function fg(f, A) | ||
f, g = Zygote.withgradient(f, A) | ||
return f, g[1] | ||
end | ||
|
||
Zygote.refresh() | ||
|
||
U, S, _ = tsvd(permute(scheme.T, ((1, 2), (4, 3))); trunc=trunc) | ||
S₀ = U * sqrt(S) | ||
if norm(imag(S)) > 1e-12 | ||
@error "S is not real" | ||
end | ||
S_opt, _, _, _, _ = optimize(A -> fg(f, A), S₀, scheme.optimization_algorithm) | ||
|
||
@tensor scheme.T[-1 -2; -4 -3] := S_opt[1 2 -3] * S_opt[1 4 -1] * S_opt[3 4 -2] * | ||
S_opt[3 2 -4] | ||
end | ||
|
||
function ψAψA(T::AbstractTensorMap) | ||
@tensor M[-1 -2 -3 -4] := T[1 -2 2 -4] * conj(T[1 -1 2 -3]) | ||
@tensor MM[-1 -2 -3 -4] := M[-1 -2 1 2] * M[-3 -4 1 2] | ||
return @tensor MM[1 2 3 4] * MM[1 2 3 4] | ||
end | ||
|
||
function ψAψB(T::AbstractTensorMap, S::AbstractTensorMap) | ||
@tensor M[-1 -2 -3 -4] := T[1 -2 2 -4] * conj(S[1 -1 3]) * conj(S[2 -3 3]) | ||
@tensor MM[-1 -2 -3 -4] := M[-1 -2 1 2] * M[-3 -4 1 2] | ||
@tensor result = MM[1 2 3 4] * MM[1 2 3 4] | ||
if norm(imag(result)) > 1e-12 | ||
@error "We only support real tensors" | ||
end | ||
return result | ||
end | ||
|
||
function ψBψB(S::AbstractTensorMap) | ||
@tensor M[-1 -2 -3 -4] := S[1 -1 3] * conj(S[1 -2 4]) * S[2 -3 3] * conj(S[2 -4 4]) | ||
@tensor MM[-1 -2 -3 -4] := M[-1 -2 1 2] * M[-3 -4 1 2] | ||
return @tensor MM[1 2 3 4] * MM[1 2 3 4] # This seems very bad for complex numbers | ||
end | ||
|
||
function _SLoopTNR_cost(T::AbstractTensorMap, S::AbstractTensorMap) | ||
return ψAψA(T) - 2 * real(ψAψB(T, S)) + ψBψB(S) | ||
end | ||
|
||
slooptnr_convcrit(steps::Int, data) = abs(log(data[end]) * 2.0^(-steps)) | ||
|
||
function Base.show(io::IO, scheme::SLoopTNR) | ||
println(io, "SLoopTNR - Symmetric Loop TNR") | ||
println(io, " * T: $(summary(scheme.T))") | ||
return println(io, | ||
" * Optimization algorithm: $(summary(scheme.optimization_algorithm))") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters