diff --git a/src/Transformations.jl b/src/Transformations.jl index 75f524f4..0853b479 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -277,19 +277,16 @@ function transform!(tn::TensorNetwork, config::SplitSimplification) for tensor in tensor_list inds = labels(tensor) - partitions = Iterators.flatten(combinations(inds, r) for r = 1:(length(inds)-1)) - - # Iterate over all possible bipartitions of the tensor's indices - for bipartition in partitions + # iterate all bipartitions of the tensor's indices + bipartitions = Iterators.flatten(combinations(inds, r) for r = 1:(length(inds)-1)) + for bipartition in bipartitions left_inds = collect(bipartition) right_inds = setdiff(inds, left_inds) - # Perform an SVD across the bipartition + # perform an SVD across the bipartition u, s, v = svd(tensor; left_inds=left_inds) - # Get the singular values and decide the rank - singular_values = diag(s) - rank_s = sum(singular_values .> config.atol) + rank_s = sum(diag(s) .> config.atol) if rank_s < length(singular_values) # Remove unnecessary data in u, s, v