From 5ecbcf5c0316a0783ae1cfa314847ced97c6c4b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 13 Feb 2025 12:08:31 -0500 Subject: [PATCH 01/17] merge Ellipsis in one block --- src/blockedpermutation.jl | 8 ++++++-- test/test_blockedpermutation.jl | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 663f480..2c3a79b 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -97,8 +97,12 @@ function blockedperm( unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks) specified_perm = flatten_tuples(specified_permblocks) len = _blockedperm_length(length, specified_perm) - unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks))) - permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims) + unspecified_dims_vec = setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)) + UD = len - sum(Base.length.(specified_permblocks)) # preserve type stability when possible + unspecified_dims = NTuple{UD}(unspecified_dims_vec) + permblocks_specified = TupleTools.insertat( + permblocks, unspecified_dim, (unspecified_dims,) + ) return blockedperm(permblocks_specified...) end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 37f8e6a..e01675a 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -80,28 +80,28 @@ using TensorAlgebra: # First dimensions are unspecified. p = blockedperm(.., (4, 3)) - @test p == blockedperm(1, 2, (4, 3)) + @test p == blockedperm((1, 2), (4, 3)) # Specify length - p = blockedperm(.., (4, 3); length=Val(6)) - @test p == blockedperm(1, 2, 5, 6, (4, 3)) + p = @constinferred blockedperm(.., (4, 3); length=Val(6)) + @test p == blockedperm((1, 2, 5, 6), (4, 3)) # Last dimensions are unspecified. p = blockedperm((4, 3), ..) - @test p == blockedperm((4, 3), 1, 2) + @test p == blockedperm((4, 3), (1, 2)) # Specify length - p = blockedperm((4, 3), ..; length=Val(6)) - @test p == blockedperm((4, 3), 1, 2, 5, 6) + p = @constinferred blockedperm((4, 3), ..; length=Val(6)) + @test p == blockedperm((4, 3), (1, 2, 5, 6)) # Middle dimensions are unspecified. p = blockedperm((4, 3), .., 1) - @test p == blockedperm((4, 3), 2, 1) + @test p == blockedperm((4, 3), (2,), (1,)) # Specify length - p = blockedperm((4, 3), .., 1; length=Val(6)) - @test p == blockedperm((4, 3), 2, 5, 6, 1) + p = @constinferred blockedperm((4, 3), .., 1; length=Val(6)) + @test p == blockedperm((4, 3), (2, 5, 6), (1,)) # No dimensions are unspecified. p = blockedperm((3, 2), .., 1) - @test p == blockedperm((3, 2), 1) + @test p == blockedperm((3, 2), (), (1,)) end @testset "BlockedTrivialPermutation" begin From a1d0ca3b662e5fa90b98bd65e2f2b37cee392f1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 13 Feb 2025 12:14:52 -0500 Subject: [PATCH 02/17] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9d18dc3..02b9d2b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.8" +version = "0.1.9" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From 5ee79b10a87232856d58f04b31f48baf629968ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 13 Feb 2025 12:41:14 -0500 Subject: [PATCH 03/17] adapt tests --- test/test_basics.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index 221274b..b0e83f9 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -22,19 +22,16 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) a_fused = fusedims(a, .., (3, 1)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (3, 5, 8)) + @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) a_fused = fusedims(a, (3, 1), ..) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) a_fused = fusedims(a, .., (3, 1), 2) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (4, 3, 1, 2)), (5, 8, 3)) a_fused = fusedims(a, (3, 1), .., 2) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3)) - a_fused = fusedims(a, (3, 1), ..) - @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) end @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) From da74ebdd53c8eb2d9dd3620b5500be810f82962c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 13 Feb 2025 13:02:07 -0500 Subject: [PATCH 04/17] adapt comments --- src/blockedpermutation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 2c3a79b..bb65f49 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -86,8 +86,8 @@ function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) return value(vallength) end -# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1) -# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1) +# blockedperm((4, 3), .., 1) == blockedperm((4, 3), (2,), (1,)) +# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), (2, 5), (1,)) function blockedperm( permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing ) From a6d3706429f26fa018802e283b3e5f6510a074bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 16:27:26 -0500 Subject: [PATCH 05/17] recover Ellipsis behavior --- src/blockedpermutation.jl | 34 +++++++++++++++++++++------------ test/test_blockedpermutation.jl | 29 +++++++++++++++++++++++----- 2 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index bb65f49..7183782 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -70,7 +70,9 @@ function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) return blockedperm(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...) +function blockedperm( + permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs... +) return blockedperm(collect_tuple.(permblocks)...; kwargs...) end @@ -87,25 +89,33 @@ function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) end # blockedperm((4, 3), .., 1) == blockedperm((4, 3), (2,), (1,)) -# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), (2, 5), (1,)) +# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), (2,), (5,), (1,)) +# blockedperm((4, 3), (..,), 1) == blockedperm((4, 3), (2,), (1,)) +# blockedperm((4, 3), (..,), 1; length=Val(5)) == blockedperm((4, 3), (2, 5), (1,)) function blockedperm( - permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing + permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...; + length::Union{Val,Nothing}=nothing, ) # Check there is only one `Ellipsis`. - @assert isone(count(x -> x isa Ellipsis, permblocks)) - specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks) - unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks) + @assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)) + specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks) + unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks) specified_perm = flatten_tuples(specified_permblocks) len = _blockedperm_length(length, specified_perm) - unspecified_dims_vec = setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)) - UD = len - sum(Base.length.(specified_permblocks)) # preserve type stability when possible - unspecified_dims = NTuple{UD}(unspecified_dims_vec) - permblocks_specified = TupleTools.insertat( - permblocks, unspecified_dim, (unspecified_dims,) - ) + unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm) + UD = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible + insert = unspecified_dims(typeof(permblocks[unspecified_dim]), unspecified_dims_vec, UD) + permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) return blockedperm(permblocks_specified...) end +function unspecified_dims(::Type{Tuple{Ellipsis}}, unspecified_dims_vec, UD::Val) + return (NTuple{value(UD),Int}(unspecified_dims_vec),) +end +function unspecified_dims(::Type{Ellipsis}, unspecified_dims_vec, UD::Val) + return NTuple{value(UD),Tuple{Int}}(Tuple.(unspecified_dims_vec)) +end + # Version of `indexin` that outputs a `blockedperm`. function blockedperm_indexin(collection, subs...) return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index e01675a..ca1f922 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -80,27 +80,46 @@ using TensorAlgebra: # First dimensions are unspecified. p = blockedperm(.., (4, 3)) - @test p == blockedperm((1, 2), (4, 3)) + @test p == blockedperm((1,), (2,), (4, 3)) # Specify length p = @constinferred blockedperm(.., (4, 3); length=Val(6)) - @test p == blockedperm((1, 2, 5, 6), (4, 3)) + @test p == blockedperm((1,), (2,), (5,), (6,), (4, 3)) # Last dimensions are unspecified. p = blockedperm((4, 3), ..) - @test p == blockedperm((4, 3), (1, 2)) + @test p == blockedperm((4, 3), (1,), (2,)) # Specify length p = @constinferred blockedperm((4, 3), ..; length=Val(6)) - @test p == blockedperm((4, 3), (1, 2, 5, 6)) + @test p == blockedperm((4, 3), (1,), (2,), (5,), (6,)) # Middle dimensions are unspecified. p = blockedperm((4, 3), .., 1) @test p == blockedperm((4, 3), (2,), (1,)) # Specify length p = @constinferred blockedperm((4, 3), .., 1; length=Val(6)) - @test p == blockedperm((4, 3), (2, 5, 6), (1,)) + @test p == blockedperm((4, 3), (2,), (5,), (6,), (1,)) # No dimensions are unspecified. p = blockedperm((3, 2), .., 1) + @test p == blockedperm((3, 2), (1,)) + + # same with (..,) instead of .. + p = blockedperm((..,), (4, 3)) + @test p == blockedperm((1, 2), (4, 3)) + p = @constinferred blockedperm((..,), (4, 3); length=Val(6)) + @test p == blockedperm((1, 2, 5, 6), (4, 3)) + + p = blockedperm((4, 3), (..,)) + @test p == blockedperm((4, 3), (1, 2)) + p = @constinferred blockedperm((4, 3), (..,); length=Val(6)) + @test p == blockedperm((4, 3), (1, 2, 5, 6)) + + p = blockedperm((4, 3), (..,), 1) + @test p == blockedperm((4, 3), (2,), (1,)) + p = @constinferred blockedperm((4, 3), (..,), 1; length=Val(6)) + @test p == blockedperm((4, 3), (2, 5, 6), (1,)) + + p = blockedperm((3, 2), (..,), 1) @test p == blockedperm((3, 2), (), (1,)) end From a7944b4edddc7ccf62809dcd216085c73f82829f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 16:31:12 -0500 Subject: [PATCH 06/17] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 02b9d2b..d343c94 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.9" +version = "0.1.10" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From f19e447f3f7cc9a9f40167a79e1a27dac62086a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 18 Feb 2025 17:35:47 -0500 Subject: [PATCH 07/17] fix tests --- test/test_basics.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_basics.jl b/test/test_basics.jl index b0e83f9..6b75c69 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -22,16 +22,19 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (24, 5)) a_fused = fusedims(a, .., (3, 1)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (15, 8)) + @test a_fused ≈ reshape(permutedims(a, (2, 4, 3, 1)), (3, 5, 8)) a_fused = fusedims(a, (3, 1), ..) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) a_fused = fusedims(a, .., (3, 1), 2) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (4, 3, 1, 2)), (5, 8, 3)) a_fused = fusedims(a, (3, 1), .., 2) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3)) + a_fused = fusedims(a, (3, 1), (..,)) + @test eltype(a_fused) === elt + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) end @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) From 9910fe4efb033fb75da027085b4081012a3a2ef3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Wed, 19 Feb 2025 10:45:11 -0500 Subject: [PATCH 08/17] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d343c94..a967c86 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.10" +version = "0.1.11" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From 1b0fcd8ffca33d5a3cf8b0339a7e3ae60a59c905 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 14:26:38 -0500 Subject: [PATCH 09/17] check consistent lengths --- src/blockedpermutation.jl | 10 +++++++--- test/test_blockedpermutation.jl | 5 +++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 7183782..1278836 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -76,8 +76,11 @@ function blockedperm( return blockedperm(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(bt::AbstractBlockTuple) - return blockedperm(Val(length(bt)), blocks(bt)...) +# keep len kwarg to be consistent with other method signatures +function blockedperm(bt::AbstractBlockTuple; length::Union{Val,Nothing}=nothing) + !(length ∈ (nothing, Val(Base.length(bt)))) && + throw(ArgumentError("Invalid total length")) + return blockedperm(Val(Base.length(bt)), blocks(bt)...) end function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) @@ -152,10 +155,11 @@ function BlockArrays.blocklengths( return BlockLengths end -function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...) +function blockedperm(len::Val, permblocks::Tuple{Vararg{Int}}...) blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}( flatten_tuples(permblocks) ) + value(len) != length(blockedperm) && throw(ArgumentError("Invalid total length")) @assert isperm(blockedperm) return blockedperm end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index ca1f922..cd492bc 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -70,6 +70,11 @@ using TensorAlgebra: @test (@constinferred blockedperm(p)) == p @test (@constinferred blockedperm(bt)) == p + @test_throws ArgumentError blockedperm((1, 3), (2, 4); length=Val(6)) + @test_throws ArgumentError blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(5)) + @test (@constinferred blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(4))) == + blockedperm((1, 3), (2, 4)) + # Split collection into `BlockedPermutation`. p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) @test p == blockedperm((3, 1), (2, 4)) From 0ca766a6fb7ebf11fc599f8f14bc51fb6165864a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Thu, 20 Feb 2025 15:31:10 -0500 Subject: [PATCH 10/17] fix tests --- test/test_blockedpermutation.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index cd492bc..060ed28 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -11,7 +11,8 @@ using TensorAlgebra: blockedperm, blockedperm_indexin, blockedtrivialperm, - trivialperm + trivialperm, + tuplemortar @testset "BlockedPermutation" begin p = @constinferred blockedperm((3, 4, 5), (2, 1)) @@ -63,10 +64,10 @@ using TensorAlgebra: @test p isa BlockedPermutation{0} p = blockedperm((3, 2), (), (1,)) - bt = BlockedTuple{3,(2, 0, 1)}((3, 2, 1)) + bt = tuplemortar(((3, 2), (), (1,))) @test (@constinferred BlockedTuple(p)) == bt @test (@constinferred map(identity, p)) == bt - @test (@constinferred p .+ p) == BlockedTuple{3,(2, 0, 1)}((6, 4, 2)) + @test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,))) @test (@constinferred blockedperm(p)) == p @test (@constinferred blockedperm(bt)) == p @@ -137,11 +138,11 @@ end @test blocklengths(tp) == (2, 0, 1) @test trivialperm(blockedperm((3, 2), (), (1,))) == tp - bt = BlockedTuple{3,(2, 0, 1)}((1, 2, 3)) + bt = tuplemortar(((1, 2), (), (3,))) @test (@constinferred BlockedTuple(tp)) == bt @test (@constinferred blocks(tp)) == blocks(bt) @test (@constinferred map(identity, tp)) == bt - @test (@constinferred tp .+ tp) == BlockedTuple{3,(2, 0, 1)}((2, 4, 6)) + @test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,))) @test (@constinferred blockedperm(tp)) == tp @test (@constinferred trivialperm(tp)) == tp @test (@constinferred trivialperm(bt)) == tp From 05155cf6aeaa7879e0b956a87aba360265493c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 24 Feb 2025 10:49:05 -0500 Subject: [PATCH 11/17] clenaer default kwarg --- src/blockedpermutation.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 1278836..72d57d0 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -77,10 +77,16 @@ function blockedperm( end # keep len kwarg to be consistent with other method signatures -function blockedperm(bt::AbstractBlockTuple; length::Union{Val,Nothing}=nothing) - !(length ∈ (nothing, Val(Base.length(bt)))) && - throw(ArgumentError("Invalid total length")) - return blockedperm(Val(Base.length(bt)), blocks(bt)...) +function blockedperm(bt::AbstractBlockTuple; length=nothing) + return _blockedperm(length, bt) +end + +function _blockedperm(::Nothing, bt::AbstractBlockTuple) + return _blockedperm(Val(length(bt)), bt) +end + +function _blockedperm(vallength::Val, bt::AbstractBlockTuple) + return blockedperm(vallength, blocks(bt)...) end function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) From 777a114e1153bcbbaa84f7d4e7a3f1f6648423b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 24 Feb 2025 10:53:50 -0500 Subject: [PATCH 12/17] cleaner ndims_unspecified --- src/blockedpermutation.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 72d57d0..fb1a800 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -112,17 +112,19 @@ function blockedperm( specified_perm = flatten_tuples(specified_permblocks) len = _blockedperm_length(length, specified_perm) unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm) - UD = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible - insert = unspecified_dims(typeof(permblocks[unspecified_dim]), unspecified_dims_vec, UD) + ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible + insert = unspecified_dims( + permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified + ) permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) return blockedperm(permblocks_specified...) end -function unspecified_dims(::Type{Tuple{Ellipsis}}, unspecified_dims_vec, UD::Val) - return (NTuple{value(UD),Int}(unspecified_dims_vec),) +function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val) + return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),) end -function unspecified_dims(::Type{Ellipsis}, unspecified_dims_vec, UD::Val) - return NTuple{value(UD),Tuple{Int}}(Tuple.(unspecified_dims_vec)) +function unspecified_dims(::Ellipsis, unspecified_dims_vec, ndims_unspecified::Val) + return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified) end # Version of `indexin` that outputs a `blockedperm`. From b5832631fe5d24602a30e74383ae1a2e040c4b74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 24 Feb 2025 12:00:33 -0500 Subject: [PATCH 13/17] permmortar --- src/blockedpermutation.jl | 9 +++++++-- test/test_blockedpermutation.jl | 8 +++++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index fb1a800..bb36316 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -89,6 +89,12 @@ function _blockedperm(vallength::Val, bt::AbstractBlockTuple) return blockedperm(vallength, blocks(bt)...) end +function blockedperm(len::Val, permblocks::Tuple{Vararg{Int}}...) + value(len) != sum(length.(permblocks); init=0) && + throw(ArgumentError("Invalid total length")) + return permmortar(Tuple(permblocks)) +end + function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) return maximum(specified_perm) end @@ -163,11 +169,10 @@ function BlockArrays.blocklengths( return BlockLengths end -function blockedperm(len::Val, permblocks::Tuple{Vararg{Int}}...) +function permmortar(permblocks::Tuple{Vararg{Tuple{Vararg{Int}}}}) blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}( flatten_tuples(permblocks) ) - value(len) != length(blockedperm) && throw(ArgumentError("Invalid total length")) @assert isperm(blockedperm) return blockedperm end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 060ed28..27480c9 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -11,11 +11,12 @@ using TensorAlgebra: blockedperm, blockedperm_indexin, blockedtrivialperm, + permmortar, trivialperm, tuplemortar @testset "BlockedPermutation" begin - p = @constinferred blockedperm((3, 4, 5), (2, 1)) + p = @constinferred permmortar(((3, 4, 5), (2, 1))) @test Tuple(p) === (3, 4, 5, 2, 1) @test isperm(p) @test length(p) == 5 @@ -26,11 +27,16 @@ using TensorAlgebra: @test blocklasts(p) == (3, 5) @test (@constinferred invperm(p)) == blockedperm((5, 4, 1), (2, 3)) @test p isa BlockedPermutation{2} + @test p == (@constinferred blockedperm((3, 4, 5), (2, 1))) flat = (3, 4, 5, 2, 1) @test_throws DimensionMismatch BlockedPermutation{2,(1, 2, 2)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(1, 2, 3)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(-1, 3, 3)}(flat) + @test_throws AssertionError blockedperm((3, 5), (2, 1)) + @test_throws AssertionError blockedperm((0, 1), (2, 3)) + @test_throws AssertionError blockedperm((0,)) + @test_throws AssertionError blockedperm((2,)) # Empty block. p = @constinferred blockedperm((3, 2), (), (1,)) From 723fc15a5db23abf0ac513612b8a4ae3991d0d5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 24 Feb 2025 16:18:00 -0500 Subject: [PATCH 14/17] define blockedpermvcat --- src/blockedpermutation.jl | 63 +++++++++----------- src/fusedims.jl | 3 +- test/test_blockedpermutation.jl | 100 ++++++++++++++++---------------- 3 files changed, 79 insertions(+), 87 deletions(-) diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index bb36316..2d29aea 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -32,23 +32,27 @@ widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple # blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1)) # TODO: Optimize with StaticNumbers.jl or generated functions, see: # https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567 -function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) +function blockedperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) return blockedperm(BlockedTuple(perm, blocklengths)) end -function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) +function blockedperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) return blockedperm(BlockedTuple(perm, BlockLengths)) end -function Base.invperm(blockedperm::AbstractBlockPermutation) +function Base.invperm(bp::AbstractBlockPermutation) # use Val to preserve compile time info - return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm))) + return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) end # # Constructors # +function blockedperm(bt::AbstractBlockTuple) + return permmortar(blocks(bt)) +end + # Bipartition a vector according to the # bipartitioned permutation. # Like `Base.permute!` block out-of-place and blocked. @@ -56,40 +60,29 @@ function blockpermute(v, blockedperm::AbstractBlockPermutation) return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)) end -# blockedperm((4, 3), (2, 1)) -function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing) - return blockedperm(length, permblocks...) +# blockedpermvcat((4, 3), (2, 1)) +function blockedpermvcat( + permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing +) + return blockedpermvcat(length, permblocks...) end -function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...) - return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) +function blockedpermvcat(::Nothing, permblocks::Tuple{Vararg{Int}}...) + return blockedpermvcat(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) end -# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,)) -function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) - return blockedperm(collect_tuple.(permblocks)...; kwargs...) +# blockedpermvcat((3, 2), 1) == blockedpermvcat((3, 2), (1,)) +function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm( +function blockedpermvcat( permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs... ) - return blockedperm(collect_tuple.(permblocks)...; kwargs...) -end - -# keep len kwarg to be consistent with other method signatures -function blockedperm(bt::AbstractBlockTuple; length=nothing) - return _blockedperm(length, bt) -end - -function _blockedperm(::Nothing, bt::AbstractBlockTuple) - return _blockedperm(Val(length(bt)), bt) -end - -function _blockedperm(vallength::Val, bt::AbstractBlockTuple) - return blockedperm(vallength, blocks(bt)...) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(len::Val, permblocks::Tuple{Vararg{Int}}...) +function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...) value(len) != sum(length.(permblocks); init=0) && throw(ArgumentError("Invalid total length")) return permmortar(Tuple(permblocks)) @@ -103,11 +96,11 @@ function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) return value(vallength) end -# blockedperm((4, 3), .., 1) == blockedperm((4, 3), (2,), (1,)) -# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), (2,), (5,), (1,)) -# blockedperm((4, 3), (..,), 1) == blockedperm((4, 3), (2,), (1,)) -# blockedperm((4, 3), (..,), 1; length=Val(5)) == blockedperm((4, 3), (2, 5), (1,)) -function blockedperm( +# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,)) +# blockedpermvcat((4, 3), .., 1; length=Val(5)) == blockedpermvcat((4, 3), (2,), (5,), (1,)) +# blockedpermvcat((4, 3), (..,), 1) == blockedpermvcat((4, 3), (2,), (1,)) +# blockedpermvcat((4, 3), (..,), 1; length=Val(5)) == blockedpermvcat((4, 3), (2, 5), (1,)) +function blockedpermvcat( permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...; length::Union{Val,Nothing}=nothing, ) @@ -123,7 +116,7 @@ function blockedperm( permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified ) permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) - return blockedperm(permblocks_specified...) + return blockedpermvcat(permblocks_specified...) end function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val) @@ -135,7 +128,7 @@ end # Version of `indexin` that outputs a `blockedperm`. function blockedperm_indexin(collection, subs...) - return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) + return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) end # diff --git a/src/fusedims.jl b/src/fusedims.jl index 2e87346..38205c8 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -45,9 +45,8 @@ end # Fix ambiguity issue fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a -# TODO: Is this needed? Maybe delete. function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a)))) + return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) end function fuseaxes( diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 27480c9..e411540 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -11,6 +11,7 @@ using TensorAlgebra: blockedperm, blockedperm_indexin, blockedtrivialperm, + blockedpermvcat, permmortar, trivialperm, tuplemortar @@ -25,21 +26,23 @@ using TensorAlgebra: @test blocklengths(p) == (3, 2) @test blockfirsts(p) == (1, 4) @test blocklasts(p) == (3, 5) - @test (@constinferred invperm(p)) == blockedperm((5, 4, 1), (2, 3)) + @test p == (@constinferred blockedpermvcat((3, 4, 5), (2, 1))) + @test p == blockedperm((3, 4, 5, 2, 1), (3, 2)) + @test p == (@constinferred blockedperm((3, 4, 5, 2, 1), Val((3, 2)))) + @test (@constinferred invperm(p)) == blockedpermvcat((5, 4, 1), (2, 3)) @test p isa BlockedPermutation{2} - @test p == (@constinferred blockedperm((3, 4, 5), (2, 1))) flat = (3, 4, 5, 2, 1) @test_throws DimensionMismatch BlockedPermutation{2,(1, 2, 2)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(1, 2, 3)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(-1, 3, 3)}(flat) - @test_throws AssertionError blockedperm((3, 5), (2, 1)) - @test_throws AssertionError blockedperm((0, 1), (2, 3)) - @test_throws AssertionError blockedperm((0,)) - @test_throws AssertionError blockedperm((2,)) + @test_throws AssertionError blockedpermvcat((3, 5), (2, 1)) + @test_throws AssertionError blockedpermvcat((0, 1), (2, 3)) + @test_throws AssertionError blockedpermvcat((0,)) + @test_throws AssertionError blockedpermvcat((2,)) # Empty block. - p = @constinferred blockedperm((3, 2), (), (1,)) + p = @constinferred blockedpermvcat((3, 2), (), (1,)) @test Tuple(p) === (3, 2, 1) @test isperm(p) @test length(p) == 3 @@ -48,10 +51,10 @@ using TensorAlgebra: @test blocklengths(p) == (2, 0, 1) @test blockfirsts(p) == (1, 3, 3) @test blocklasts(p) == (2, 2, 3) - @test invperm(p) == blockedperm((3, 2), (), (1,)) + @test invperm(p) == blockedpermvcat((3, 2), (), (1,)) @test p isa BlockedPermutation{3} - p = @constinferred blockedperm((), ()) + p = @constinferred blockedpermvcat((), ()) @test Tuple(p) === () @test blocklength(p) == 2 @test blocklengths(p) == (0, 0) @@ -60,7 +63,7 @@ using TensorAlgebra: @test blocks(p) == ((), ()) @test p isa BlockedPermutation{2} - p = @constinferred blockedperm() + p = @constinferred blockedpermvcat() @test Tuple(p) === () @test blocklength(p) == 0 @test blocklengths(p) == () @@ -69,7 +72,7 @@ using TensorAlgebra: @test blocks(p) == () @test p isa BlockedPermutation{0} - p = blockedperm((3, 2), (), (1,)) + p = blockedpermvcat((3, 2), (), (1,)) bt = tuplemortar(((3, 2), (), (1,))) @test (@constinferred BlockedTuple(p)) == bt @test (@constinferred map(identity, p)) == bt @@ -77,62 +80,59 @@ using TensorAlgebra: @test (@constinferred blockedperm(p)) == p @test (@constinferred blockedperm(bt)) == p - @test_throws ArgumentError blockedperm((1, 3), (2, 4); length=Val(6)) - @test_throws ArgumentError blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(5)) - @test (@constinferred blockedperm(tuplemortar(((1, 3), (2, 4))); length=Val(4))) == - blockedperm((1, 3), (2, 4)) + @test_throws ArgumentError blockedpermvcat((1, 3), (2, 4); length=Val(6)) # Split collection into `BlockedPermutation`. p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) - @test p == blockedperm((3, 1), (2, 4)) + @test p == blockedpermvcat((3, 1), (2, 4)) # Singleton dimensions. - p = @constinferred blockedperm((2, 3), 1) - @test p == blockedperm((2, 3), (1,)) + p = @constinferred blockedpermvcat((2, 3), 1) + @test p == blockedpermvcat((2, 3), (1,)) # First dimensions are unspecified. - p = blockedperm(.., (4, 3)) - @test p == blockedperm((1,), (2,), (4, 3)) + p = blockedpermvcat(.., (4, 3)) + @test p == blockedpermvcat((1,), (2,), (4, 3)) # Specify length - p = @constinferred blockedperm(.., (4, 3); length=Val(6)) - @test p == blockedperm((1,), (2,), (5,), (6,), (4, 3)) + p = @constinferred blockedpermvcat(.., (4, 3); length=Val(6)) + @test p == blockedpermvcat((1,), (2,), (5,), (6,), (4, 3)) # Last dimensions are unspecified. - p = blockedperm((4, 3), ..) - @test p == blockedperm((4, 3), (1,), (2,)) + p = blockedpermvcat((4, 3), ..) + @test p == blockedpermvcat((4, 3), (1,), (2,)) # Specify length - p = @constinferred blockedperm((4, 3), ..; length=Val(6)) - @test p == blockedperm((4, 3), (1,), (2,), (5,), (6,)) + p = @constinferred blockedpermvcat((4, 3), ..; length=Val(6)) + @test p == blockedpermvcat((4, 3), (1,), (2,), (5,), (6,)) # Middle dimensions are unspecified. - p = blockedperm((4, 3), .., 1) - @test p == blockedperm((4, 3), (2,), (1,)) + p = blockedpermvcat((4, 3), .., 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) # Specify length - p = @constinferred blockedperm((4, 3), .., 1; length=Val(6)) - @test p == blockedperm((4, 3), (2,), (5,), (6,), (1,)) + p = @constinferred blockedpermvcat((4, 3), .., 1; length=Val(6)) + @test p == blockedpermvcat((4, 3), (2,), (5,), (6,), (1,)) # No dimensions are unspecified. - p = blockedperm((3, 2), .., 1) - @test p == blockedperm((3, 2), (1,)) + p = blockedpermvcat((3, 2), .., 1) + @test p == blockedpermvcat((3, 2), (1,)) # same with (..,) instead of .. - p = blockedperm((..,), (4, 3)) - @test p == blockedperm((1, 2), (4, 3)) - p = @constinferred blockedperm((..,), (4, 3); length=Val(6)) - @test p == blockedperm((1, 2, 5, 6), (4, 3)) - - p = blockedperm((4, 3), (..,)) - @test p == blockedperm((4, 3), (1, 2)) - p = @constinferred blockedperm((4, 3), (..,); length=Val(6)) - @test p == blockedperm((4, 3), (1, 2, 5, 6)) - - p = blockedperm((4, 3), (..,), 1) - @test p == blockedperm((4, 3), (2,), (1,)) - p = @constinferred blockedperm((4, 3), (..,), 1; length=Val(6)) - @test p == blockedperm((4, 3), (2, 5, 6), (1,)) - - p = blockedperm((3, 2), (..,), 1) - @test p == blockedperm((3, 2), (), (1,)) + p = blockedpermvcat((..,), (4, 3)) + @test p == blockedpermvcat((1, 2), (4, 3)) + p = @constinferred blockedpermvcat((..,), (4, 3); length=Val(6)) + @test p == blockedpermvcat((1, 2, 5, 6), (4, 3)) + + p = blockedpermvcat((4, 3), (..,)) + @test p == blockedpermvcat((4, 3), (1, 2)) + p = @constinferred blockedpermvcat((4, 3), (..,); length=Val(6)) + @test p == blockedpermvcat((4, 3), (1, 2, 5, 6)) + + p = blockedpermvcat((4, 3), (..,), 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) + p = @constinferred blockedpermvcat((4, 3), (..,), 1; length=Val(6)) + @test p == blockedpermvcat((4, 3), (2, 5, 6), (1,)) + + p = blockedpermvcat((3, 2), (..,), 1) + @test p == blockedpermvcat((3, 2), (), (1,)) end @testset "BlockedTrivialPermutation" begin @@ -142,7 +142,7 @@ end @test Tuple(tp) == (1, 2, 3) @test blocklength(tp) == 3 @test blocklengths(tp) == (2, 0, 1) - @test trivialperm(blockedperm((3, 2), (), (1,))) == tp + @test trivialperm(blockedpermvcat((3, 2), (), (1,))) == tp bt = tuplemortar(((1, 2), (), (3,))) @test (@constinferred BlockedTuple(tp)) == bt From 5a8863c194a39ea1786b70f63228481864b8f530 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Mon, 24 Feb 2025 16:37:14 -0500 Subject: [PATCH 15/17] adapt blockperms --- src/contract/blockedperms.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index bc0679e..6f33134 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedperm(filter(!isempty, permblocks1)...) + biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedperm(filter(!isempty, permblocks2)...) + biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...) return biperm_dest, biperm1, biperm2 end From 0c4ba32f4b204f841b7e4243ffcd8e7d05cce607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Tue, 25 Feb 2025 13:46:18 -0500 Subject: [PATCH 16/17] bump version number to breaking change --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a967c86..439a335 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.11" +version = "0.2.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" From 673dd2ba556999897f69f7fe5ee3a866026a4e86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olivier=20Gauth=C3=A9?= Date: Fri, 28 Feb 2025 17:04:46 -0500 Subject: [PATCH 17/17] remove BlockSparseArrays from tests --- test/Project.toml | 4 - test/test_blockarrays_contract.jl | 106 ++++------------------ test/test_gradedunitrangesext_contract.jl | 79 ---------------- 3 files changed, 19 insertions(+), 170 deletions(-) delete mode 100644 test/test_gradedunitrangesext_contract.jl diff --git a/test/Project.toml b/test/Project.toml index 275179d..31461ab 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -10,7 +9,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" @@ -21,10 +19,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8.9" -BlockSparseArrays = "0.2" Random = "1.10" SafeTestsets = "0.1" -SparseArraysBase = "0.2.11" Suppressor = "0.2" SymmetrySectors = "0.1" TensorOperations = "5.1.3" diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index c9094f0..e7754b0 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -1,12 +1,10 @@ using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize -using BlockSparseArrays: BlockSparseArray -using SparseArraysBase: densearray using TensorAlgebra: contract using Random: randn! using Test: @test, @test_broken, @testset function randn_blockdiagonal(elt::Type, axes::Tuple) - a = BlockSparseArray{elt}(axes) + a = zeros(elt, axes) blockdiaglength = minimum(blocksize(a)) for i in 1:blockdiaglength b = Block(ntuple(Returns(i), ndims(a))) @@ -18,74 +16,14 @@ end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "`contract` blocked arrays (eltype=$elt)" for elt in elts d = blockedrange([2, 3]) - a1_sba = randn_blockdiagonal(elt, (d, d, d, d)) - a2_sba = randn_blockdiagonal(elt, (d, d, d, d)) - a3_sba = randn_blockdiagonal(elt, (d, d)) - a1_dense = densearray(a1_sba) - a2_dense = densearray(a2_sba) - a3_dense = densearray(a3_sba) - - @testset "BlockArray" begin - a1 = BlockArray(a1_sba) - a2 = BlockArray(a2_sba) - a3 = BlockArray(a3_sba) - - # matrix matrix - @test_broken a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - #= - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # vector vector - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test_broken a_dest isa BlockArray # TBD relax to AbstractArray{elt,0}? - @test a_dest ≈ a_dest_dense - - # outer product - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - end + a1 = randn_blockdiagonal(elt, (d, d, d, d)) + a2 = randn_blockdiagonal(elt, (d, d, d, d)) + a3 = randn_blockdiagonal(elt, (d, d)) + a1_dense = convert(Array, a1) + a2_dense = convert(Array, a2) + a3_dense = convert(Array, a3) @testset "BlockedArray" begin - a1 = BlockedArray(a1_sba) - a2 = BlockedArray(a2_sba) - a3 = BlockedArray(a3_sba) - # matrix matrix a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) a_dest_dense, dimnames_dest_dense = contract( @@ -97,31 +35,27 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_dest ≈ a_dest_dense # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) @test a_dest isa BlockedArray @test a_dest ≈ a_dest_dense - =# # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= + a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) @test a_dest isa BlockedArray @test a_dest ≈ a_dest_dense - =# # vector vector a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test_broken a_dest isa BlockedArray # TBD relax to AbstractArray{elt,0}? + @test_broken a_dest isa BlockedArray{elt,0} @test a_dest ≈ a_dest_dense # outer product @@ -133,8 +67,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_dest ≈ a_dest_dense end - @testset "BlockSparseArray" begin - a1, a2, a3 = a1_sba, a2_sba, a3_sba + @testset "BlockArray" begin + a1, a3, a3 = BlockArray.((a1, a2, a3)) # matrix matrix a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) @@ -143,25 +77,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) ) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense - =# # vector matrix a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense # vector vector @@ -169,15 +101,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test_broken a_dest isa BlockArray{elt,0} @test a_dest ≈ a_dest_dense # outer product - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense end end diff --git a/test/test_gradedunitrangesext_contract.jl b/test/test_gradedunitrangesext_contract.jl deleted file mode 100644 index 4304c62..0000000 --- a/test/test_gradedunitrangesext_contract.jl +++ /dev/null @@ -1,79 +0,0 @@ -using BlockArrays: Block, blocksize -using BlockSparseArrays: BlockSparseArray -using GradedUnitRanges: dual, gradedrange -using SparseArraysBase: densearray -using SymmetrySectors: U1 -using TensorAlgebra: contract -using Random: randn! -using Test: @test, @test_broken, @testset - -function randn_blockdiagonal(elt::Type, axes::Tuple) - a = BlockSparseArray{elt}(axes) - blockdiaglength = minimum(blocksize(a)) - for i in 1:blockdiaglength - b = Block(ntuple(Returns(i), ndims(a))) - a[b] = randn!(a[b]) - end - return a -end - -const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) -@testset "`contract` `BlockSparseArray` (eltype=$elt)" for elt in elts - d = gradedrange([U1(0) => 2, U1(1) => 3]) - a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) - a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) - a3 = randn_blockdiagonal(elt, (d, dual(d))) - a1_dense = densearray(a1) - a2_dense = densearray(a2) - a3_dense = densearray(a3) - - # matrix matrix - a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - - # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # vector vector - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # outer product - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# -end