diff --git a/Project.toml b/Project.toml index a967c86..439a335 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TensorAlgebra" uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" authors = ["ITensor developers and contributors"] -version = "0.1.11" +version = "0.2.0" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/blockedpermutation.jl b/src/blockedpermutation.jl index 663f480..2d29aea 100644 --- a/src/blockedpermutation.jl +++ b/src/blockedpermutation.jl @@ -32,23 +32,27 @@ widened_constructorof(::Type{<:AbstractBlockPermutation}) = BlockedTuple # blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1)) # TODO: Optimize with StaticNumbers.jl or generated functions, see: # https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567 -function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) +function blockedperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}}) return blockedperm(BlockedTuple(perm, blocklengths)) end -function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) +function blockedperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val) return blockedperm(BlockedTuple(perm, BlockLengths)) end -function Base.invperm(blockedperm::AbstractBlockPermutation) +function Base.invperm(bp::AbstractBlockPermutation) # use Val to preserve compile time info - return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm))) + return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp))) end # # Constructors # +function blockedperm(bt::AbstractBlockTuple) + return permmortar(blocks(bt)) +end + # Bipartition a vector according to the # bipartitioned permutation. # Like `Base.permute!` block out-of-place and blocked. @@ -56,26 +60,32 @@ function blockpermute(v, blockedperm::AbstractBlockPermutation) return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm)) end -# blockedperm((4, 3), (2, 1)) -function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing) - return blockedperm(length, permblocks...) +# blockedpermvcat((4, 3), (2, 1)) +function blockedpermvcat( + permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing +) + return blockedpermvcat(length, permblocks...) end -function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...) - return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) +function blockedpermvcat(::Nothing, permblocks::Tuple{Vararg{Int}}...) + return blockedpermvcat(Val(sum(length, permblocks; init=zero(Bool))), permblocks...) end -# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,)) -function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) - return blockedperm(collect_tuple.(permblocks)...; kwargs...) +# blockedpermvcat((3, 2), 1) == blockedpermvcat((3, 2), (1,)) +function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...) - return blockedperm(collect_tuple.(permblocks)...; kwargs...) +function blockedpermvcat( + permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs... +) + return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...) end -function blockedperm(bt::AbstractBlockTuple) - return blockedperm(Val(length(bt)), blocks(bt)...) +function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...) + value(len) != sum(length.(permblocks); init=0) && + throw(ArgumentError("Invalid total length")) + return permmortar(Tuple(permblocks)) end function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}}) @@ -86,25 +96,39 @@ function _blockedperm_length(vallength::Val, ::Tuple{Vararg{Int}}) return value(vallength) end -# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1) -# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1) -function blockedperm( - permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing +# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,)) +# blockedpermvcat((4, 3), .., 1; length=Val(5)) == blockedpermvcat((4, 3), (2,), (5,), (1,)) +# blockedpermvcat((4, 3), (..,), 1) == blockedpermvcat((4, 3), (2,), (1,)) +# blockedpermvcat((4, 3), (..,), 1; length=Val(5)) == blockedpermvcat((4, 3), (2, 5), (1,)) +function blockedpermvcat( + permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...; + length::Union{Val,Nothing}=nothing, ) # Check there is only one `Ellipsis`. - @assert isone(count(x -> x isa Ellipsis, permblocks)) - specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks) - unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks) + @assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)) + specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks) + unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks) specified_perm = flatten_tuples(specified_permblocks) len = _blockedperm_length(length, specified_perm) - unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks))) - permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims) - return blockedperm(permblocks_specified...) + unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm) + ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible + insert = unspecified_dims( + permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified + ) + permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert) + return blockedpermvcat(permblocks_specified...) +end + +function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val) + return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),) +end +function unspecified_dims(::Ellipsis, unspecified_dims_vec, ndims_unspecified::Val) + return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified) end # Version of `indexin` that outputs a `blockedperm`. function blockedperm_indexin(collection, subs...) - return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) + return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...) end # @@ -138,7 +162,7 @@ function BlockArrays.blocklengths( return BlockLengths end -function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...) +function permmortar(permblocks::Tuple{Vararg{Tuple{Vararg{Int}}}}) blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}( flatten_tuples(permblocks) ) diff --git a/src/contract/blockedperms.jl b/src/contract/blockedperms.jl index bc0679e..6f33134 100644 --- a/src/contract/blockedperms.jl +++ b/src/contract/blockedperms.jl @@ -22,10 +22,10 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) permblocks_dest = (perm_codomain_dest, perm_domain_dest) - biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...) + biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...) permblocks1 = (perm_codomain1, perm_domain1) - biperm1 = blockedperm(filter(!isempty, permblocks1)...) + biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...) permblocks2 = (perm_codomain2, perm_domain2) - biperm2 = blockedperm(filter(!isempty, permblocks2)...) + biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...) return biperm_dest, biperm1, biperm2 end diff --git a/src/fusedims.jl b/src/fusedims.jl index 2e87346..38205c8 100644 --- a/src/fusedims.jl +++ b/src/fusedims.jl @@ -45,9 +45,8 @@ end # Fix ambiguity issue fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a -# TODO: Is this needed? Maybe delete. function fusedims(a::AbstractArray, permblocks...) - return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a)))) + return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a)))) end function fuseaxes( diff --git a/test/Project.toml b/test/Project.toml index 275179d..31461ab 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" -BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949" GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" @@ -10,7 +9,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e" @@ -21,10 +19,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a" [compat] Aqua = "0.8.9" -BlockSparseArrays = "0.2" Random = "1.10" SafeTestsets = "0.1" -SparseArraysBase = "0.2.11" Suppressor = "0.2" SymmetrySectors = "0.1" TensorOperations = "5.1.3" diff --git a/test/test_basics.jl b/test/test_basics.jl index 221274b..6b75c69 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -32,9 +32,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_fused = fusedims(a, (3, 1), .., 2) @test eltype(a_fused) === elt @test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3)) - a_fused = fusedims(a, (3, 1), ..) + a_fused = fusedims(a, (3, 1), (..,)) @test eltype(a_fused) === elt - @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5)) + @test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15)) end @testset "splitdims (eltype=$elt)" for elt in elts a = randn(elt, 6, 20) diff --git a/test/test_blockarrays_contract.jl b/test/test_blockarrays_contract.jl index c9094f0..e7754b0 100644 --- a/test/test_blockarrays_contract.jl +++ b/test/test_blockarrays_contract.jl @@ -1,12 +1,10 @@ using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize -using BlockSparseArrays: BlockSparseArray -using SparseArraysBase: densearray using TensorAlgebra: contract using Random: randn! using Test: @test, @test_broken, @testset function randn_blockdiagonal(elt::Type, axes::Tuple) - a = BlockSparseArray{elt}(axes) + a = zeros(elt, axes) blockdiaglength = minimum(blocksize(a)) for i in 1:blockdiaglength b = Block(ntuple(Returns(i), ndims(a))) @@ -18,74 +16,14 @@ end const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @testset "`contract` blocked arrays (eltype=$elt)" for elt in elts d = blockedrange([2, 3]) - a1_sba = randn_blockdiagonal(elt, (d, d, d, d)) - a2_sba = randn_blockdiagonal(elt, (d, d, d, d)) - a3_sba = randn_blockdiagonal(elt, (d, d)) - a1_dense = densearray(a1_sba) - a2_dense = densearray(a2_sba) - a3_dense = densearray(a3_sba) - - @testset "BlockArray" begin - a1 = BlockArray(a1_sba) - a2 = BlockArray(a2_sba) - a3 = BlockArray(a3_sba) - - # matrix matrix - @test_broken a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - #= - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - - # vector vector - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test_broken a_dest isa BlockArray # TBD relax to AbstractArray{elt,0}? - @test a_dest ≈ a_dest_dense - - # outer product - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockArray - @test a_dest ≈ a_dest_dense - =# - end + a1 = randn_blockdiagonal(elt, (d, d, d, d)) + a2 = randn_blockdiagonal(elt, (d, d, d, d)) + a3 = randn_blockdiagonal(elt, (d, d)) + a1_dense = convert(Array, a1) + a2_dense = convert(Array, a2) + a3_dense = convert(Array, a3) @testset "BlockedArray" begin - a1 = BlockedArray(a1_sba) - a2 = BlockedArray(a2_sba) - a3 = BlockedArray(a3_sba) - # matrix matrix a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) a_dest_dense, dimnames_dest_dense = contract( @@ -97,31 +35,27 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_dest ≈ a_dest_dense # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) @test a_dest isa BlockedArray @test a_dest ≈ a_dest_dense - =# # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= + a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) @test a_dest isa BlockedArray @test a_dest ≈ a_dest_dense - =# # vector vector a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test_broken a_dest isa BlockedArray # TBD relax to AbstractArray{elt,0}? + @test_broken a_dest isa BlockedArray{elt,0} @test a_dest ≈ a_dest_dense # outer product @@ -133,8 +67,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test a_dest ≈ a_dest_dense end - @testset "BlockSparseArray" begin - a1, a2, a3 = a1_sba, a2_sba, a3_sba + @testset "BlockArray" begin + a1, a3, a3 = BlockArray.((a1, a2, a3)) # matrix matrix a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) @@ -143,25 +77,23 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) ) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= + a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense - =# # vector matrix a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense # vector vector @@ -169,15 +101,15 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test_broken a_dest isa BlockArray{elt,0} @test a_dest ≈ a_dest_dense # outer product - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) + a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) @test dimnames_dest == dimnames_dest_dense @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray + @test a_dest isa BlockArray @test a_dest ≈ a_dest_dense end end diff --git a/test/test_blockedpermutation.jl b/test/test_blockedpermutation.jl index 37f8e6a..e411540 100644 --- a/test/test_blockedpermutation.jl +++ b/test/test_blockedpermutation.jl @@ -11,10 +11,13 @@ using TensorAlgebra: blockedperm, blockedperm_indexin, blockedtrivialperm, - trivialperm + blockedpermvcat, + permmortar, + trivialperm, + tuplemortar @testset "BlockedPermutation" begin - p = @constinferred blockedperm((3, 4, 5), (2, 1)) + p = @constinferred permmortar(((3, 4, 5), (2, 1))) @test Tuple(p) === (3, 4, 5, 2, 1) @test isperm(p) @test length(p) == 5 @@ -23,16 +26,23 @@ using TensorAlgebra: @test blocklengths(p) == (3, 2) @test blockfirsts(p) == (1, 4) @test blocklasts(p) == (3, 5) - @test (@constinferred invperm(p)) == blockedperm((5, 4, 1), (2, 3)) + @test p == (@constinferred blockedpermvcat((3, 4, 5), (2, 1))) + @test p == blockedperm((3, 4, 5, 2, 1), (3, 2)) + @test p == (@constinferred blockedperm((3, 4, 5, 2, 1), Val((3, 2)))) + @test (@constinferred invperm(p)) == blockedpermvcat((5, 4, 1), (2, 3)) @test p isa BlockedPermutation{2} flat = (3, 4, 5, 2, 1) @test_throws DimensionMismatch BlockedPermutation{2,(1, 2, 2)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(1, 2, 3)}(flat) @test_throws DimensionMismatch BlockedPermutation{3,(-1, 3, 3)}(flat) + @test_throws AssertionError blockedpermvcat((3, 5), (2, 1)) + @test_throws AssertionError blockedpermvcat((0, 1), (2, 3)) + @test_throws AssertionError blockedpermvcat((0,)) + @test_throws AssertionError blockedpermvcat((2,)) # Empty block. - p = @constinferred blockedperm((3, 2), (), (1,)) + p = @constinferred blockedpermvcat((3, 2), (), (1,)) @test Tuple(p) === (3, 2, 1) @test isperm(p) @test length(p) == 3 @@ -41,10 +51,10 @@ using TensorAlgebra: @test blocklengths(p) == (2, 0, 1) @test blockfirsts(p) == (1, 3, 3) @test blocklasts(p) == (2, 2, 3) - @test invperm(p) == blockedperm((3, 2), (), (1,)) + @test invperm(p) == blockedpermvcat((3, 2), (), (1,)) @test p isa BlockedPermutation{3} - p = @constinferred blockedperm((), ()) + p = @constinferred blockedpermvcat((), ()) @test Tuple(p) === () @test blocklength(p) == 2 @test blocklengths(p) == (0, 0) @@ -53,7 +63,7 @@ using TensorAlgebra: @test blocks(p) == ((), ()) @test p isa BlockedPermutation{2} - p = @constinferred blockedperm() + p = @constinferred blockedpermvcat() @test Tuple(p) === () @test blocklength(p) == 0 @test blocklengths(p) == () @@ -62,46 +72,67 @@ using TensorAlgebra: @test blocks(p) == () @test p isa BlockedPermutation{0} - p = blockedperm((3, 2), (), (1,)) - bt = BlockedTuple{3,(2, 0, 1)}((3, 2, 1)) + p = blockedpermvcat((3, 2), (), (1,)) + bt = tuplemortar(((3, 2), (), (1,))) @test (@constinferred BlockedTuple(p)) == bt @test (@constinferred map(identity, p)) == bt - @test (@constinferred p .+ p) == BlockedTuple{3,(2, 0, 1)}((6, 4, 2)) + @test (@constinferred p .+ p) == tuplemortar(((6, 4), (), (2,))) @test (@constinferred blockedperm(p)) == p @test (@constinferred blockedperm(bt)) == p + @test_throws ArgumentError blockedpermvcat((1, 3), (2, 4); length=Val(6)) + # Split collection into `BlockedPermutation`. p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) - @test p == blockedperm((3, 1), (2, 4)) + @test p == blockedpermvcat((3, 1), (2, 4)) # Singleton dimensions. - p = @constinferred blockedperm((2, 3), 1) - @test p == blockedperm((2, 3), (1,)) + p = @constinferred blockedpermvcat((2, 3), 1) + @test p == blockedpermvcat((2, 3), (1,)) # First dimensions are unspecified. - p = blockedperm(.., (4, 3)) - @test p == blockedperm(1, 2, (4, 3)) + p = blockedpermvcat(.., (4, 3)) + @test p == blockedpermvcat((1,), (2,), (4, 3)) # Specify length - p = blockedperm(.., (4, 3); length=Val(6)) - @test p == blockedperm(1, 2, 5, 6, (4, 3)) + p = @constinferred blockedpermvcat(.., (4, 3); length=Val(6)) + @test p == blockedpermvcat((1,), (2,), (5,), (6,), (4, 3)) # Last dimensions are unspecified. - p = blockedperm((4, 3), ..) - @test p == blockedperm((4, 3), 1, 2) + p = blockedpermvcat((4, 3), ..) + @test p == blockedpermvcat((4, 3), (1,), (2,)) # Specify length - p = blockedperm((4, 3), ..; length=Val(6)) - @test p == blockedperm((4, 3), 1, 2, 5, 6) + p = @constinferred blockedpermvcat((4, 3), ..; length=Val(6)) + @test p == blockedpermvcat((4, 3), (1,), (2,), (5,), (6,)) # Middle dimensions are unspecified. - p = blockedperm((4, 3), .., 1) - @test p == blockedperm((4, 3), 2, 1) + p = blockedpermvcat((4, 3), .., 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) # Specify length - p = blockedperm((4, 3), .., 1; length=Val(6)) - @test p == blockedperm((4, 3), 2, 5, 6, 1) + p = @constinferred blockedpermvcat((4, 3), .., 1; length=Val(6)) + @test p == blockedpermvcat((4, 3), (2,), (5,), (6,), (1,)) # No dimensions are unspecified. - p = blockedperm((3, 2), .., 1) - @test p == blockedperm((3, 2), 1) + p = blockedpermvcat((3, 2), .., 1) + @test p == blockedpermvcat((3, 2), (1,)) + + # same with (..,) instead of .. + p = blockedpermvcat((..,), (4, 3)) + @test p == blockedpermvcat((1, 2), (4, 3)) + p = @constinferred blockedpermvcat((..,), (4, 3); length=Val(6)) + @test p == blockedpermvcat((1, 2, 5, 6), (4, 3)) + + p = blockedpermvcat((4, 3), (..,)) + @test p == blockedpermvcat((4, 3), (1, 2)) + p = @constinferred blockedpermvcat((4, 3), (..,); length=Val(6)) + @test p == blockedpermvcat((4, 3), (1, 2, 5, 6)) + + p = blockedpermvcat((4, 3), (..,), 1) + @test p == blockedpermvcat((4, 3), (2,), (1,)) + p = @constinferred blockedpermvcat((4, 3), (..,), 1; length=Val(6)) + @test p == blockedpermvcat((4, 3), (2, 5, 6), (1,)) + + p = blockedpermvcat((3, 2), (..,), 1) + @test p == blockedpermvcat((3, 2), (), (1,)) end @testset "BlockedTrivialPermutation" begin @@ -111,13 +142,13 @@ end @test Tuple(tp) == (1, 2, 3) @test blocklength(tp) == 3 @test blocklengths(tp) == (2, 0, 1) - @test trivialperm(blockedperm((3, 2), (), (1,))) == tp + @test trivialperm(blockedpermvcat((3, 2), (), (1,))) == tp - bt = BlockedTuple{3,(2, 0, 1)}((1, 2, 3)) + bt = tuplemortar(((1, 2), (), (3,))) @test (@constinferred BlockedTuple(tp)) == bt @test (@constinferred blocks(tp)) == blocks(bt) @test (@constinferred map(identity, tp)) == bt - @test (@constinferred tp .+ tp) == BlockedTuple{3,(2, 0, 1)}((2, 4, 6)) + @test (@constinferred tp .+ tp) == tuplemortar(((2, 4), (), (6,))) @test (@constinferred blockedperm(tp)) == tp @test (@constinferred trivialperm(tp)) == tp @test (@constinferred trivialperm(bt)) == tp diff --git a/test/test_gradedunitrangesext_contract.jl b/test/test_gradedunitrangesext_contract.jl deleted file mode 100644 index 4304c62..0000000 --- a/test/test_gradedunitrangesext_contract.jl +++ /dev/null @@ -1,79 +0,0 @@ -using BlockArrays: Block, blocksize -using BlockSparseArrays: BlockSparseArray -using GradedUnitRanges: dual, gradedrange -using SparseArraysBase: densearray -using SymmetrySectors: U1 -using TensorAlgebra: contract -using Random: randn! -using Test: @test, @test_broken, @testset - -function randn_blockdiagonal(elt::Type, axes::Tuple) - a = BlockSparseArray{elt}(axes) - blockdiaglength = minimum(blocksize(a)) - for i in 1:blockdiaglength - b = Block(ntuple(Returns(i), ndims(a))) - a[b] = randn!(a[b]) - end - return a -end - -const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) -@testset "`contract` `BlockSparseArray` (eltype=$elt)" for elt in elts - d = gradedrange([U1(0) => 2, U1(1) => 3]) - a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) - a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d))) - a3 = randn_blockdiagonal(elt, (d, dual(d))) - a1_dense = densearray(a1) - a2_dense = densearray(a2) - a3_dense = densearray(a3) - - # matrix matrix - a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4)) - a_dest_dense, dimnames_dest_dense = contract( - a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4) - ) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - - # matrix vector - @test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2)) - #= - a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # vector matrix - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # vector vector - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# - - # outer product - @test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4)) - #= - a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4)) - @test dimnames_dest == dimnames_dest_dense - @test size(a_dest) == size(a_dest_dense) - @test a_dest isa BlockSparseArray - @test a_dest ≈ a_dest_dense - =# -end