Skip to content

Generalize blockedperm ellipsis inputs, change constructor names #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Mar 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TensorAlgebra"
uuid = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.11"
version = "0.2.0"

[deps]
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
Expand Down
80 changes: 52 additions & 28 deletions src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,50 +32,60 @@
# blockperm((4, 3, 2, 1), (2, 2)) == blockedperm((4, 3), (2, 1))
# TODO: Optimize with StaticNumbers.jl or generated functions, see:
# https://discourse.julialang.org/t/avoiding-type-instability-when-slicing-a-tuple/38567
function blockperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})
function blockedperm(perm::Tuple{Vararg{Int}}, blocklengths::Tuple{Vararg{Int}})

Check warning on line 35 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L35

Added line #L35 was not covered by tests
return blockedperm(BlockedTuple(perm, blocklengths))
end

function blockperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)
function blockedperm(perm::Tuple{Vararg{Int}}, BlockLengths::Val)

Check warning on line 39 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L39

Added line #L39 was not covered by tests
return blockedperm(BlockedTuple(perm, BlockLengths))
end

function Base.invperm(blockedperm::AbstractBlockPermutation)
function Base.invperm(bp::AbstractBlockPermutation)

Check warning on line 43 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L43

Added line #L43 was not covered by tests
# use Val to preserve compile time info
return blockperm(invperm(Tuple(blockedperm)), Val(blocklengths(blockedperm)))
return blockedperm(invperm(Tuple(bp)), Val(blocklengths(bp)))

Check warning on line 45 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L45

Added line #L45 was not covered by tests
end

#
# Constructors
#

function blockedperm(bt::AbstractBlockTuple)
return permmortar(blocks(bt))

Check warning on line 53 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L52-L53

Added lines #L52 - L53 were not covered by tests
end

# Bipartition a vector according to the
# bipartitioned permutation.
# Like `Base.permute!` block out-of-place and blocked.
function blockpermute(v, blockedperm::AbstractBlockPermutation)
return map(blockperm -> map(i -> v[i], blockperm), blocks(blockedperm))
end

# blockedperm((4, 3), (2, 1))
function blockedperm(permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing)
return blockedperm(length, permblocks...)
# blockedpermvcat((4, 3), (2, 1))
function blockedpermvcat(

Check warning on line 64 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L64

Added line #L64 was not covered by tests
permblocks::Tuple{Vararg{Int}}...; length::Union{Val,Nothing}=nothing
)
return blockedpermvcat(length, permblocks...)

Check warning on line 67 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L67

Added line #L67 was not covered by tests
end

function blockedperm(::Nothing, permblocks::Tuple{Vararg{Int}}...)
return blockedperm(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)
function blockedpermvcat(::Nothing, permblocks::Tuple{Vararg{Int}}...)
return blockedpermvcat(Val(sum(length, permblocks; init=zero(Bool))), permblocks...)

Check warning on line 71 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
end

# blockedperm((3, 2), 1) == blockedperm((3, 2), (1,))
function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...)
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
# blockedpermvcat((3, 2), 1) == blockedpermvcat((3, 2), (1,))
function blockedpermvcat(permblocks::Union{Tuple{Vararg{Int}},Int}...; kwargs...)
return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...)

Check warning on line 76 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L75-L76

Added lines #L75 - L76 were not covered by tests
end

function blockedperm(permblocks::Union{Tuple{Vararg{Int}},Int,Ellipsis}...; kwargs...)
return blockedperm(collect_tuple.(permblocks)...; kwargs...)
function blockedpermvcat(

Check warning on line 79 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L79

Added line #L79 was not covered by tests
permblocks::Union{Tuple{Vararg{Int}},Tuple{Ellipsis},Int,Ellipsis}...; kwargs...
)
return blockedpermvcat(collect_tuple.(permblocks)...; kwargs...)

Check warning on line 82 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L82

Added line #L82 was not covered by tests
end

function blockedperm(bt::AbstractBlockTuple)
return blockedperm(Val(length(bt)), blocks(bt)...)
function blockedpermvcat(len::Val, permblocks::Tuple{Vararg{Int}}...)
value(len) != sum(length.(permblocks); init=0) &&

Check warning on line 86 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L85-L86

Added lines #L85 - L86 were not covered by tests
throw(ArgumentError("Invalid total length"))
return permmortar(Tuple(permblocks))

Check warning on line 88 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L88

Added line #L88 was not covered by tests
end

function _blockedperm_length(::Nothing, specified_perm::Tuple{Vararg{Int}})
Expand All @@ -86,25 +96,39 @@
return value(vallength)
end

