Skip to content

Commit 6a8e2d8

Browse files
authored
adapt to TensorAlgebra.blockedperm changes (#56)
1 parent 4fae77f commit 6a8e2d8

File tree

2 files changed

+127
-14
lines changed

2 files changed

+127
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ MacroTools = "0.5.13"
4545
MapBroadcast = "0.1.5"
4646
SparseArraysBase = "0.2.10"
4747
SplitApplyCombine = "1.2.3"
48-
TensorAlgebra = "0.1.0"
48+
TensorAlgebra = "0.1.0, 0.2"
4949
Test = "1.10"
5050
TypeParameterAccessors = "0.2.0, 0.3"
5151
julia = "1.10"

test/test_tensoralgebraext.jl

Lines changed: 126 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,130 @@
1-
@eval module $(gensym())
2-
using Test: @test, @testset
1+
using Random: randn!
2+
using Test: @test, @test_broken, @testset
3+
4+
using BlockArrays: Block, BlockArray, BlockedArray, blockedrange, blocksize
5+
36
using BlockSparseArrays: BlockSparseArray
7+
using GradedUnitRanges: dual, gradedrange
8+
using SymmetrySectors: U1
49
using TensorAlgebra: contract
5-
using SparseArraysBase: densearray
6-
@testset "BlockSparseArraysTensorAlgebraExt (eltype=$elt)" for elt in (
7-
Float32, Float64, Complex{Float32}, Complex{Float64}
8-
)
9-
a1 = BlockSparseArray{elt}([1, 2], [2, 3], [3, 2])
10-
a2 = BlockSparseArray{elt}([2, 2], [3, 2], [2, 3])
11-
a_dest, dimnames_dest = contract(a1, (1, -1, -2), a2, (2, -2, -1))
12-
a_dest_dense, dimnames_dest_dense = contract(
13-
densearray(a1), (1, -1, -2), densearray(a2), (2, -2, -1)
14-
)
15-
@test a_dest a_dest_dense
10+
11+
function randn_blockdiagonal(elt::Type, axes::Tuple)
12+
a = BlockSparseArray{elt}(axes)
13+
blockdiaglength = minimum(blocksize(a))
14+
for i in 1:blockdiaglength
15+
b = Block(ntuple(Returns(i), ndims(a)))
16+
a[b] = randn!(a[b])
17+
end
18+
return a
1619
end
20+
21+
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
22+
@testset "`contract` `BlockSparseArray` (eltype=$elt)" for elt in elts
23+
@testset "BlockedOneTo" begin
24+
d = blockedrange([2, 3])
25+
a1 = randn_blockdiagonal(elt, (d, d, d, d))
26+
a2 = randn_blockdiagonal(elt, (d, d, d, d))
27+
a3 = randn_blockdiagonal(elt, (d, d))
28+
a1_dense = convert(Array, a1)
29+
a2_dense = convert(Array, a2)
30+
a3_dense = convert(Array, a3)
31+
32+
# matrix matrix
33+
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
34+
a_dest_dense, dimnames_dest_dense = contract(
35+
a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4)
36+
)
37+
@test dimnames_dest == dimnames_dest_dense
38+
@test size(a_dest) == size(a_dest_dense)
39+
@test a_dest isa BlockSparseArray
40+
@test a_dest a_dest_dense
41+
42+
# matrix vector
43+
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
44+
#=
45+
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
46+
@test dimnames_dest == dimnames_dest_dense
47+
@test size(a_dest) == size(a_dest_dense)
48+
@test a_dest isa BlockSparseArray
49+
@test a_dest ≈ a_dest_dense
50+
=#
51+
52+
# vector matrix
53+
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
54+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
55+
@test dimnames_dest == dimnames_dest_dense
56+
@test size(a_dest) == size(a_dest_dense)
57+
@test a_dest isa BlockSparseArray
58+
@test a_dest a_dest_dense
59+
60+
# vector vector
61+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
62+
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
63+
@test dimnames_dest == dimnames_dest_dense
64+
@test size(a_dest) == size(a_dest_dense)
65+
@test a_dest isa BlockSparseArray{elt,0}
66+
@test a_dest a_dest_dense
67+
68+
# outer product
69+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
70+
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
71+
@test dimnames_dest == dimnames_dest_dense
72+
@test size(a_dest) == size(a_dest_dense)
73+
@test a_dest isa BlockSparseArray
74+
@test a_dest a_dest_dense
75+
end
76+
77+
@testset "GradedOneTo with U(1)" begin
78+
d = gradedrange([U1(0) => 2, U1(1) => 3])
79+
a1 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
80+
a2 = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
81+
a3 = randn_blockdiagonal(elt, (d, dual(d)))
82+
a1_dense = convert(Array, a1)
83+
a2_dense = convert(Array, a2)
84+
a3_dense = convert(Array, a3)
85+
86+
# matrix matrix
87+
a_dest, dimnames_dest = contract(a1, (1, -1, 2, -2), a2, (2, -3, 1, -4))
88+
a_dest_dense, dimnames_dest_dense = contract(
89+
a1_dense, (1, -1, 2, -2), a2_dense, (2, -3, 1, -4)
90+
)
91+
@test dimnames_dest == dimnames_dest_dense
92+
@test size(a_dest) == size(a_dest_dense)
93+
@test a_dest isa BlockSparseArray
94+
@test a_dest a_dest_dense
95+
96+
# matrix vector
97+
@test_broken a_dest, dimnames_dest = contract(a1, (2, -1, -2, 1), a3, (1, 2))
98+
#=
99+
a_dest_dense, dimnames_dest_dense = contract(a1_dense, (2, -1, -2, 1), a3_dense, (1, 2))
100+
@test dimnames_dest == dimnames_dest_dense
101+
@test size(a_dest) == size(a_dest_dense)
102+
@test a_dest isa BlockSparseArray
103+
@test a_dest ≈ a_dest_dense
104+
=#
105+
106+
# vector matrix
107+
a_dest, dimnames_dest = contract(a3, (1, 2), a1, (2, -1, -2, 1))
108+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a1_dense, (2, -1, -2, 1))
109+
@test dimnames_dest == dimnames_dest_dense
110+
@test size(a_dest) == size(a_dest_dense)
111+
@test a_dest isa BlockSparseArray
112+
@test a_dest a_dest_dense
113+
114+
# vector vector
115+
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (2, 1))
116+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (2, 1))
117+
@test dimnames_dest == dimnames_dest_dense
118+
@test size(a_dest) == size(a_dest_dense)
119+
@test a_dest isa BlockSparseArray{elt,0}
120+
@test a_dest a_dest_dense
121+
122+
# outer product
123+
a_dest, dimnames_dest = contract(a3, (1, 2), a3, (3, 4))
124+
a_dest_dense, dimnames_dest_dense = contract(a3_dense, (1, 2), a3_dense, (3, 4))
125+
@test dimnames_dest == dimnames_dest_dense
126+
@test size(a_dest) == size(a_dest_dense)
127+
@test a_dest isa BlockSparseArray
128+
@test a_dest a_dest_dense
129+
end
17130
end

0 commit comments

Comments
 (0)