Skip to content

Commit e1219b1

Browse files
authored
Fix splitdims(::BlockSparseArrays) (#52)
1 parent 711dc43 commit e1219b1

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.2.18"
4+
version = "0.2.19"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

ext/BlockSparseArraysTensorAlgebraExt/BlockSparseArraysTensorAlgebraExt.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,15 @@ function TensorAlgebra.splitdims(
9494
groupreducewhile(tensor_product, split_axes, ndims(a); init=OneToOne()) do i, axis
9595
return length(axis) length(axes(a, i))
9696
end
97-
blockperms = invblockperm.(blocksortperm.(axes_prod))
97+
blockperms = blocksortperm.(axes_prod)
98+
sorted_axes = map((r, I) -> only(axes(r[I])), axes_prod, blockperms)
99+
98100
# TODO: This is doing extra copies of the blocks,
99101
# use `@view a[axes_prod...]` instead.
100102
# That will require implementing some reindexing logic
101103
# for this combination of slicing.
102-
a_unblocked = a[axes_prod...]
103-
a_blockpermed = a_unblocked[blockperms...]
104+
a_unblocked = a[sorted_axes...]
105+
a_blockpermed = a_unblocked[invblockperm.(blockperms)...]
104106
return splitdims(BlockReshapeFusion(), a_blockpermed, split_axes...)
105107
end
106108

test/test_gradedunitrangesext.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ using SymmetrySectors: U1
1919
using TensorAlgebra: fusedims, splitdims
2020
using LinearAlgebra: adjoint
2121
using Random: randn!
22-
function blockdiagonal!(f, a::AbstractArray)
23-
for i in 1:minimum(blocksize(a))
22+
function randn_blockdiagonal(elt::Type, axes::Tuple)
23+
a = BlockSparseArray{elt}(axes)
24+
blockdiaglength = minimum(blocksize(a))
25+
for i in 1:blockdiaglength
2426
b = Block(ntuple(Returns(i), ndims(a)))
25-
a[b] = f(a[b])
27+
a[b] = randn!(a[b])
2628
end
2729
return a
2830
end
@@ -32,8 +34,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
3234
@testset "map" begin
3335
d1 = gradedrange([U1(0) => 2, U1(1) => 2])
3436
d2 = gradedrange([U1(0) => 2, U1(1) => 2])
35-
a = BlockSparseArray{elt}(d1, d2, d1, d2)
36-
blockdiagonal!(randn!, a)
37+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
3738
@test axes(a, 1) isa GradedOneTo
3839
@test axes(view(a, 1:4, 1:4, 1:4, 1:4), 1) isa GradedOneTo
3940

@@ -89,8 +90,7 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
8990
@testset "fusedims" begin
9091
d1 = gradedrange([U1(0) => 1, U1(1) => 1])
9192
d2 = gradedrange([U1(0) => 1, U1(1) => 1])
92-
a = BlockSparseArray{elt}(d1, d2, d1, d2)
93-
blockdiagonal!(randn!, a)
93+
a = randn_blockdiagonal(elt, (d1, d2, d1, d2))
9494
m = fusedims(a, (1, 2), (3, 4))
9595
for ax in axes(m)
9696
@test ax isa GradedOneTo
@@ -107,6 +107,11 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
107107
@test a[2, 2, 2, 2] == m[4, 4]
108108
@test blocksize(m) == (3, 3)
109109
@test a == splitdims(m, (d1, d2), (d1, d2))
110+
111+
# check block fusing and splitting
112+
d = gradedrange([U1(0) => 2, U1(1) => 1])
113+
a = randn_blockdiagonal(elt, (d, d, dual(d), dual(d)))
114+
@test splitdims(fusedims(a, (1, 2), (3, 4)), axes(a)...) == a
110115
end
111116

112117
@testset "dual axes" begin

0 commit comments

Comments
 (0)