# blockedperm((4, 3), .., 1) == blockedperm((4, 3), 2, 1)
# blockedperm((4, 3), .., 1; length=Val(5)) == blockedperm((4, 3), 2, 5, 1)
function blockedperm(
permblocks::Union{Tuple{Vararg{Int}},Ellipsis}...; length::Union{Val,Nothing}=nothing
# blockedpermvcat((4, 3), .., 1) == blockedpermvcat((4, 3), (2,), (1,))
# blockedpermvcat((4, 3), .., 1; length=Val(5)) == blockedpermvcat((4, 3), (2,), (5,), (1,))
# blockedpermvcat((4, 3), (..,), 1) == blockedpermvcat((4, 3), (2,), (1,))
# blockedpermvcat((4, 3), (..,), 1; length=Val(5)) == blockedpermvcat((4, 3), (2, 5), (1,))
function blockedpermvcat(

Check warning on line 103 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L103

Added line #L103 was not covered by tests
permblocks::Union{Tuple{Vararg{Int}},Ellipsis,Tuple{Ellipsis}}...;
length::Union{Val,Nothing}=nothing,
)
# Check there is only one `Ellipsis`.
@assert isone(count(x -> x isa Ellipsis, permblocks))
specified_permblocks = filter(x -> !(x isa Ellipsis), permblocks)
unspecified_dim = findfirst(x -> x isa Ellipsis, permblocks)
@assert isone(count(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks))
specified_permblocks = filter(x -> !(x isa Union{Ellipsis,Tuple{Ellipsis}}), permblocks)
unspecified_dim = findfirst(x -> x isa Union{Ellipsis,Tuple{Ellipsis}}, permblocks)

Check warning on line 110 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L108-L110

Added lines #L108 - L110 were not covered by tests
specified_perm = flatten_tuples(specified_permblocks)
len = _blockedperm_length(length, specified_perm)
unspecified_dims = Tuple(setdiff(Base.OneTo(len), flatten_tuples(specified_permblocks)))
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, unspecified_dims)
return blockedperm(permblocks_specified...)
unspecified_dims_vec = setdiff(Base.OneTo(len), specified_perm)
ndims_unspecified = Val(len - sum(Base.length.(specified_permblocks))) # preserve type stability when possible
insert = unspecified_dims(

Check warning on line 115 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L113-L115

Added lines #L113 - L115 were not covered by tests
permblocks[unspecified_dim], unspecified_dims_vec, ndims_unspecified
)
permblocks_specified = TupleTools.insertat(permblocks, unspecified_dim, insert)
return blockedpermvcat(permblocks_specified...)

Check warning on line 119 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L118-L119

Added lines #L118 - L119 were not covered by tests
end

function unspecified_dims(::Tuple{Ellipsis}, unspecified_dims_vec, ndims_unspecified::Val)
return (ntuple(i -> unspecified_dims_vec[i], ndims_unspecified),)

Check warning on line 123 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L122-L123

Added lines #L122 - L123 were not covered by tests
end
function unspecified_dims(::Ellipsis, unspecified_dims_vec, ndims_unspecified::Val)
return ntuple(i -> (unspecified_dims_vec[i],), ndims_unspecified)

Check warning on line 126 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L125-L126

Added lines #L125 - L126 were not covered by tests
end

# Version of `indexin` that outputs a `blockedperm`.
function blockedperm_indexin(collection, subs...)
return blockedperm(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)
return blockedpermvcat(map(sub -> BaseExtensions.indexin(sub, collection), subs)...)

Check warning on line 131 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L131

Added line #L131 was not covered by tests
end

#
Expand Down Expand Up @@ -138,7 +162,7 @@
return BlockLengths
end

function blockedperm(::Val, permblocks::Tuple{Vararg{Int}}...)
function permmortar(permblocks::Tuple{Vararg{Tuple{Vararg{Int}}}})

Check warning on line 165 in src/blockedpermutation.jl

View check run for this annotation

Codecov / codecov/patch

src/blockedpermutation.jl#L165

Added line #L165 was not covered by tests
blockedperm = BlockedPermutation{length(permblocks),length.(permblocks)}(
flatten_tuples(permblocks)
)
Expand Down
6 changes: 3 additions & 3 deletions src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
biperm_dest = blockedpermvcat(filter(!isempty, permblocks_dest)...)

Check warning on line 25 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L25

Added line #L25 was not covered by tests
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
biperm1 = blockedpermvcat(filter(!isempty, permblocks1)...)

Check warning on line 27 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L27

Added line #L27 was not covered by tests
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
biperm2 = blockedpermvcat(filter(!isempty, permblocks2)...)

Check warning on line 29 in src/contract/blockedperms.jl

View check run for this annotation

Codecov / codecov/patch

src/contract/blockedperms.jl#L29

Added line #L29 was not covered by tests
return biperm_dest, biperm1, biperm2
end
3 changes: 1 addition & 2 deletions src/fusedims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@
# Fix ambiguity issue
fusedims(a::AbstractArray{<:Any,0}, ::Vararg{Tuple{}}) = a

# TODO: Is this needed? Maybe delete.
function fusedims(a::AbstractArray, permblocks...)
return fusedims(a, blockedperm(permblocks...; length=Val(ndims(a))))
return fusedims(a, blockedpermvcat(permblocks...; length=Val(ndims(a))))

Check warning on line 49 in src/fusedims.jl

