diff --git a/Project.toml b/Project.toml index 60fefbbf..aeef8d6c 100644 --- a/Project.toml +++ b/Project.toml @@ -45,6 +45,6 @@ OMEinsum = "0.7" Permutations = "0.4" Quac = "0.2" Requires = "1.3" -Tensors = "0.1.4" +Tensors = "0.1.9" ValSplit = "0.1" julia = "1.8" diff --git a/src/Transformations.jl b/src/Transformations.jl index 34b101d5..75f524f4 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -302,8 +302,8 @@ function transform!(tn::TensorNetwork, config::SplitSimplification) tensor_r = v pop!(tn, tensor) # Remove the old tensor - push!(tn, tensor_l) # Add the new tensors - push!(tn, tensor_r) + push!(tn, dropdims(tensor_l, dims = tuple(findall(size(tensor_l) .== 1)...))) # Add the new tensors + push!(tn, dropdims(tensor_r, dims = tuple(findall(size(tensor_r) .== 1)...))) done = false # A change was made, so we'll need to go another pass break # Exit the inner loop early