View check run for this annotation

Codecov / codecov/patch

src/fusedims.jl#L49

Added line #L49 was not covered by tests
end

function fuseaxes(
Expand Down
4 changes: 0 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -10,7 +9,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArraysBase = "0d5efcca-f356-4864-8770-e1ed8d78f208"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymmetrySectors = "f8a8ad64-adbc-4fce-92f7-ffe2bb36a86e"
Expand All @@ -21,10 +19,8 @@ TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"

[compat]
Aqua = "0.8.9"
BlockSparseArrays = "0.2"
Random = "1.10"
SafeTestsets = "0.1"
SparseArraysBase = "0.2.11"
Suppressor = "0.2"
SymmetrySectors = "0.1"
TensorOperations = "5.1.3"
Expand Down
4 changes: 2 additions & 2 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
a_fused = fusedims(a, (3, 1), .., 2)
@test eltype(a_fused) === elt
@test a_fused ≈ reshape(permutedims(a, (3, 1, 4, 2)), (8, 5, 3))
a_fused = fusedims(a, (3, 1), ..)
a_fused = fusedims(a, (3, 1), (..,))
@test eltype(a_fused) === elt
@test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 3, 5))
@test a_fused ≈ reshape(permutedims(a, (3, 1, 2, 4)), (8, 15))
end
@testset "splitdims (eltype=$elt)" for elt in elts
a = randn(elt, 6, 20)
Expand Down
106 changes: 19 additions & 87 deletions test/test_blockarrays_contract.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize
using BlockSparseArrays: BlockSparseArray
using SparseArraysBase: densearray
using TensorAlgebra: contract
using Random: randn!
using Test: @test, @test_broken, @testset

function randn_blockdiagonal(elt::Type, axes::Tuple)
a = BlockSparseArray{elt}(axes)
a = zeros(elt, axes)
blockdiaglength = minimum(blocksize(a))
for i in 1:blockdiaglength
b = Block(ntuple(Returns(i), ndims(a)))
Expand All @@ -18,74 +16,14 @@ end
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@testset "`contract` blocked arrays (eltype=$elt)" for elt in elts
d = blockedrange([2, 3])
a1_sba = randn_blockdiagonal(elt, (d, d, d, d))
a2_sba = randn_blockdiagonal(elt, (d, d, d, d))
a3_sba = randn_blockdiagonal(elt, (d, d))
a1_dense = densearray(a1_sba)
a2_dense = densearray(a2_sba)
a3_dense = densearray(a3_sba)

@testset "BlockArray" begin
a1 = BlockArray(a1_sba)
a2 = BlockArray(a2_sba)
a3 = BlockArray(a3_sba)

# matrix matrix
@test_broken a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
#=
a_dest_dense, dimnames_dest_dense = contract(
a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4)
)
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
=#

# matrix vector
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
=#

# vector matrix
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
#=
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
=#

# vector vector
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test_broken a_dest isa BlockArray # TBD relax to AbstractArray{elt,0}?
@test a_dest ≈ a_dest_dense

# outer product
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
#=
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
=#
end
a1 = randn_blockdiagonal(elt, (d, d, d, d))
a2 = randn_blockdiagonal(elt, (d, d, d, d))
a3 = randn_blockdiagonal(elt, (d, d))
a1_dense = convert(Array, a1)
a2_dense = convert(Array, a2)
a3_dense = convert(Array, a3)

@testset "BlockedArray" begin
a1 = BlockedArray(a1_sba)
a2 = BlockedArray(a2_sba)
a3 = BlockedArray(a3_sba)

# matrix matrix
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
a_dest_dense, dimnames_dest_dense = contract(
Expand All @@ -97,31 +35,27 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest ≈ a_dest_dense

# matrix vector
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray
@test a_dest ≈ a_dest_dense
=#

# vector matrix
@test_broken a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
#=
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockedArray
@test a_dest ≈ a_dest_dense
=#

# vector vector
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test_broken a_dest isa BlockedArray # TBD relax to AbstractArray{elt,0}?
@test_broken a_dest isa BlockedArray{elt,0}
@test a_dest ≈ a_dest_dense

# outer product
Expand All @@ -133,8 +67,8 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test a_dest ≈ a_dest_dense
end

@testset "BlockSparseArray" begin
a1, a2, a3 = a1_sba, a2_sba, a3_sba
@testset "BlockArray" begin
a1, a3, a3 = BlockArray.((a1, a2, a3))

# matrix matrix
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
Expand All @@ -143,41 +77,39 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
)
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense

# matrix vector
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
#=
a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
=#

# vector matrix
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense

# vector vector
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test_broken a_dest isa BlockArray{elt,0}
@test a_dest ≈ a_dest_dense

# outer product
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
@test dimnames_dest == dimnames_dest_dense
@test size(a_dest) == size(a_dest_dense)
@test a_dest isa BlockSparseArray
@test a_dest isa BlockArray
@test a_dest ≈ a_dest_dense
end
end
Loading
Loading